diff --git a/brainbox/examples/Loading_from_ONE_and_running_functions.py b/brainbox/examples/Loading_from_ONE_and_running_functions.py index e9f5f96bd..877631dc8 100644 --- a/brainbox/examples/Loading_from_ONE_and_running_functions.py +++ b/brainbox/examples/Loading_from_ONE_and_running_functions.py @@ -83,21 +83,21 @@ wf_car = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, t=2.0, car=True) # Plot variances of a spike feature for all units and for a subset of units -fig1, var_vals, p_vals = bb.plot.feat_vars(spks_b, units=[], feat_name='amps') -fig2, var_vals, p_vals = bb.plot.feat_vars(spks_b, units=filtered_units, feat_name='amps') +fig1, var_vals, p_vals = bb.plot.feat_vars(units_b, units=[], feat_name='amps') +fig2, var_vals, p_vals = bb.plot.feat_vars(units_b, units=filtered_units, feat_name='amps') # Plot distribution cutoff of a spike feature for a single unit -fig3, fraction_missing = bb.plot.feat_cutoff(spks_b, unit=1, feat_name='amps') +fig3, fraction_missing = bb.plot.feat_cutoff(units_b, unit=1, feat_name='amps') # Plot and compare two sets of waveforms from two different time epochs for a single unit ts = units_b['times']['1'] ts1 = ts[np.where(ts<60)[0]] ts2 = ts[np.where(ts>180)[0][:len(ts1)]] -fig4, wf_1, wf_2, s = bb.plot.single_unit_wf_comp(path_to_ephys_file, spks_b, clstrs_b, unit=1, +fig4, wf_1, wf_2, s = bb.plot.single_unit_wf_comp(path_to_ephys_file, units_b, clstrs_b, unit=1, ts1=ts1, ts2=ts2, n_ch=20, car=True) # Plot the instantaneous firing rate and its coefficient of variation for a single unit -fig5, fr, cv, cvs = bb.plot.firing_rate(spks_b, unit=1, t='all', hist_win=0.01, fr_win=0.5, +fig5, fr, cv, cvs = bb.plot.firing_rate(units_b, unit=1, t='all', hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True) # Save figs in a directory diff --git a/brainbox/metrics/metrics.py b/brainbox/metrics/metrics.py index bfa4575e4..2e4cd5d0b 100644 --- a/brainbox/metrics/metrics.py +++ b/brainbox/metrics/metrics.py @@ -1,5 +1,18 @@ """ Computes metrics for assessing quality of single units. + +Run the following to set-up the workspace to run the docstring examples: +>>> import brainbox as bb +>>> import alf.io as aio +>>> import numpy as np +>>> import matplotlib.pyplot as plt +>>> import ibllib.ephys.spikes as e_spks +# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): +>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) +# Load the alf spikes bunch and clusters bunch, and get a units bunch. +>>> spks_b = aio.load_object(path_to_alf_out, 'spikes') +>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') +>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute """ import brainbox as bb import numpy as np @@ -9,7 +22,7 @@ # import spikemetrics as sm -def unit_stability(spks_b, units=[], feat_names=['amps'], dist='norm', test='ks'): +def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'): ''' Computes the probability that the empirical spike feature distribution(s), for specified feature(s), for all units, comes from a specific theoretical distribution, based on a specified @@ -17,16 +30,16 @@ def unit_stability(spks_b, units=[], feat_names=['amps'], dist='norm', test='ks' Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. + units_b : bunch + A units bunch containing fields with spike information (e.g. cluster IDs, times, features, + etc.) for all units. units : array-like (optional) - A subset of all units for which to create the bar plot. (If `[]`, all units are used) + A subset of all units for which to create the bar plot. (If `None`, all units are used) feat_names : list of strings (optional) A list of names of spike features that can be found in `spks` to specify which features to use for calculating unit stability. dist : string (optional) - The type of hypothetical null distribution from which the empirical spike feature + The type of hypothetical null distribution for which the empirical spike feature distributions are presumed to belong to. test : string (optional) The statistical test used to calculate the probability that the empirical spike feature @@ -53,40 +66,33 @@ def unit_stability(spks_b, units=[], feat_names=['amps'], dist='norm', test='ks' unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater than 50. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Load a spikes bunch and calculate unit stability: - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> p_vals_b, variances_b = bb.metrics.unit_stability(spks) + >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b) # Plot histograms of variances color-coded by depth of channel of max amplitudes - >>> fig = bb.plot.feat_vars(spks_b, feat_name='amps') + >>> fig = bb.plot.feat_vars(units_b, feat_name='amps') # Get all unit IDs which have amps variance > 50 >>> var_vals = np.array(tuple(variances_b['amps'].values())) >>> bad_units = np.where(var_vals > 50) ''' - # Get units bunch and number of units. - units_b = bb.processing.get_units_bunch(spks_b, feat_names) - if len(units) != 0: # we're using a subset of all units + # Get units. + if not(units is None): # we're using a subset of all units unit_list = list(units_b[feat_names[0]].keys()) # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units` for feat in feat_names: [units_b[feat].pop(unit) for unit in unit_list if not(int(unit) in units)] unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units + # Initialize `p_vals` and `variances`. p_vals_b = bb.core.Bunch() variances_b = bb.core.Bunch() + # Set the test as a lambda function (in future, more tests can be added to this dict) tests = \ { 'ks': lambda x, y: stats.kstest(x, y) } test_fun = tests.get(test) + # Compute the statistical tests and variances. For each feature, iteratively get each unit's # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and # `variances_feat`. After iterating through all units, add these bunches as keys to their @@ -96,7 +102,7 @@ def unit_stability(spks_b, units=[], feat_names=['amps'], dist='norm', test='ks' variances_feat = bb.core.Bunch((unit, 0) for unit in unit_list) for unit in unit_list: # If we're missing units/features, create a NaN placeholder and skip them: - if not(str(type(units_b['amps'][unit])) == ""): + if not(str(type(units_b[feat][unit])) == ""): p_val = np.nan var = np.nan else: @@ -108,10 +114,11 @@ def unit_stability(spks_b, units=[], feat_names=['amps'], dist='norm', test='ks' variances_feat[str(unit)] = var p_vals_b[feat] = p_vals_feat variances_b[feat] = variances_feat + return p_vals_b, variances_b -def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): +def feat_cutoff(feat, spks_per_bin=20, sigma=5, min_num_bins=50): ''' Computes the approximate fraction of spikes missing from a spike feature distribution for a given unit, assuming the distribution is symmetric. @@ -120,18 +127,15 @@ def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - unit : int - The unit number for the feature to plot. - feat_name : string (optional) - The spike feature to plot. + feat : ndarray + The spikes' feature values. spks_per_bin : int (optional) The number of spikes per bin from which to compute the spike feature histogram. sigma : int (optional) The standard deviation for the gaussian kernel used to compute the pdf from the spike feature histogram. + min_num_bins : int (optional) + The minimum number of bins used to compute the spike feature histogram. Returns ------- @@ -150,30 +154,21 @@ def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): Examples -------- - 1) Determine the fraction of spikes missing from a unit based on the recorded unit's spike + 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and calculate estimated fraction of missing spikes. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> fraction_missing = bb.metrics.feat_cutoff(spks_b, unit=1, feat_name='amps') - # Plot histogram and pdf of the spike amplitude distribution. - >>> fig = bb.plot.feat_cutoff(spks_b, unit=1, feat_name='amps') + # Get unit 1 amplitudes from a unit bunch, and calculate fraction spikes missing. + >>> feat = units_b['amps']['1'] + >>> fraction_missing = bb.plot.feat_cutoff(feat) ''' - - min_num_bins = 50 - units = bb.processing.get_units_bunch(spks_b, [feat_name]) - feature = units[feat_name][str(unit)] + + # Ensure minimum number of spikes requirement is met. error_str = 'The number of spikes in this unit is {0}, ' \ - 'but it must be at least {1}'.format(feature.size, spks_per_bin * min_num_bins) - assert (feature.size > (spks_per_bin * min_num_bins)), error_str + 'but it must be at least {1}'.format(feat.size, spks_per_bin * min_num_bins) + assert (feat.size > (spks_per_bin * min_num_bins)), error_str # Calculate the spike feature histogram and pdf: - num_bins = np.int(feature.size / spks_per_bin) - hist, bins = np.histogram(feature, num_bins, density=True) + num_bins = np.int(feat.size / spks_per_bin) + hist, bins = np.histogram(feat, num_bins, density=True) pdf = filters.gaussian_filter1d(hist, sigma) # Find where the distribution stops being symmetric around the peak: @@ -217,16 +212,9 @@ def wf_similarity(wf1, wf2): -------- 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20 channels around the channel of max amplitude. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch, a clusters bunch, a units bunch, the channels around the max amp - # channel for the unit, two sets of timestamps for the units, and the two corresponding - # sets of waveforms for those two sets of timestamps. Then compute `s`. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') + # Get the channels around the max amp channel for the unit, two sets of timestamps for the + # unit, and the two corresponding sets of waveforms for those two sets of timestamps. + # Then compute `s`. >>> max_ch = clstrs_b['channels'][1] >>> if max_ch < 10: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) @@ -234,7 +222,6 @@ def wf_similarity(wf1, wf2): >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) - >>> units_b = bb.processing.get_units_bunch(spks_b, ['times']) >>> ts1 = units_b['times']['1'][:100] >>> ts2 = units_b['times']['1'][-100:] >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch) @@ -250,12 +237,15 @@ def wf_similarity(wf1, wf2): warnings.filterwarnings('ignore', r'invalid value encountered in true_divide') assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})' '({})'.format(wf1.shape, wf2.shape)) + + # Get number of spikes, samples, and channels of waveforms. n_spks = wf1.shape[0] n_samples = wf1.shape[1] n_ch = wf1.shape[2] - # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2` + + # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`. + # Iterate over both sets of spikes, computing `s` for each pair. similarity_matrix = np.zeros((n_spks, n_spks)) - # Iterate over both sets of spikes, computing `s` for each pair for spk1 in range(n_spks): for spk2 in range(n_spks): s_spk = \ @@ -263,29 +253,24 @@ def wf_similarity(wf1, wf2): wf1[spk1, :, :] * wf2[spk2, :, :] / np.sqrt(wf1[spk1, :, :]**2 * wf2[spk2, :, :]**2))) / (n_samples * n_ch) similarity_matrix[spk1, spk2] = s_spk + # Return mean of similarity matrix s = np.mean(similarity_matrix) return s -def firing_rate_coeff_var(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bins=10): +def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10): ''' Computes the coefficient of variation of the firing rate: the ratio of the standard deviation to the mean. Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - unit : int - The unit number for which to calculate the firing rate. - t : str or pair of floats (optional) - The total time period for which the instantaneous firing rate is returned. Default: the - time period from `unit`'s first to last spike. - hist_win : float (optional) + ts : ndarray + The spike timestamps from which to compute the firing rate. + hist_win : float The time window (in s) to use for computing spike counts. - fr_win : float (optional) + fr_win : float The time window (in s) to use as a moving slider to compute the instantaneous firing rate. n_bins : int (optional) The number of bins in which to compute a coefficient of variation of the firing rate. @@ -307,23 +292,119 @@ def firing_rate_coeff_var(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bi Examples -------- - 1) Compute the coefficient of variation of the firing rate for unit1 from the time of its - first to last spike, and compute the coefficient of variation of the firing rate for unit2 from - the first to second minute. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and calculate the firing rate. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(spks_b, unit=1) - >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(spks_b, unit=2) + 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its + first to last spike, and compute the coefficient of variation of the firing rate for unit 2 + from the first to second minute. + >>> ts_1 = units_b['times']['1'] + >>> ts_2 = units_b['times']['2'] + >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) + >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1) + >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2) ''' - fr = bb.singlecell.firing_rate(spks_b, unit, t=t, hist_win=hist_win, fr_win=fr_win) + # Calculate overall instantaneous firing rate and firing rate for each bin. + fr = bb.singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) bin_sz = np.int(fr.size / n_bins) fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) + + # Calculate coefficient of variations of firing rate for each bin, and the mean c.v. cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1) cv = np.mean(cvs) + return cv, cvs, fr + + +def isi_viol(ts, rp=0.002): + ''' + Computes the fraction of isi violations for a unit. + + Parameters + ---------- + ts : ndarray + The spike timestamps from which to compute the firing rate. + rp : float + The refractory period (in s). + + Returns + ------- + frac_isi_viol : float + The fraction of isi violations. + n_isi_viol : int + The number of isi violations. + isis : ndarray + The isis. + + See Also + -------- + + Examples + -------- + 1) Get the fraction of isi violations, the total number of isi violations, and the array of + isis for unit 1. + >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] + >>> ts = spks_b['times'][unit_idxs] + >>> frac_isi_viol, n_isi_viol, isi = bb.metrics.isi_viol(ts) + ''' + + isis = np.diff(ts) + v = np.where(isis < rp)[0] # violations + frac_isi_viol = len(v) / len(ts) + return frac_isi_viol, len(v), isis + + +def max_drift(depths): + ''' + Computes the maximum drift of spikes in a unit. + + Parameters + ---------- + depths : ndarray + The spike depths from which to compute the maximum drift. + + Returns + ------- + md : float + The maxmimum drift of the unit (in mm). + + See Also + -------- + + Examples + -------- + 1) Get the maximum drift for unit 1. + >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] + >>> depths = spks_b['depths'][unit_idxs] + >>> md = bb.metrics.max_drift(depths) + ''' + + md = np.max(depths) - np.min(depths) + return md + + +def cum_drift(depths): + ''' + Computes the cumulative drift (normalized by the total number of spikes) of spikes in a unit. + + Parameters + ---------- + depths : ndarray + The spike depths from which to compute the maximum drift. + + Returns + ------- + md : float + The maxmimum drift of the unit (in mm). + + See Also + -------- + + Examples + -------- + 1) Get the cumulative drift for unit 1. + >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] + >>> depths = spks_b['depths'][unit_idxs] + >>> md = bb.metrics.cum_drift(depths) + ''' + + cd = np.sum(np.abs(np.diff(depths))) / len(depths) + return cd diff --git a/brainbox/plot/plot.py b/brainbox/plot/plot.py index 87abc9e5e..29ed9c94e 100644 --- a/brainbox/plot/plot.py +++ b/brainbox/plot/plot.py @@ -1,6 +1,19 @@ """ Plots metrics that assess quality of single units. Some functions here generate plots for the output of functions in the brainbox `metrics.py` module. + +Run the following to set-up the workspace to run the docstring examples: +>>> import brainbox as bb +>>> import alf.io as aio +>>> import numpy as np +>>> import matplotlib.pyplot as plt +>>> import ibllib.ephys.spikes as e_spks +# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): +>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) +# Load the alf spikes bunch and clusters bunch, and get a units bunch. +>>> spks_b = aio.load_object(path_to_alf_out, 'spikes') +>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') +>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute """ import os.path as op @@ -10,18 +23,19 @@ import brainbox as bb -def feat_vars(spks_b, units=[], feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm'): +def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm', + ax=None): ''' Plots the variances of a particular spike feature for all units as a bar plot, where each bar is color-coded corresponding to the depth of the max amplitude channel of the respective unit. Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. + units_b : bunch + A units bunch containing fields with spike information (e.g. cluster IDs, times, features, + etc.) for all units. units : array-like (optional) - A subset of all units for which to create the bar plot. (If `[]`, all units are used) + A subset of all units for which to create the bar plot. (If `None`, all units are used) feat_name : string (optional) The spike feature to plot. dist : string (optional) @@ -32,11 +46,11 @@ def feat_vars(spks_b, units=[], feat_name='amps', dist='norm', test='ks', cmap_n distributions come from `dist`. cmap_name : string (optional) The name of the colormap associated with the plot. + ax : axessubplot (optional) + The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- - fig : figure - A figure object containing the plot. var_vals : ndarray Contains the variances of `feat_name` for each unit. p_vals : ndarray @@ -49,29 +63,23 @@ def feat_vars(spks_b, units=[], feat_name='amps', dist='norm', test='ks', cmap_n Examples -------- - 1) Create a bar plot of the variances of the spike amplitudes for each unit. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get spikes bunch and create the bar plot. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> fig, var_vals, p_vals = bb.plot.feat_vars(spks_b) + 1) Create a bar plot of the variances of the spike amplitudes for all units. + >>> fig, var_vals, p_vals = bb.plot.feat_vars(units_b) ''' - # Get units bunch and calculate variances. - units_b = bb.processing.get_units_bunch(spks_b, ['depths']) - if len(units) != 0: # we're using a subset of all units + # Get units. + if not(units is None): # we're using a subset of all units unit_list = list(units_b['depths'].keys()) # For each unit in `unit_list`, remove unit from `units_b` if not in `units` [units_b['depths'].pop(unit) for unit in unit_list if not(int(unit) in units)] unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing units + # Calculate variances for all units - p_vals_b, variances_b = bb.metrics.unit_stability(spks_b, units=units, feat_names=[feat_name], - dist=dist, test=test) + p_vals_b, variances_b = bb.metrics.unit_stability( + units_b, units=units, feat_names=[feat_name], dist=dist, test=test) var_vals = np.array(tuple(variances_b[feat_name].values())) p_vals = np.array(tuple(p_vals_b[feat_name].values())) + # Specify and remove bad units (i.e. missing unit numbers from spike sorter output). bad_units = np.where(np.isnan(var_vals))[0] if len(bad_units) > 0: @@ -79,29 +87,35 @@ def feat_vars(spks_b, units=[], feat_name='amps', dist='norm', test='ks', cmap_n good_units = unit_list else: good_units = unit_list + # Get depth of max amplitude channel for good units depths = np.asarray([np.mean(units_b['depths'][str(unit)]) for unit in good_units]) + # Create unit normalized colormap based on `depths`, sorted by depth. cmap = plt.cm.get_cmap(cmap_name) depths_norm = depths / np.max(depths) rgba = np.asarray([cmap(depth) for depth in np.sort(np.flip(depths_norm))]) + # Plot depth-color-coded h bar plot of variances for `feature` for each unit, where units are # sorted descendingly by depth along y-axis. - fig, ax = plt.subplots() + if ax is None: + ax = plt.gca() + fig = ax.figure ax.barh(y=[int(unit) for unit in good_units], width=var_vals[np.argsort(depths)], color=rgba) cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax) max_d = np.max(depths) - cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels` tick_labels = [int(max_d) * tick for tick in (0, 0.2, 0.4, 0.6, 0.8, 1.0)] + cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels` cbar.set_ticklabels(tick_labels) ax.set_title('{feat} variance'.format(feat=feat_name)) ax.set_ylabel('unit number (sorted by depth)') ax.set_xlabel('variance') cbar.set_label('depth', rotation=0) - return fig, var_vals, p_vals + + return var_vals, p_vals -def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): +def feat_cutoff(feat, feat_name, unit, spks_per_bin=20, sigma=5, min_num_bins=50, ax=None): ''' Plots the pdf of an estimated symmetric spike feature distribution, with a vertical cutoff line that indicates the approximate fraction of spikes missing from the distribution, assuming the @@ -109,23 +123,24 @@ def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - unit : int - The unit number for the feature to plot. - feat_name : string (optional) + feat : ndarray + The spikes' feature values. + feat_name : string The spike feature to plot. + unit : int + The unit from which the spike feature distribution comes from. spks_per_bin : int (optional) The number of spikes per bin from which to compute the spike feature histogram. sigma : int (optional) The standard deviation for the gaussian kernel used to compute the pdf from the spike feature histogram. + min_num_bins : int (optional) + The minimum number of bins used to compute the spike feature histogram. + ax : axessubplot (optional) + The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- - fig : figure - A figure object containing the plot. fraction_missing : float The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an accurate estimate isn't possible. @@ -138,65 +153,58 @@ def feat_cutoff(spks_b, unit, feat_name='amps', spks_per_bin=20, sigma=5): -------- 1) Plot cutoff line indicating the fraction of spikes missing from a unit based on the recorded unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch, a units bunch, and plot feature cutoff for spike amplitudes for unit1 - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> fig, fraction_missing = bb.plot.feat_cutoff(spks_b, unit=1, feat_name='amps') + >>> feat = units_b['amps']['1'] + >>> fraction_missing = bb.plot.feat_cutoff(feat, feat_name='amps', unit=1) ''' - # Calculate and plot the feature distribution histogram and pdf with symmetric cutoff: + # Calculate the feature distribution histogram and fraction of spikes missing. fraction_missing, pdf, cutoff_idx = \ - bb.metrics.feat_cutoff(spks_b, unit, feat_name=feat_name, - spks_per_bin=spks_per_bin, sigma=sigma) - fig, ax = plt.subplots(nrows=1, ncols=2) - units_b = bb.processing.get_units_bunch(spks_b, [feat_name]) - feature = units_b[feat_name][str(unit)] - num_bins = np.int(feature.size / spks_per_bin) - ax[0].hist(feature, bins=num_bins) - ax[0].set_xlabel('{0}'.format(feat_name)) - ax[0].set_ylabel('count') - ax[0].set_title('histogram of {0} for unit{1}'.format(feat_name, str(unit))) - ax[1].plot(pdf) - ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r') - ax[1].set_xlabel('bin number') - ax[1].set_ylabel('density') - ax[1].set_title('cutoff of pdf at end of symmetry around peak\n' - '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) - return fig, fraction_missing - - -def single_unit_wf_comp(ephys_file, spks_b, clstrs_b, unit, n_ch=20, ts1='start', ts2='end', - n_spks=100, sr=30000, n_ch_probe=385, dtype='int16', car=True, - col=['b', 'r']): + bb.metrics.feat_cutoff(feat, spks_per_bin, sigma, min_num_bins) + + # Plot. + if ax is None: # create two axes + fig, ax = plt.subplots(nrows=1, ncols=2) + if ax is None or len(ax) == 2: # plot histogram and pdf on two separate axes + num_bins = np.int(feat.size / spks_per_bin) + ax[0].hist(feat, bins=num_bins) + ax[0].set_xlabel('{0}'.format(feat_name)) + ax[0].set_ylabel('count') + ax[0].set_title('histogram of {0} for unit{1}'.format(feat_name, str(unit))) + ax[1].plot(pdf) + ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r') + ax[1].set_xlabel('bin number') + ax[1].set_ylabel('density') + ax[1].set_title('cutoff of pdf at end of symmetry around peak\n' + '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) + else: # just plot pdf + ax = ax[0] + ax.plot(pdf) + ax.vlines(cutoff_idx, 0, np.max(pdf), colors='r') + ax.set_xlabel('bin number') + ax.set_ylabel('density') + ax.set_title('cutoff of pdf at end of symmetry around peak\n' + '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) + + return fraction_missing + + +def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', car=True, + col=['b', 'r'], ax=None): ''' - Plots waveforms from a single unit across a specified number of channels between two separate - time periods, after (optionally) common-average-referencing. In this way, waveforms can be - compared to see if there is, e.g. drift during the recording. + Plots two different sets of waveforms across specified channels after (optionally) + common-average-referencing. In this way, waveforms can be compared to see if there is, + e.g. drift during the recording, or if two units should be merged, or one unit should be split. Parameters ---------- ephys_file : string The file path to the binary ephys data. - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - clstrs_b : bunch - A clusters bunch containing fields with unit information (e.g. mean amps, channel of max - amplitude, etc...), used here to extract the channel of max amplitude for the given unit. - unit : int - The unit number for the waveforms to plot. - n_ch : int (optional) - The number of channels around the channel of max amplitude to plot. - ts1 : array_like (optional) - A set of timestamps for which to compare waveforms with `ts2`. - ts2: array_like (optional) + ts1 : array_like + A set of timestamps for which to compare waveforms with `ts2`. + ts2: array_like A set of timestamps for which to compare waveforms with `ts1`. - n_spks: int (optional) - The number of spikes to plot for each channel if `ts1` and `ts2` are kept as their defaults + ch : array-like + The channels to use for extracting and plotting the waveforms. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) @@ -208,11 +216,11 @@ def single_unit_wf_comp(ephys_file, spks_b, clstrs_b, unit, n_ch=20, ts1='start' col: list of strings or float arrays (optional) Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms will be plotted in, respectively. + ax : axessubplot (optional) + The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- - fig : figure - A figure object containing the plot. wf1 : ndarray The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels). wf2 : ndarray @@ -224,92 +232,76 @@ def single_unit_wf_comp(ephys_file, spks_b, clstrs_b, unit, n_ch=20, ts1='start' See Also -------- io.extract_waveforms + metrics.wf_similarity Examples -------- 1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth - minutes for unit2, across 15 channels around the mean. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch, a clusters bunch, and plot the first and last 100 waveforms for - # unit1 across 20 channels. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') - >>> fig, wf1, wf2, _ = bb.plot.single_unit_wf_comp(path_to_ephys_file, spks_b, clstrs_b, - unit=1) - # Get a units bunch, and plot waveforms for unit2 from the first and fourth minutes - # across 15 channels. - >>> units_b = bb.processing.get_units_bunch(spks_b, ['times']) + minutes for unit2, across 10 channels around the mean. + # Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1: + >>> ts1 = units_b['times']['1'][:100] + >>> ts2 = units_b['times']['1'][-100:] + >>> max_ch = clstrs_b['channels'][1] + >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. + >>> ch = np.arange(max_ch, max_ch + 20) + >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. + >>> ch = np.arange(max_ch - 20, max_ch) + >>> else: # take `n_c_ch` around `max_ch`. + >>> ch = np.arange(max_ch - 10, max_ch + 10) + >>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch) + # Plot waveforms for unit2 from the first and fourth minutes across 10 channels. >>> ts = units_b['times']['2'] - >>> ts1 = ts[np.where(ts<60)[0]] - >>> ts2 = ts[np.where(ts>180)[0][:len(ts1)]] - >>> fig2, wf1_2, wf2_2, _ = bb.plot.single_unit_wf_comp(path_to_ephys_file, spks_b, - clstrs_b, unit=2, ts1=ts1, ts2=ts2) + >>> ts1_2 = ts[np.where(ts<60)[0]] + >>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]] + >>> max_ch = clstrs_b['channels'][2] + >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. + >>> ch = np.arange(max_ch, max_ch + 10) + >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. + >>> ch = np.arange(max_ch - 10, max_ch) + >>> else: # take `n_c_ch` around `max_ch`. + >>> ch = np.arange(max_ch - 5, max_ch + 5) + >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch) ''' - # Take the first and last `n_spks` timestamps by default. - if ts1 == 'start' or ts2 == 'end': - units_b = bb.processing.get_units_bunch(spks_b, ['times']) - ts1 = units_b['times'][str(unit)][:n_spks] if ts1 == 'start' else ts1 - ts2 = units_b['times'][str(unit)][-n_spks:] if ts2 == 'end' else ts2 - # Get the channel of max amplitude and `n_ch` around it. - max_ch = clstrs_b['channels'][unit] - n_c_ch = n_ch // 2 - if max_ch < n_c_ch: # take only channels greater than `max_ch`. - ch = np.arange(max_ch, max_ch + n_ch) - elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. - ch = np.arange(max_ch - n_ch, max_ch) - else: # take `n_c_ch` around `max_ch`. - ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch) # Extract the waveforms for these timestamps and compute similarity score. wf1 = bb.io.extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, car=car) wf2 = bb.io.extract_waveforms(ephys_file, ts2, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, car=car) s = bb.metrics.wf_similarity(wf1, wf2) + # Plot these waveforms against each other. - fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean + n_ch = len(ch) + if ax is None: + fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean for cur_ax, cur_ch in enumerate(ch): ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0]) ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1]) ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0]) ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1]) ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch)) - ax[0][0].set_title('all waveforms') - ax[0][1].set_title('mean waveforms') + ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s)) + ax[0][1].set_title('Mean Waveforms') plt.legend(['1st spike set', '2nd spike set']) - fig.suptitle('comparison of waveforms from two sets of spikes for unit {0} \ - \n s = {1:.2f}'.format(unit, s)) - return fig, wf1, wf2, s + + return wf1, wf2, s -def amp_heatmap(ephys_file, spks_b, clstrs_b, unit, t='all', n_ch=20, sr=30000, n_ch_probe=385, - dtype='int16', cmap_name='RdBu', car=True): +def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cmap_name='RdBu', + car=True, ax=None): ''' - Plots a heatmap of the normalized voltage values over space and time at the timestamps of a - particular unit over a specified number of channels, after (optionally) - common-average-referencing. + Plots a heatmap of the normalized voltage values over time and space for given timestamps and + channels, after (optionally) common-average-referencing. Parameters ---------- ephys_file : string The file path to the binary ephys data. - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - clstrs_b : bunch - A clusters bunch containing fields with unit information (e.g. mean amps, channel of max - amplitude, etc...), used here to extract the channel of max amplitude for the given unit. - unit : int - The unit number for which to plot the amp heatmap. - t : str or pair of floats (optional) - The time period from which to get the spike amplitudes. Default: all spike amplitudes. - n_ch: int (optional) - The number of channels for which to plot the amp heatmap. + ts: array_like + A set of timestamps for which to get the voltage values. + ch : array-like + The channels to use for extracting the voltage values. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) @@ -320,48 +312,36 @@ def amp_heatmap(ephys_file, spks_b, clstrs_b, unit, t='all', n_ch=20, sr=30000, The name of the colormap associated with the plot. car: bool (optional) A flag for whether or not to perform common-average-referencing before extracting waveforms + ax : axessubplot (optional) + The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- - fig : figure - A figure object containing the plot. - v_vals_norm : ndarray - The unit-normalized voltage values displayed in `fig`. + v_vals : ndarray + The voltage values. Examples -------- 1) Plot a heatmap of the spike amplitudes across 20 channels around the channel of max - amplitude for unit1. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch, a clusters bunch, and plot heatmap for unit1 across 20 channels. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') - >>> bb.plot.amp_heatmap(path_to_ephys_file, spks_b, clstrs_b, unit=1) + amplitude for all spikes in unit 1. + >>> ts = units_b['times']['1'] + >>> max_ch = clstrs_b['channels'][1] + >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. + >>> ch = np.arange(max_ch, max_ch + 20) + >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. + >>> ch = np.arange(max_ch - 20, max_ch) + >>> else: # take `n_c_ch` around `max_ch`. + >>> ch = np.arange(max_ch - 10, max_ch + 10) + >>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch) ''' # Get memmapped array of `ephys_file` item_bytes = np.dtype(dtype).itemsize n_samples = op.getsize(ephys_file) // (item_bytes * n_ch_probe) file_m = np.memmap(ephys_file, shape=(n_samples, n_ch_probe), dtype=dtype, mode='r') - # Get voltage values for each peak amplitude sample for `n_ch` around `max_ch`: - # Get the channel of max amplitude and `n_ch` around it. - max_ch = clstrs_b['channels'][unit] - n_c_ch = n_ch // 2 - if max_ch < n_c_ch: # take only channels greater than `max_ch`. - ch = np.arange(max_ch, max_ch + n_ch) - elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. - ch = np.arange(max_ch - n_ch, max_ch) - else: # take `n_c_ch` around `max_ch`. - ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch) - import pdb - pdb.set_trace() - unit_idxs = np.where(spks_b['clusters'] == unit)[0] - max_amp_samples = spks_b['samples'][unit_idxs] - ts = spks_b['times'][unit_idxs] + + # Get voltage values for each peak amplitude sample for `ch`. + max_amp_samples = (ts * sr).astype(int) v_vals = file_m[np.ix_(max_amp_samples, ch)] if car: # Compute and subtract temporal and spatial noise from `v_vals`. # Get subset of time (from first to last max amp sample) @@ -373,35 +353,34 @@ def amp_heatmap(ephys_file, spks_b, clstrs_b, unit, t='all', n_ch=20, sr=30000, noise_s = np.median(file_m[np.ix_(t_subset, ch)], axis=0, out=out_noise_s) v_vals -= noise_t[max_amp_samples - max_amp_samples[0], None] v_vals -= noise_s[None, :] + # Plot heatmap. + if ax is None: + ax = plt.gca() v_vals_norm = (v_vals / np.max(abs(v_vals))).T - fig, ax = plt.subplots() cbar_map = ax.imshow(v_vals_norm, cmap=cmap_name, aspect='auto', extent=[ts[0], ts[-1], ch[0], ch[-1]], origin='lower') ax.set_yticks(np.arange(ch[0], ch[-1], 5)) - ax.set_ylabel('channel numbers') - ax.set_xlabel('time (s)') - ax.set_title('heatmap of voltage at unit{0} timestamps'.format(unit)) + ax.set_ylabel('Channel Numbers') + ax.set_xlabel('Time (s)') + ax.set_title('Voltage Heatmap') + fig = ax.figure cbar = fig.colorbar(cbar_map, ax=ax) cbar.set_label('V', rotation=90) - return fig, v_vals_norm + + return v_vals -def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True): +def firing_rate(ts, hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True, ax=None): ''' - Plots the instantaneous firing rate of a unit over time, and optionally overlays the value of - the coefficient of variation of the firing rate for a specified number of bins. + Plots the instantaneous firing rate of for given spike timestamps over time, and optionally + overlays the value of the coefficient of variation of the firing rate for a specified number + of bins. Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - unit : int - The unit number for which to calculate the firing rate. - t : str or pair of floats (optional) - The total time period for which the instantaneous firing rate is returned. Default: the - time period from `unit`'s first to last spike. + ts : ndarray + The spike timestamps from which to compute the firing rate. hist_win : float (optional) The time window (in s) to use for computing spike counts. fr_win : float (optional) @@ -411,11 +390,11 @@ def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bins=10, sho show_fr_cv : bool (optional) A flag for whether or not to compute and show the coefficients of variation of the firing rate for `n_bins`. + ax : axessubplot (optional) + The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- - fig : figure - A figure object containing the plot. fr: ndarray The instantaneous firing rate over time (in hz). cv: float @@ -432,33 +411,27 @@ def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bins=10, sho Examples -------- - 1) Plot the firing rate for unit1 from the time of its first to last spike, showing the cv - of the firing rate for 10 evenly spaced bins, and plot the firing rate for unit2 from the first - to second minute, without showing the cv. - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and calculate the firing rate. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> fig, fr_1, cv_1, cvs_1 = bb.plot.firing_rate(spks_b, unit=1) - >>> fig2, fr_2 = bb.plot.firing_rate(spks_b, unit=2, t=[60,120], show_fr_cv=False) + 1) Plot the firing rate for unit 1 from the time of its first to last spike, showing the cv + of the firing rate for 10 evenly spaced bins. + >>> ts = units_b['times']['1'] + >>> fr, cv, cvs = bb.plot.firing_rate(ts) ''' - fig, ax = plt.subplots() + + if ax is None: + ax = plt.gca() if not(show_fr_cv): # compute just the firing rate - fr = bb.singlecell.firing_rate(spks_b, unit, t=t, hist_win=hist_win, fr_win=fr_win) + fr = bb.singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) else: # compute firing rate and coefficients of variation - cv, cvs, fr = bb.metrics.firing_rate_coeff_var(spks_b, unit, t=t, hist_win=hist_win, - fr_win=fr_win, n_bins=n_bins) + cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts, hist_win=hist_win, fr_win=fr_win, + n_bins=n_bins) x = np.arange(fr.size) * hist_win ax.plot(x, fr) - ax.set_title('Firing Rate for Unit {0}'.format(unit)) + ax.set_title('Firing Rate') ax.set_xlabel('Time (s)') ax.set_ylabel('Rate (s$^-1$)') if not(show_fr_cv): - return fig, fr + return fr else: # show coefficients of variation y_max = np.max(fr) * 1.05 x_l = x[np.int(x.size / n_bins)] @@ -468,7 +441,7 @@ def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5, n_bins=10, sho # Plot text with cv of firing rate for each bin. [ax.text(x_l * (i + 1), y_max, 'cv={0:.2f}'.format(cvs[i]), fontsize=9, ha='right') for i in range(n_bins)] - return fig, fr, cv, cvs + return fr, cv, cvs def peri_event_time_histogram( diff --git a/brainbox/processing/processing.py b/brainbox/processing/processing.py index 50e21b227..96b32f0bb 100644 --- a/brainbox/processing/processing.py +++ b/brainbox/processing/processing.py @@ -237,7 +237,6 @@ def get_units_bunch(spks_b, *args): >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and filter the units. >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') >>> units_b = bb.processing.get_units_bunch(spks_b) # Get amplitudes for unit 4. @@ -268,15 +267,18 @@ def get_units_bunch(spks_b, *args): return units_b -def filter_units(spks_b, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, 'rp': 0.002}): +def filter_units(units_b, t, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, 'rp': 0.002}): ''' Filters units according to some parameters. Parameters ---------- - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. + units_b : bunch + A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.) + whose values are arrays that hold values for each unit. The arrays for each key are ordered + by unit ID. + t : float + Duration of the recording session. params : dict Parameters to use to filter the units: 'min_amp' : The minimum mean amplitude (in uV) of the spikes in the unit @@ -302,9 +304,18 @@ def filter_units(spks_b, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and filter the units. + # Get a spikes bunch, units bunch, and filter the units. >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> filtered_units_mask = bb.processing.filter_units(spks_b) + >>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters']) + >>> T = spks_b['times'][-1] - spks_b['times'][0] + >>> filtered_units_mask = bb.processing.filter_units(units_b, T) + # Get an array of the filtered units` ids. + filtered_units = np.where(filtered_units_mask)[0] + + 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false + positive rate of 0.2, given a refractory period of 2 ms. + >>> filtered_units_mask = bb.processing.filter_units( + units_b, T, params={'min_amp': 0, 'min_fr': 1, 'max_fpr': 0.2, 'rp': 0.002}) # Get an array of the filtered units` ids. filtered_units = np.where(filtered_units_mask)[0] ''' @@ -314,13 +325,9 @@ def filter_units(spks_b, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, warnings.filterwarnings('ignore', r'invalid value encountered in greater') warnings.filterwarnings('ignore', r'invalid value encountered in less') - # Get units bunch and number of units. - units_b = bb.processing.get_units_bunch(spks_b, ['amps', 'times']) - n_units = np.max(spks_b['clusters']) + 1 - - # Get recording duration, and units' mean spike amps, number of spikes, firing rates, - # number of isi violations, and false positive rate. - T = spks_b['times'][-1] - spks_b['times'][0] + # Get units' mean spike amps, number of spikes, firing rates, number of isi violations, + # and false positive rate. + n_units = len(units_b['clusters'].keys()) u_amps = np.zeros((n_units,)) # mean spike amplitude for each unit u_n_spks = np.zeros((n_units,)) # number of spikes for each unit u_fr = np.zeros((n_units,)) # firing rate over entire session for each unit @@ -335,10 +342,10 @@ def filter_units(spks_b, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, else: u_amps[i] = units_b['amps'][str(i)][0] u_n_spks[i] = len(units_b['amps'][str(i)]) - u_fr[i] = u_n_spks[i] / T + u_fr[i] = u_n_spks[i] / t n_isi_viol = len(np.where(np.diff(units_b['times'][str(i)]) < params['rp'])[0]) # false positive rate is min of roots of solved quadratic equation (Hill, et al. 2011) - c = (T * n_isi_viol) / (2 * params['rp'] * u_n_spks[i]**2) # 3rd term in quadratic + c = (t * n_isi_viol) / (2 * params['rp'] * u_n_spks[i]**2) # 3rd term in quadratic u_fpr[i] = np.min(np.abs(np.roots([-1, 1, c]))) # Get units that don't meet `params` requirements, and empty units, and filter them out. @@ -349,4 +356,5 @@ def filter_units(spks_b, params={'min_amp': 100, 'min_fr': 0.5, 'max_fpr': 0.1, np.where(np.isnan(u_amps))[0]))) filtered_units_mask = np.ones((n_units,)) filtered_units_mask[units_to_rm] = 0 + return filtered_units_mask diff --git a/brainbox/singlecell/singlecell.py b/brainbox/singlecell/singlecell.py index e25e873f2..b5400d056 100644 --- a/brainbox/singlecell/singlecell.py +++ b/brainbox/singlecell/singlecell.py @@ -137,7 +137,7 @@ def calculate_peths( return peths, binned_spikes -def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5): +def firing_rate(ts, hist_win=0.01, fr_win=0.5): ''' Computes the instantaneous firing rate of a unit over time by computing a histogram of spike counts over a specified window of time, and summing this histogram over a sliding window of @@ -145,14 +145,8 @@ def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5): Parameters ---------- - spks : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for each unit. - unit : int - The unit number for which to calculate the firing rate. - t : str or pair of floats - The total time period for which the instantaneous firing rate is returned. Default: the - time period from `unit`'s first to last spike. + ts : ndarray + The spike timestamps from which to compute the firing rate.. hist_win : float The time window (in s) to use for computing spike counts. fr_win : float @@ -170,24 +164,20 @@ def firing_rate(spks_b, unit, t='all', hist_win=0.01, fr_win=0.5): Examples -------- - 1) Compute the firing rate for unit1 from the time of its first to last spike. + 1) Compute the firing rate for unit 1 from the time of its first to last spike. >>> import brainbox as bb >>> import alf.io as aio >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get a spikes bunch and calculate the firing rate. + # Load a spikes bunch and get the timestamps for unit 1, and calculate the instantaneous + # firing rate. >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> fr = bb.singlecell.firing_rate(spks_b, 1) + >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] + >>> ts = spks_b['times'][unit_idxs] + >>> fr = bb.singlecell.firing_rate(ts) ''' - # Get unit timestamps. - unit_idxs = np.where(spks_b['clusters'] == unit) - ts = ts = spks_b['times'][unit_idxs] - if t != 'all': - t_first = np.where(ts > t[0])[0][0] - t_last = np.where(ts < t[1])[0][-1] - ts = ts[t_first:t_last] # Compute histogram of spike counts. t_tot = ts[-1] - ts[0] n_bins_hist = np.int(t_tot / hist_win) diff --git a/brainbox/spike_features/spike_features.py b/brainbox/spike_features/spike_features.py index 7353aaec0..a33cd228e 100644 --- a/brainbox/spike_features/spike_features.py +++ b/brainbox/spike_features/spike_features.py @@ -18,16 +18,16 @@ def depth(ephys_file, spks_b, clstrs_b, chnls_b, tmplts_b, unit, n_ch=12, n_ch_p ---------- ephys_file : string The file path to the binary ephys data. - spks : bunch + spks_b : bunch A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, etc.) for all spikes. - clstrs : bunch + clstrs_b : bunch A clusters bunch containing fields with cluster information (e.g. amp, ch of max amp, depth of ch of max amp, etc.) for all clusters. - chnls : bunch + chnls_b : bunch A channels bunch containing fields with channel information (e.g. coordinates, indices, etc.) for all probe channels. - tmplts : bunch + tmplts_b : bunch A unit templates bunch containing fields with unit template information (e.g. template waveforms, etc.) for all unit templates. unit : numeric