Skip to content

Commit

Permalink
LFP functions (int-brain-lab#160)
Browse files Browse the repository at this point in the history
* Added classifier function

* Started on LFP analysis functions

* Change loading method to spikeglx.Reader

* Changed loading method

* Added example power spectrum plots

* Added coherence plot

* Added spike triggered LFP functions

* Include leave-one-block-out cross validation

* Updated cross-validation procedure

* Added sub selection of neuron functionality

* Output predictions and probabilities

* Take out LFP scripts for now

* Started on example

* Added example

* Started on LFP functions

* Added possibility to shuffle and updated description

* Update LFP scripts

* PEP8 styling

* More PEP8
  • Loading branch information
guidomeijer authored May 8, 2020
1 parent 39f44cb commit 0d29ade
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 33 deletions.
1 change: 1 addition & 0 deletions brainbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import core
from . import experimental
from . import io
from . import lfp
from . import metrics
from . import plot
from . import population
Expand Down
41 changes: 12 additions & 29 deletions brainbox/examples/lfp_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,42 @@
# Read in raw LFP data from probe00
raw = spikeglx.Reader(lf_paths[0])
signal = raw.read(nsel=slice(None, 100000, None), csel=slice(None, None, None))[0]
signal = signal * raw.channel_conversion_sample2v['lf'] # Convert samples into uV
signal = np.rot90(signal)

ts = one.load(eid[0], 'ephysData.raw.timestamps')

# %% Calculate power spectrum and coherence between two random channels

ps_freqs, ps = bb.lfp.power_spectrum(signal, fs=raw.fs)
ps_freqs, ps = bb.lfp.power_spectrum(signal, fs=raw.fs, segment_length=1, segment_overlap=0.5)
random_ch = np.random.choice(raw.nc, 2)
coh_freqs, coh, phase_lag = bb.lfp.coherence(signal[random_ch[0], :],
signal[random_ch[1], :], fs=raw.fs)

# %% Create power spectrum and coherence plot

fig = plt.figure(figsize=(18, 12))
gs = GridSpec(2, 2, figure=fig)
cmap = sns.cubehelix_palette(dark=1, light=0, as_cmap=True)

gs = GridSpec(3, 2, figure=fig)
cmap = sns.color_palette('cubehelix', 50)
ax1 = fig.add_subplot(gs[:, 0])
sns.heatmap(data=np.log10(ps[:, ps_freqs < 140]), cbar=True, ax=ax1, yticklabels=50,
cmap=cmap, cbar_kws={'label': 'log10 power ($V^2$)'})
ax1.set(xticks=np.arange(0, np.sum(ps_freqs < 140), 50),
xticklabels=np.array(ps_freqs[np.arange(0, np.sum(ps_freqs < 140), 50)], dtype=int),
ylabel='Channels', xlabel='Frequency (Hz)')


ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(ps_freqs, ps[random_ch[0], :])
ax2.set(xlim=[1, 140], yscale='log', ylabel='Power ($V^2$)',
ax2.plot(signal[random_ch[0]])
ax2.set(ylabel='Power ($V^2$)',
xlabel='Frequency (Hz)', title='Channel %d' % random_ch[0])

ax3 = fig.add_subplot(gs[1, 1])
ax3.plot(coh_freqs, coh)
ax3.set(xlim=[1, 140], ylabel='Coherence', xlabel='Frequency (Hz)',
ax3.plot(ps_freqs, ps[random_ch[0], :])
ax3.set(xlim=[1, 140], yscale='log', ylabel='Power ($V^2$)',
xlabel='Frequency (Hz)', title='Channel %d' % random_ch[0])

ax4 = fig.add_subplot(gs[2, 1])
ax4.plot(coh_freqs, coh)
ax4.set(xlim=[1, 140], ylabel='Coherence', xlabel='Frequency (Hz)',
title='Channel %d and %d' % (random_ch[0], random_ch[1]))

plt.tight_layout(pad=5)

# %% Calculate spike triggered average

# Read in spike data
spikes = one.load_object(eid[0], 'spikes')
clusters = one.load_object(eid[0], 'clusters')

# Pick two random neurons
random_neurons = np.random.choice(
clusters.metrics.cluster_id[clusters.metrics.ks2_label == 'good'], 2)
spiketrain = spikes.times[spikes.clusters == random_neurons[0]]
sta, time = bb.lfp.spike_triggered_average(signal[random_ch[0], :], spiketrain)

# %% Plot spike triggered LFP

f, ax1 = plt.subplots(1, 1)

ax1.plot(time, sta)
ax1.set(ylabel='Spike triggered LFP average (uV)', xlabel='Time (ms)')
1 change: 1 addition & 0 deletions brainbox/lfp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lfp import *
115 changes: 115 additions & 0 deletions brainbox/lfp/lfp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 13 14:57:53 2020
Functions to analyse LFP signals
@author: Guido Meijer
"""

from scipy.signal import welch, csd, filtfilt, butter
import numpy as np


def butter_filter(signal, highpass_freq=None, lowpass_freq=None, order=4, fs=2500):

# The filter type is determined according to the values of cut-off frequencies
Fn = fs / 2.
if lowpass_freq and highpass_freq:
if highpass_freq < lowpass_freq:
Wn = (highpass_freq / Fn, lowpass_freq / Fn)
btype = 'bandpass'
else:
Wn = (lowpass_freq / Fn, highpass_freq / Fn)
btype = 'bandstop'
elif lowpass_freq:
Wn = lowpass_freq / Fn
btype = 'lowpass'
elif highpass_freq:
Wn = highpass_freq / Fn
btype = 'highpass'
else:
raise ValueError("Either highpass_freq or lowpass_freq must be given")

# Filter signal
b, a = butter(order, Wn, btype=btype, output='ba')
filtered_data = filtfilt(b=b, a=a, x=signal, axis=1)

return filtered_data


def power_spectrum(signal, fs=2500, segment_length=0.5, segment_overlap=0.5, scaling='density'):
"""
Calculate the power spectrum of an LFP signal
Parameters
----------
signal : 2D array
LFP signal from different channels in V with dimensions (channels X samples)
fs : int
Sampling frequency
segment_length : float
Length of the segments for which the spectral density is calcualted in seconds
segment_overlap : float
Fraction of overlap between the segments represented as a float number between 0 (no
overlap) and 1 (complete overlap)
Returns
----------
freqs : 1D array
Frequencies for which the spectral density is calculated
psd : 2D array
Power spectrum in V^2 with dimensions (channels X frequencies)
"""

# Transform segment from seconds to samples
segment_samples = int(fs * segment_length)
overlap_samples = int(segment_overlap * segment_samples)

# Calculate power spectrum
freqs, psd = welch(signal, fs=fs, nperseg=segment_samples, noverlap=overlap_samples,
scaling=scaling)
return freqs, psd


def coherence(signal_a, signal_b, fs=2500, segment_length=1, segment_overlap=0.5):
"""
Calculate the coherence between two LFP signals
Parameters
----------
signal_a : 1D array
LFP signal from different channels with dimensions (channels X samples)
fs : int
Sampling frequency
segment_length : float
Length of the segments for which the spectral density is calcualted in seconds
segment_overlap : float
Fraction of overlap between the segments represented as a float number between 0 (no
overlap) and 1 (complete overlap)
Returns
----------
freqs : 1D array
Frequencies for which the coherence is calculated
coherence : 1D array
Coherence takes a value between 0 and 1, with 0 or 1 representing no or perfect coherence,
respectively
phase_lag : 1D array
Estimate of phase lag in radian between the input time series for each frequency
"""

# Transform segment from seconds to samples
segment_samples = int(fs * segment_length)
overlap_samples = int(segment_overlap * segment_samples)

# Calculate coherence
freqs, Pxx = welch(signal_a, fs=fs, nperseg=segment_samples, noverlap=overlap_samples)
_, Pyy = welch(signal_b, fs=fs, nperseg=segment_samples, noverlap=overlap_samples)
_, Pxy = csd(signal_a, signal_b, fs=fs, nperseg=segment_samples, noverlap=overlap_samples)
coherence = np.abs(Pxy) ** 2 / (Pxx * Pyy)
phase_lag = np.angle(Pxy)

return freqs, coherence, phase_lag
19 changes: 15 additions & 4 deletions brainbox/population/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import KFold, LeaveOneOut, LeaveOneGroupOut
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score
from sklearn.utils import shuffle as sklearn_shuffle


def _get_spike_counts_in_bins(spike_times, spike_clusters, intervals):
Expand Down Expand Up @@ -193,7 +194,7 @@ def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):

def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, post_time=0.5,
classifier='bayes', cross_validation='kfold', num_splits=5, prob_left=None,
custom_validation=None, n_neurons='all', iterations=1):
custom_validation=None, n_neurons='all', iterations=1, shuffle=False):
"""
Use decoding to classify groups of trials (e.g. stim left/right). Classification is done using
the population vector of summed spike counts from the specified time window. Cross-validation
Expand Down Expand Up @@ -247,6 +248,12 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
(split2_train_idxs, split2_test_idxs),
(split3_train_idxs, split3_test_idxs),
...)
n_neurons : string or integer
number of neurons to randomly subselect from the population (default is 'all')
iterations : int
number of times to repeat the decoding (especially usefull when subselecting neurons)
shuffle : boolean
whether to shuffle the trial labels each decoding iteration
Returns
-------
Expand Down Expand Up @@ -314,6 +321,10 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
use_neurons = np.random.choice(pop_vector.shape[1], n_neurons, replace=False)
sub_pop_vector = pop_vector[:, use_neurons]

# Shuffle trail labels if necessary
if shuffle is True:
event_groups = sklearn_shuffle(event_groups)

if cross_validation == 'none':

# Fit the model on all the data and predict
Expand Down Expand Up @@ -374,15 +385,15 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
'confusion_matrix': conf_matrix_norm,
'n_groups': np.shape(np.unique(event_groups))[0],
'classifier': classifier, 'cross_validation': '%d-fold' % num_splits,
'iterations': iterations})
'iterations': iterations, 'shuffle': shuffle})

else:
results = dict({'accuracy': acc, 'f1': f1, 'auroc': auroc,
'predictions': pred, 'probabilities': prob,
'confusion_matrix': conf_matrix_norm,
'n_groups': np.shape(np.unique(event_groups))[0],
'classifier': classifier, 'cross_validation': cross_validation,
'iterations': iterations})

'iterations': iterations, 'shuffle': shuffle})
return results


Expand Down

0 comments on commit 0d29ade

Please sign in to comment.