Skip to content

Commit

Permalink
updated code and examples to allow for compressed files
Browse files Browse the repository at this point in the history
  • Loading branch information
Jai Bhagat committed Jan 9, 2020
1 parent 75e3040 commit 44f1849
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 41 deletions.
4 changes: 2 additions & 2 deletions brainbox/examples/Loading_from_ONE_and_running_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@
# Ensure directories and paths can be found
assert os.path.isdir(ephys_file_dir) and os.path.isdir(alf_probe_dir) \
and os.path.isabs(ephys_file_path), 'Directories set incorrectly'

# Call brainbox functions #
#-------------------------#

# Change variable names to same names used in brainbox docstrings
path_to_ephys_file = ephys_file_path
path_to_alf_out = alf_probe_dir

# Load alf objects:
# Load alf objects
spks_b = aio.load_object(path_to_alf_out, 'spikes')
clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
chnls_b = aio.load_object(path_to_alf_out, 'channels')
Expand Down
28 changes: 13 additions & 15 deletions brainbox/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
# (Previously required `os.path` to get file info before memmapping)
# import os.path as op
from ibllib.io import spikeglx
from scipy.signal import decimate


def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='int16',
offset=0, car=True, q=5):
offset=0, car=True):
'''
Extracts spike waveforms from binary ephys data file, after (optionally)
common-average-referencing (CAR) spatial noise.
Expand All @@ -32,8 +31,6 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype
The offset (in bytes) from the start of `ephys_file`.
car: bool (optional)
A flag to perform CAR before extracting waveforms.
q : int )optional)
The downsampling factor to use when performing CAR.
Returns
-------
Expand Down Expand Up @@ -80,12 +77,13 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype

# Exception handling for impossible channels
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1))
if np.any(ch < 0) or np.any(ch > n_ch_probe):
raise Exception('At least one specified channel number is impossible. The minimum channel'
' number was {}, and the maximum channel number was {}. Check specified'
' channel numbers and try again.'.format(np.min(ch), np.max(ch)))

if car: # compute spatial noise in chunks
if car: # compute spatial noise in chunks
# (previously computed temporal noise also, but was too costly)
# Get number of chunks.
t_sample_first = ts_samples[0] - n_wf_samples
Expand All @@ -96,16 +94,16 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype
# samples that make up the first chunk.
chunk_sample = np.arange(t_sample_first, t_sample_last, n_chunk_samples, dtype=int)
chunk_sample = np.append(chunk_sample, t_sample_last)
noise_s_chunks = np.zeros((n_chunks, ch.size))
# Give time estimate for calculating `noise_s_chunks`.
noise_s_chunks = np.zeros((n_chunks, ch.size)) # spatial noise array
# Give time estimate for computing `noise_s_chunks`.
t0 = time.perf_counter()
np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
dt = time.perf_counter() - t0
print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.'
' ({})'.format(dt * n_chunks / 60, time.ctime()))
# Compute noise for each chunk, then take the median noise of all chunks.
for chunk in range(n_chunks):
noise_s_chunks[chunk,:] = np.median(
noise_s_chunks[chunk, :] = np.median(
file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0)
noise_s = np.median(noise_s_chunks, axis=0)
print('Done. ({})'.format(time.ctime()))
Expand All @@ -116,18 +114,18 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype
t0 = time.perf_counter()
for i in range(5):
waveforms[i, :, :] = \
file_m[i * n_wf_samples * 2 + t_sample_first:
i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape(
(n_wf_samples * 2, ch.size))
dt = time.perf_counter() - t0
print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})'
.format(dt * len(ts) / 60 / 5, time.ctime()))
file_m[i * n_wf_samples * 2 + t_sample_first:
i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape(
(n_wf_samples * 2, ch.size))
dt = time.perf_counter() - t0
print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})'
.format(dt * len(ts) / 60 / 5, time.ctime()))
for spk, _ in enumerate(ts): # extract waveforms
spk_ts_sample = ts_samples[spk]
spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples)
# have to reshape to add an axis to broadcast `file_m` into `waveforms`
waveforms[spk, :, :] = \
file_m[spk_samples[0]:spk_samples[-1]+1, ch].reshape((spk_samples.size, ch.size))
file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size))
print('Done. ({})'.format(time.ctime()))
if car: # perform CAR (subtract spatial noise)
waveforms -= noise_s[None, None, :]
Expand Down
39 changes: 29 additions & 10 deletions brainbox/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute
"""

import os.path as op
import time
import numpy as np
import scipy.stats as stats
import scipy.ndimage.filters as filters
import brainbox as bb
from ibllib.io import spikeglx
# add spikemetrics as dependency?
# import spikemetrics as sm

Expand Down Expand Up @@ -504,7 +505,8 @@ def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='i
'''

# Ensure `ch` is ndarray
ch = np.asarrray(ch)
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1))

# Get waveforms.
wf = bb.io.extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe,
Expand All @@ -516,12 +518,29 @@ def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='i
mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) -
np.min(wf[:, :, cur_ch], axis=1))

