Skip to content

Commit

Permalink
drift display
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Aug 26, 2020
1 parent b4f21f0 commit ec0599e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
26 changes: 15 additions & 11 deletions brainbox/metrics/electrode_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
from brainbox.processing import bincount2D


def estimate_drift(spike_times, spike_amps, spike_depths):
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
"""
Estimate drift for spike sorted data.
:param spike_times:
:param spike_amps:
:param spike_depths:
:param display:
:return:
"""
# binning parameters
DT_SECS = 1 # output sampling rate of the depth estimation (seconds)
DEPTH_BIN_UM = 2 # binning parameter for depth
AMP_RES_V = 100 * 1e-6 # binning parameter for amplitudes
NXCORR = 50 # positive and negative lag in depth samples to look for depth
NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate)
DISPLAY = False

# experimental: try the amp with a log scale
na = int(np.ceil(np.nanmax(spike_amps) / AMP_RES_V))
Expand All @@ -30,22 +37,19 @@ def estimate_drift(spike_times, spike_amps, spike_depths):
# compute the depth lag by xcorr
# experimental: LP the fft for a better tracking ?
atd_ = np.fft.fft(atd_hist, axis=-1)
xcorr = np.real(np.fft.ifft(atd_ * np.conj(atd_[:, 0, :])[:, np.newaxis, :]))
xcorr = np.real(np.fft.ifft(atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
xcorr = np.sum(xcorr, axis=0)
xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]

# experimental: parabolic fit to get max values
raw_drift = (np.argmax(xcorr, axis=-1) - NXCORR) * DEPTH_BIN_UM
drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning')

if DISPLAY:
if display:
import matplotlib.pyplot as plt
from ibllib.plots import Density
Density(atd_hist[5, :, :])
plt.figure()
plt.plot(xcorr.transpose())
plt.figure()
plt.plot(raw_drift - 50)
plt.plot(drift - 50)
from brainbox.plot import driftmap
_, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True)
axs[0].plot(DT_SECS * np.arange(drift.size), drift)
driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])

return drift
2 changes: 1 addition & 1 deletion brainbox/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def driftmap(ts, feat, ax=None, plot_style='bincount',
R, times, depths = bb.processing.bincount2D(
ts[iok], feat[iok], t_bin, d_bin, weights=weights)
# plot raster map
ax.imshow(R, aspect='auto', cmap='binary',
ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=np.std(R) * 4,
extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs)

return cd, md
Expand Down

0 comments on commit ec0599e

Please sign in to comment.