From 44f18494ff2ea6d006101265164bd555be827eaa Mon Sep 17 00:00:00 2001 From: Jai Bhagat Date: Thu, 9 Jan 2020 21:22:57 +0000 Subject: [PATCH] updated code and examples to allow for compressed files --- .../Loading_from_ONE_and_running_functions.py | 4 +- brainbox/io/io.py | 28 +++++------ brainbox/metrics/metrics.py | 39 +++++++++++---- brainbox/plot/plot.py | 48 +++++++++++++------ brainbox/processing/processing.py | 2 + 5 files changed, 80 insertions(+), 41 deletions(-) diff --git a/brainbox/examples/Loading_from_ONE_and_running_functions.py b/brainbox/examples/Loading_from_ONE_and_running_functions.py index 4698ca4af..7128cb317 100644 --- a/brainbox/examples/Loading_from_ONE_and_running_functions.py +++ b/brainbox/examples/Loading_from_ONE_and_running_functions.py @@ -47,7 +47,7 @@ # Ensure directories and paths can be found assert os.path.isdir(ephys_file_dir) and os.path.isdir(alf_probe_dir) \ and os.path.isabs(ephys_file_path), 'Directories set incorrectly' - + # Call brainbox functions # #-------------------------# @@ -55,7 +55,7 @@ path_to_ephys_file = ephys_file_path path_to_alf_out = alf_probe_dir -# Load alf objects: +# Load alf objects spks_b = aio.load_object(path_to_alf_out, 'spikes') clstrs_b = aio.load_object(path_to_alf_out, 'clusters') chnls_b = aio.load_object(path_to_alf_out, 'channels') diff --git a/brainbox/io/io.py b/brainbox/io/io.py index 9b5fde215..ce5a3a36b 100644 --- a/brainbox/io/io.py +++ b/brainbox/io/io.py @@ -3,11 +3,10 @@ # (Previously required `os.path` to get file info before memmapping) # import os.path as op from ibllib.io import spikeglx -from scipy.signal import decimate def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='int16', - offset=0, car=True, q=5): + offset=0, car=True): ''' Extracts spike waveforms from binary ephys data file, after (optionally) common-average-referencing (CAR) spatial noise. @@ -32,8 +31,6 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype The offset (in bytes) from the start of `ephys_file`. car: bool (optional) A flag to perform CAR before extracting waveforms. - q : int )optional) - The downsampling factor to use when performing CAR. Returns ------- @@ -80,12 +77,13 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype # Exception handling for impossible channels ch = np.asarray(ch) + ch = ch.reshape((ch.size, 1)) if np.any(ch < 0) or np.any(ch > n_ch_probe): raise Exception('At least one specified channel number is impossible. The minimum channel' ' number was {}, and the maximum channel number was {}. Check specified' ' channel numbers and try again.'.format(np.min(ch), np.max(ch))) - if car: # compute spatial noise in chunks + if car: # compute spatial noise in chunks # (previously computed temporal noise also, but was too costly) # Get number of chunks. t_sample_first = ts_samples[0] - n_wf_samples @@ -96,8 +94,8 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype # samples that make up the first chunk. chunk_sample = np.arange(t_sample_first, t_sample_last, n_chunk_samples, dtype=int) chunk_sample = np.append(chunk_sample, t_sample_last) - noise_s_chunks = np.zeros((n_chunks, ch.size)) - # Give time estimate for calculating `noise_s_chunks`. + noise_s_chunks = np.zeros((n_chunks, ch.size)) # spatial noise array + # Give time estimate for computing `noise_s_chunks`. t0 = time.perf_counter() np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) dt = time.perf_counter() - t0 @@ -105,7 +103,7 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype ' ({})'.format(dt * n_chunks / 60, time.ctime())) # Compute noise for each chunk, then take the median noise of all chunks. for chunk in range(n_chunks): - noise_s_chunks[chunk,:] = np.median( + noise_s_chunks[chunk, :] = np.median( file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0) noise_s = np.median(noise_s_chunks, axis=0) print('Done. ({})'.format(time.ctime())) @@ -116,18 +114,18 @@ def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype t0 = time.perf_counter() for i in range(5): waveforms[i, :, :] = \ - file_m[i * n_wf_samples * 2 + t_sample_first: - i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape( - (n_wf_samples * 2, ch.size)) - dt = time.perf_counter() - t0 - print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})' - .format(dt * len(ts) / 60 / 5, time.ctime())) + file_m[i * n_wf_samples * 2 + t_sample_first: + i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape( + (n_wf_samples * 2, ch.size)) + dt = time.perf_counter() - t0 + print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})' + .format(dt * len(ts) / 60 / 5, time.ctime())) for spk, _ in enumerate(ts): # extract waveforms spk_ts_sample = ts_samples[spk] spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples) # have to reshape to add an axis to broadcast `file_m` into `waveforms` waveforms[spk, :, :] = \ - file_m[spk_samples[0]:spk_samples[-1]+1, ch].reshape((spk_samples.size, ch.size)) + file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size)) print('Done. ({})'.format(time.ctime())) if car: # perform CAR (subtract spatial noise) waveforms -= noise_s[None, None, :] diff --git a/brainbox/metrics/metrics.py b/brainbox/metrics/metrics.py index 8f49137db..89d545cb8 100644 --- a/brainbox/metrics/metrics.py +++ b/brainbox/metrics/metrics.py @@ -15,11 +15,12 @@ >>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute """ -import os.path as op +import time import numpy as np import scipy.stats as stats import scipy.ndimage.filters as filters import brainbox as bb +from ibllib.io import spikeglx # add spikemetrics as dependency? # import spikemetrics as sm @@ -504,7 +505,8 @@ def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='i ''' # Ensure `ch` is ndarray - ch = np.asarrray(ch) + ch = np.asarray(ch) + ch = ch.reshape((ch.size, 1)) # Get waveforms. wf = bb.io.extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, @@ -516,12 +518,29 @@ def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='i mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) - np.min(wf[:, :, cur_ch], axis=1)) - # Compute MAD for all channels. - item_bytes = np.dtype(dtype).itemsize - n_samples = (op.getsize(ephys_file) - offset) // (item_bytes * n_ch_probe) - file_m = np.memmap(ephys_file, shape=(n_samples, n_ch_probe), dtype=dtype, mode='r') - noise = stats.median_absolute_deviation(file_m[:, ch], axis=0, scale=1) - - # Return `mean_ptp` over `noise` - ptp_sigma = mean_ptp / noise + # Compute MAD for `ch` in chunks. + s_reader = spikeglx.Reader(ephys_file) + file_m = s_reader.data # the memmapped array + n_chunk_samples = 5e6 # number of samples per chunk + n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int') + # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the + # samples that make up the first chunk. + chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int) + chunk_sample = np.append(chunk_sample, file_m.shape[0]) + # Give time estimate for computing MAD. + t0 = time.perf_counter() + stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) + dt = time.perf_counter() - t0 + print('Performing MAD computation. Estimated time is {:.2f} mins.' + ' ({})'.format(dt * n_chunks / 60, time.ctime())) + # Compute MAD for each chunk, then take the median MAD of all chunks. + mad_chunks = np.zeros((n_chunks, ch.size)) + for chunk in range(n_chunks): + mad_chunks[chunk, :] = stats.median_absolute_deviation( + file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1) + print('Done. ({})'.format(time.ctime())) + + # Return `mean_ptp` over `mad` + mad = np.median(mad_chunks, axis=0) + ptp_sigma = mean_ptp / mad return ptp_sigma diff --git a/brainbox/plot/plot.py b/brainbox/plot/plot.py index 1f0d7e74f..c32ca6922 100644 --- a/brainbox/plot/plot.py +++ b/brainbox/plot/plot.py @@ -16,12 +16,13 @@ >>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute """ -import os.path as op +import time from warnings import warn import numpy as np import matplotlib.pyplot as plt # from matplotlib.ticker import StrMethodFormatter import brainbox as bb +from ibllib.io import spikeglx def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm', @@ -262,6 +263,10 @@ def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', c >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch) ''' + # Ensure `ch` is ndarray + ch = np.asarray(ch) + ch = ch.reshape((ch.size, 1)) + # Extract the waveforms for these timestamps and compute similarity score. wf1 = bb.io.extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, car=car) @@ -270,7 +275,7 @@ def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', c s = bb.metrics.wf_similarity(wf1, wf2) # Plot these waveforms against each other. - n_ch = len(ch) + n_ch = ch.size if ax is None: fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean for cur_ax, cur_ch in enumerate(ch): @@ -332,24 +337,39 @@ def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cma >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch) ''' + # Ensure `ch` is ndarray + ch = np.asarray(ch) + ch = ch.reshape((ch.size, 1)) # Get memmapped array of `ephys_file` - item_bytes = np.dtype(dtype).itemsize - n_samples = op.getsize(ephys_file) // (item_bytes * n_ch_probe) - file_m = np.memmap(ephys_file, shape=(n_samples, n_ch_probe), dtype=dtype, mode='r') + s_reader = spikeglx.Reader(ephys_file) + file_m = s_reader.data # Get voltage values for each peak amplitude sample for `ch`. max_amp_samples = (ts * sr).astype(int) - v_vals = file_m[np.ix_(max_amp_samples, ch)] - if car: # Compute and subtract temporal and spatial noise from `v_vals`. + v_vals = file_m[max_amp_samples, ch] + if car: # compute spatial noise in chunks, and subtract from `v_vals`. # Get subset of time (from first to last max amp sample) - t_subset = np.arange(max_amp_samples[0], max_amp_samples[-1] + 1, dtype='int16') - # Specify output arrays as `dtype='int16'` - out_noise_t = np.zeros((len(t_subset),), dtype='int16') - out_noise_s = np.zeros((len(ch),), dtype='int16') - noise_t = np.median(file_m[np.ix_(t_subset, ch)], axis=1, out=out_noise_t) - noise_s = np.median(file_m[np.ix_(t_subset, ch)], axis=0, out=out_noise_s) - v_vals -= noise_t[max_amp_samples - max_amp_samples[0], None] + n_chunk_samples = 5e6 # number of samples per chunk + n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) / + n_chunk_samples).astype('int') + # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the + # samples that make up the first chunk. + chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples, + dtype=int) + chunk_sample = np.append(chunk_sample, max_amp_samples[-1]) + noise_s_chunks = np.zeros((n_chunks, ch.size)) # spatial noise array + # Give time estimate for computing `noise_s_chunks`. + t0 = time.perf_counter() + np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) + dt = time.perf_counter() - t0 + print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.' + ' ({})'.format(dt * n_chunks / 60, time.ctime())) + # Compute noise for each chunk, then take the median noise of all chunks. + for chunk in range(n_chunks): + noise_s_chunks[chunk, :] = np.median( + file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0) + noise_s = np.median(noise_s_chunks, axis=0) v_vals -= noise_s[None, :] # Plot heatmap. diff --git a/brainbox/processing/processing.py b/brainbox/processing/processing.py index 85aefd3ed..d84ae0c73 100644 --- a/brainbox/processing/processing.py +++ b/brainbox/processing/processing.py @@ -240,6 +240,8 @@ def get_units_bunch(spks_b, *args): >>> units_b = bb.processing.get_units_bunch(spks_b) # Get amplitudes for unit 4. >>> amps = units_b['amps']['4'] + + TODO add computation time estimate? ''' # Initialize `units`