# Compute MAD for all channels.
item_bytes = np.dtype(dtype).itemsize
n_samples = (op.getsize(ephys_file) - offset) // (item_bytes * n_ch_probe)
file_m = np.memmap(ephys_file, shape=(n_samples, n_ch_probe), dtype=dtype, mode='r')
noise = stats.median_absolute_deviation(file_m[:, ch], axis=0, scale=1)

# Return `mean_ptp` over `noise`
ptp_sigma = mean_ptp / noise
# Compute MAD for `ch` in chunks.
s_reader = spikeglx.Reader(ephys_file)
file_m = s_reader.data # the memmapped array
n_chunk_samples = 5e6 # number of samples per chunk
n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int')
# Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
# samples that make up the first chunk.
chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int)
chunk_sample = np.append(chunk_sample, file_m.shape[0])
# Give time estimate for computing MAD.
t0 = time.perf_counter()
stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
dt = time.perf_counter() - t0
print('Performing MAD computation. Estimated time is {:.2f} mins.'
' ({})'.format(dt * n_chunks / 60, time.ctime()))
# Compute MAD for each chunk, then take the median MAD of all chunks.
mad_chunks = np.zeros((n_chunks, ch.size))
for chunk in range(n_chunks):
mad_chunks[chunk, :] = stats.median_absolute_deviation(
file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1)
print('Done. ({})'.format(time.ctime()))

# Return `mean_ptp` over `mad`
mad = np.median(mad_chunks, axis=0)
ptp_sigma = mean_ptp / mad
return ptp_sigma
48 changes: 34 additions & 14 deletions brainbox/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute
"""

import os.path as op
import time
from warnings import warn
import numpy as np
import matplotlib.pyplot as plt
# from matplotlib.ticker import StrMethodFormatter
import brainbox as bb
from ibllib.io import spikeglx


def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm',
Expand Down Expand Up @@ -262,6 +263,10 @@ def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', c
>>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch)
'''

# Ensure `ch` is ndarray
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1))

# Extract the waveforms for these timestamps and compute similarity score.
wf1 = bb.io.extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype,
car=car)
Expand All @@ -270,7 +275,7 @@ def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', c
s = bb.metrics.wf_similarity(wf1, wf2)

# Plot these waveforms against each other.
n_ch = len(ch)
n_ch = ch.size
if ax is None:
fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean
for cur_ax, cur_ch in enumerate(ch):
Expand Down Expand Up @@ -332,24 +337,39 @@ def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cma
>>> ch = np.arange(max_ch - 10, max_ch + 10)
>>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch)
'''
# Ensure `ch` is ndarray
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1))

# Get memmapped array of `ephys_file`
item_bytes = np.dtype(dtype).itemsize
n_samples = op.getsize(ephys_file) // (item_bytes * n_ch_probe)
file_m = np.memmap(ephys_file, shape=(n_samples, n_ch_probe), dtype=dtype, mode='r')
s_reader = spikeglx.Reader(ephys_file)
file_m = s_reader.data

# Get voltage values for each peak amplitude sample for `ch`.
max_amp_samples = (ts * sr).astype(int)
v_vals = file_m[np.ix_(max_amp_samples, ch)]
if car: # Compute and subtract temporal and spatial noise from `v_vals`.
v_vals = file_m[max_amp_samples, ch]
if car: # compute spatial noise in chunks, and subtract from `v_vals`.
# Get subset of time (from first to last max amp sample)
t_subset = np.arange(max_amp_samples[0], max_amp_samples[-1] + 1, dtype='int16')
# Specify output arrays as `dtype='int16'`
out_noise_t = np.zeros((len(t_subset),), dtype='int16')
out_noise_s = np.zeros((len(ch),), dtype='int16')
noise_t = np.median(file_m[np.ix_(t_subset, ch)], axis=1, out=out_noise_t)
noise_s = np.median(file_m[np.ix_(t_subset, ch)], axis=0, out=out_noise_s)
v_vals -= noise_t[max_amp_samples - max_amp_samples[0], None]
n_chunk_samples = 5e6 # number of samples per chunk
n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) /
n_chunk_samples).astype('int')
# Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
# samples that make up the first chunk.
chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples,
dtype=int)
chunk_sample = np.append(chunk_sample, max_amp_samples[-1])
noise_s_chunks = np.zeros((n_chunks, ch.size)) # spatial noise array
# Give time estimate for computing `noise_s_chunks`.
t0 = time.perf_counter()
np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
dt = time.perf_counter() - t0
print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.'
' ({})'.format(dt * n_chunks / 60, time.ctime()))
# Compute noise for each chunk, then take the median noise of all chunks.
for chunk in range(n_chunks):
noise_s_chunks[chunk, :] = np.median(
file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0)
noise_s = np.median(noise_s_chunks, axis=0)
v_vals -= noise_s[None, :]

# Plot heatmap.
Expand Down
2 changes: 2 additions & 0 deletions brainbox/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def get_units_bunch(spks_b, *args):
>>> units_b = bb.processing.get_units_bunch(spks_b)
# Get amplitudes for unit 4.
>>> amps = units_b['amps']['4']
TODO add computation time estimate?
'''

# Initialize `units`
Expand Down

0 comments on commit 44f1849

Please sign in to comment.