Skip to content

Commit

Permalink
metrics: change contamination name
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Nov 12, 2020
1 parent fba46c9 commit 1a67b8c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
16 changes: 8 additions & 8 deletions brainbox/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, dtype='i
return ptp_sigma


def contamination(ts, rp=0.002):
def contamination_ks2(ts, rp=0.002):
"""
An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
the number of spikes, number of isi violations, and time between the first and last spike.
Expand Down Expand Up @@ -643,7 +643,7 @@ def contamination(ts, rp=0.002):
return ce


def contamination_alt(ts, min_time, max_time, rp=0.002, min_isi=0.0001):
def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001):
"""
An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
the number of spikes, number of isi violations, and time between the first and last spike.
Expand Down Expand Up @@ -683,7 +683,7 @@ def contamination_alt(ts, min_time, max_time, rp=0.002, min_isi=0.0001):
1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate
spikes of 0.1 ms.
>>> ts = units_b['times']['1']
>>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001)
>>> ce = bb.metrics.contamination_ks2(ts, min_isi=0.0001)
"""

duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0]
Expand Down Expand Up @@ -936,16 +936,16 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
'amp_median',
'amp_std_dB',
'contamination',
'contamination_alt',
'contamination_ks2',
'drift',
'frac_isi_viol',
'missed_spikes_est',
'noise_cutoff',
'presence_ratio',
'presence_ratio_std',
'slidingRP_viol',
'spike_count',
]
'spike_count'
]

r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list})
r['cluster_id'] = cluster_ids
Expand Down Expand Up @@ -983,8 +983,8 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,

# compute metrics
r.frac_isi_viol[ic], _, _ = isi_viol(ts, rp=params['refractory_period'])
r.contamination[ic] = contamination(ts, rp=params['refractory_period'])
r.contamination_alt[ic], _ = contamination_alt(
r.contamination_ks2[ic] = contamination_ks2(ts, rp=params['refractory_period'])
r.contamination[ic], _ = contamination(
ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi'])
r.slidingRP_viol[ic] = slidingRP_viol(ts,
bin_size=params['bin_size'],
Expand Down
Empty file.
14 changes: 7 additions & 7 deletions brainbox/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import alf.io
from brainbox.metrics import quick_unit_metrics

REC_LEN_SECS = 1000
Expand All @@ -16,7 +15,7 @@ def multiple_spike_trains(firing_rates=None, rec_len_secs=1000, cluster_ids=None
if firing_rates is None:
firing_rates = np.random.randint(150, 600, 10)
if cluster_ids is None:
clusters_ids = np.arange(firing_rates.size)
cluster_ids = np.arange(firing_rates.size)
st = np.empty(0)
sa = np.empty(0)
sc = np.empty(0)
Expand All @@ -38,7 +37,7 @@ def single_spike_train(firing_rate=200, rec_len_secs=1000):
Basic spike train generator following a poisson process for spike-times and
:param firing_rate:
:param rec_len_secs:
:return:
:return: spike_times (secs) , spike_amplitudes (V)
"""

# spike times: exponential decay prob
Expand All @@ -47,21 +46,22 @@ def single_spike_train(firing_rate=200, rec_len_secs=1000):

# spike amplitudes: log normal (estimated from an IBL session)
nspi = np.size(st)
sa = np.exp(np.random.normal(5.5, 0.5, nspi)) / 1e6
sa = np.exp(np.random.normal(5.5, 0.5, nspi)) / 1e6 # output is in V

return st, sa


def test_clusters_metrics():
t, a, c = multiple_spike_trains(firing_rates=[3, 200, 259, 567], rec_len_secs=1000,
cluster_ids=[0, 1, 3, 4])
frs = [3, 200, 259, 567] # firing rates
t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=1000, cluster_ids=[0, 1, 3, 4])
d = np.sin(2 * np.pi * c / 1000 * t) * 100 # sinusoidal shift where cluster id drives period
dfm = quick_unit_metrics(c, t, a, d)

assert np.allclose(dfm['amp_median'] / np.exp(5.5) * 1e6, 1, rtol=1.1)
assert np.allclose(dfm['amp_std_dB'] / 20 * np.log10(np.exp(0.5)), 1, rtol=1.1)
assert np.allclose(dfm['drift'], np.array([0, 1, 3, 4]) * 100 * 4 * 3.6, rtol=1.1)

# probe_path = "/datadisk/FlatIron/mainenlab/Subjects/ZFM-01577/2020-11-04/001/alf/probe00"
np.allclose(dfm['firing_rate'], frs)
# probe_path = "/datadisk/FlatIron/m1ainenlab/Subjects/ZFM-01577/2020-11-04/001/alf/probe00"
# spikes = alf.io.load_object(probe_path, 'spikes')
# quick_unit_metrics(spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'])
1 change: 0 additions & 1 deletion ibllib/ephys/ephysqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def unit_metrics_ks2(ks2_path=None, m=None, save=True):

# compute labels based on metrics
df_labels = pd.DataFrame(unit_labels(m.spike_clusters, m.spike_times, m.amplitudes))
# add labels to metrics dataframe
r = r.set_index('cluster_id', drop=False).join(df_labels.set_index('cluster_id'))

# include the ks2 cluster contamination if `cluster_ContamPct` file exists
Expand Down

0 comments on commit 1a67b8c

Please sign in to comment.