Skip to content

Commit

Permalink
Merge pull request int-brain-lab#215 from int-brain-lab/guido_WIP
Browse files Browse the repository at this point in the history
Wrote a Brainbox function to map values on atlas
  • Loading branch information
nbonacchi authored Nov 17, 2020
2 parents 927d2e5 + fab40ec commit 9b2f2ae
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 20 deletions.
1 change: 1 addition & 0 deletions brainbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from brainbox import atlas
from . import behavior
from . import core
from . import experimental
Expand Down
1 change: 1 addition & 0 deletions brainbox/atlas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from brainbox.atlas.atlas import plot_atlas
137 changes: 137 additions & 0 deletions brainbox/atlas/atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
'''
Functions which map metrics to the Allen atlas.
Code by G. Meijer
'''

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from ibllib import atlas


def _label2values(imlabel, fill_values, ba):
"""
Fills a slice from the label volume with values to display
:param imlabel: 2D np-array containing label ids (slice of the label volume)
:param fill_values: 1D np-array containing values to fill into the slice
:return: 2D np-array filled with values
"""
im_unique, ilabels, iim = np.unique(imlabel, return_index=True, return_inverse=True)
_, ir_unique, _ = np.intersect1d(ba.regions.id, im_unique, return_indices=True)
im = np.squeeze(np.reshape(fill_values[ir_unique[iim]], (*imlabel.shape, 1)))
return im


def plot_atlas(regions, values, ML=-1, AP=0, DV=-1, color_palette='Reds',
minmax=None, axs=None, custom_region_list=None):
"""
Plot a sagittal, coronal and horizontal slice of the Allen atlas with regions colored in
according to any value that the user specifies.
Parameters
----------
regions : 1D array
Array of strings with the acronyms of brain regions (in Allen convention) that should be
filled with color
values : 1D array
Array of values that correspond to the brain region acronyms
ML, AP, DV : float
The coordinates of the slices in mm
color_palette : any input that can be interpreted by sns.color_palette
The color palette of the plot
minmax : 2 element array
The min and max of the color map, if None it uses the min and max of values
axs : 3 element list of axis
A list of the three axis in which to plot the three slices
custom_region_list : 1D array with shape the same as ba.regions.acronym.shape
Input any custom list of acronyms that replaces the default list of acronyms
found in ba.regions.acronym. For example if you want to merge certain regions you can
give them the same name in the custom_region_list
"""

# Import Allen atlas
ba = atlas.AllenAtlas(25)

# Check input
assert regions.shape == values.shape
if minmax is not None:
assert len(minmax) == 2
if axs is not None:
assert len(axs) == 3
if custom_region_list is not None:
assert custom_region_list.shape == ba.regions.acronym.shape

# Get region boundaries volume
boundaries = np.diff(ba.label, axis=0, append=0)
boundaries = boundaries + np.diff(ba.label, axis=1, append=0)
boundaries = boundaries + np.diff(ba.label, axis=2, append=0)
boundaries[boundaries != 0] = 1

# Get all brain region names, use custom list if inputted
if custom_region_list is None:
all_regions = ba.regions.acronym
else:
all_regions = custom_region_list

# Add values to brain region list
region_values = np.ones(ba.regions.acronym.shape) * (np.min(values) - (np.max(values) + 1))
for i, region in enumerate(regions):
region_values[all_regions == region] = values[i]

# Set 'void' to default white
region_values[0] = np.min(values) - (np.max(values) + 1)

# Get slices with fill values
slice_sag = ba.slice(ML / 1000, axis=0, volume=ba.label) # saggital
slice_sag = _label2values(slice_sag, region_values, ba)
bound_sag = ba.slice(ML / 1000, axis=0, volume=boundaries)
slice_cor = ba.slice(AP / 1000, axis=1, volume=ba.label) # coronal
slice_cor = _label2values(slice_cor, region_values, ba)
bound_cor = ba.slice(AP / 1000, axis=1, volume=boundaries)
slice_hor = ba.slice(DV / 1000, axis=2, volume=ba.label) # horizontal
slice_hor = _label2values(slice_hor, region_values, ba)
bound_hor = ba.slice(DV / 1000, axis=2, volume=boundaries)

# Add boundaries to slices outside of the fill value region
slice_sag[bound_sag == 1] = np.max(values) + 1
slice_cor[bound_cor == 1] = np.max(values) + 1
slice_hor[bound_hor == 1] = np.max(values) + 1

# Construct color map
color_map = sns.color_palette(color_palette, 1000)
color_map.append((0.8, 0.8, 0.8)) # color of the boundaries between regions
color_map.insert(0, (1, 1, 1)) # color of the background and regions without a value

# Get color scale
if minmax is None:
cmin = np.min(values)
cmax = np.max(values)
else:
cmin = minmax[0]
cmax = minmax[1]

# Plot
if axs is None:
fig, axs = plt.subplots(1, 3, figsize=(16, 4))

# Saggital
sns.heatmap(np.rot90(slice_sag, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[0])
axs[0].set(title='ML: %.1f mm' % ML)
plt.axis('off')
axs[0].get_xaxis().set_visible(False)
axs[0].get_yaxis().set_visible(False)

# Coronal
sns.heatmap(np.rot90(slice_cor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[1])
axs[1].set(title='AP: %.1f mm' % AP)
plt.axis('off')
axs[1].get_xaxis().set_visible(False)
axs[1].get_yaxis().set_visible(False)

# Horizontal
sns.heatmap(np.rot90(slice_hor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[2])
axs[2].set(title='DV: %.1f mm' % DV)
plt.axis('off')
axs[2].get_xaxis().set_visible(False)
axs[2].get_yaxis().set_visible(False)
47 changes: 47 additions & 0 deletions brainbox/examples/plot_atlas_color_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import matplotlib.pyplot as plt
from ibllib import atlas
from brainbox.atlas import plot_atlas


def combine_layers_cortex(regions, delete_duplicates=False):
remove = ["1", "2", "3", "4", "5", "6a", "6b", "/"]
for i, region in enumerate(regions):
for j, char in enumerate(remove):
regions[i] = regions[i].replace(char, "")
if delete_duplicates:
regions = list(set(regions))
return regions


# Coordinates of slices in mm
ML = -0.5
AP = 1
DV = -2

# Generate some mock data
ba = atlas.AllenAtlas(25)
all_regions = ba.regions.acronym
regions = np.random.choice(all_regions, size=500, replace=False) # pick 500 random regions
values = np.random.uniform(-1, 1, 500) # generate 500 random values

# Plot atlas
f, axs = plt.subplots(2, 3, figsize=(20, 10))
plot_atlas(regions, values, ML, AP, DV, color_palette="RdBu_r", minmax=[-1, 1], axs=axs[0])

# Now combine all layers of cortex
plot_regions = combine_layers_cortex(regions)
combined_cortex = combine_layers_cortex(all_regions)

# Plot atlas
plot_atlas(
plot_regions,
values,
ML,
AP,
DV,
color_palette="RdBu_r",
minmax=[-1, 1],
axs=axs[1],
custom_region_list=combined_cortex,
)
62 changes: 46 additions & 16 deletions brainbox/population/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'''

import numpy as np
import scipy as sp
import types
from itertools import groupby
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -197,7 +198,7 @@ def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):

def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, post_time=0.5,
classifier='bayes', cross_validation='kfold', num_splits=5, prob_left=None,
custom_validation=None, n_neurons='all', iterations=1, shuffle=False):
custom_validation=None, n_neurons='all', iterations=1, shuffle=False, phase_rand=False):
"""
Use decoding to classify groups of trials (e.g. stim left/right). Classification is done using
the population vector of summed spike counts from the specified time window. Cross-validation
Expand All @@ -222,10 +223,11 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
time (in seconds) preceding the event times
post_time : float
time (in seconds) following the event times
classifier : string
which decoder to use, options are:
classifier : string or sklearn object
which decoder to use, either input a scikit learn clf object directly or a string.
When it's a string options are (all classifiers are used with default options):
'bayes' Naive Bayes
'forest' Random forest (with 100 trees)
'forest' Random forest
'regression' Logistic regression
'lda' Linear Discriminant Analysis
cross_validation : string
Expand Down Expand Up @@ -257,6 +259,9 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
number of times to repeat the decoding (especially usefull when subselecting neurons)
shuffle : boolean
whether to shuffle the trial labels each decoding iteration
phase_rand : boolean
whether to use phase randomization of the activity over trials to use as a "chance"
predictor
Returns
-------
Expand Down Expand Up @@ -289,17 +294,25 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
# Get matrix of all neuronal responses
times = np.column_stack(((event_times - pre_time), (event_times + post_time)))
pop_vector, cluster_ids = _get_spike_counts_in_bins(spike_times, spike_clusters, times)
pop_vector = np.rot90(pop_vector)
pop_vector = pop_vector.T

# Exclude last trial if the number of trials is even and phase shuffling
if (phase_rand is True) & (event_groups.shape[0] % 2 == 0):
event_groups = event_groups[:-1]
pop_vector = pop_vector[:-1]

# Initialize classifier
if classifier == 'forest':
clf = RandomForestClassifier(n_estimators=100)
elif classifier == 'bayes':
clf = GaussianNB()
elif classifier == 'regression':
clf = LogisticRegression(solver='liblinear', multi_class='auto')
elif classifier == 'lda':
clf = LinearDiscriminantAnalysis()
if type(classifier) == str:
if classifier == 'forest':
clf = RandomForestClassifier()
elif classifier == 'bayes':
clf = GaussianNB()
elif classifier == 'regression':
clf = LogisticRegression()
elif classifier == 'lda':
clf = LinearDiscriminantAnalysis()
else:
clf = classifier

# Pre-allocate variables
acc = np.zeros(iterations)
Expand Down Expand Up @@ -328,6 +341,23 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
if shuffle is True:
event_groups = sklearn_shuffle(event_groups)

# Perform phase randomization of activity over trials if necessary
if phase_rand is True:
if i == 0:
original_pop_vector = sub_pop_vector
rand_pop_vector = np.empty(original_pop_vector.shape)
frequencies = int((original_pop_vector.shape[0] - 1) / 2)
fsignal = sp.fft.fft(original_pop_vector, axis=0)
power = np.abs(fsignal[1:1 + frequencies])
phases = 2 * np.pi * np.random.rand(frequencies)
for k in range(original_pop_vector.shape[1]):
newfsignal = fsignal[0, k]
newfsignal = np.append(newfsignal, np.exp(1j * phases) * power[:, k])
newfsignal = np.append(newfsignal, np.flip(np.exp(-1j * phases) * power[:, k]))
newsignal = sp.fft.ifft(newfsignal)
rand_pop_vector[:, k] = np.abs(newsignal.real)
sub_pop_vector = rand_pop_vector

if cross_validation == 'none':

# Fit the model on all the data and predict
Expand All @@ -341,13 +371,13 @@ def decode(spike_times, spike_clusters, event_times, event_groups, pre_time=0, p
else:
# Perform cross-validation
if cross_validation == 'leave-one-out':
cv = LeaveOneOut().split(pop_vector)
cv = LeaveOneOut().split(sub_pop_vector)
elif cross_validation == 'kfold':
cv = KFold(n_splits=num_splits).split(pop_vector)
cv = KFold(n_splits=num_splits).split(sub_pop_vector)
elif cross_validation == 'block':
block_lengths = [sum(1 for i in g) for k, g in groupby(prob_left)]
blocks = np.repeat(np.arange(len(block_lengths)), block_lengths)
cv = LeaveOneGroupOut().split(pop_vector, groups=blocks)
cv = LeaveOneGroupOut().split(sub_pop_vector, groups=blocks)
elif cross_validation == 'custom':
cv = custom_validation

Expand Down
4 changes: 2 additions & 2 deletions brainbox/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def responsive_units(spike_times, spike_clusters, event_times,
stats[i], p_values[i] = wilcoxon(baseline_counts[i, :], spike_counts[i, :])

# Perform FDR correction for multiple testing
sig_units, p_values, _, _ = multipletests(p_values, alpha)
sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh')
significant_units = cluster_ids[sig_units]

return significant_units, stats, p_values, cluster_ids
Expand Down Expand Up @@ -181,7 +181,7 @@ def differentiate_units(spike_times, spike_clusters, event_times, event_groups,
stats[i], p_values[i] = ttest_rel(counts_1[i, :], counts_2[i, :])

# Perform FDR correction for multiple testing
sig_units, p_values, _, _ = multipletests(p_values, alpha)
sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh')
significant_units = cluster_ids[sig_units]

return significant_units, stats, p_values, cluster_ids
Expand Down
4 changes: 2 additions & 2 deletions brainbox/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_responsive_units(self):
post_time=[0, 0.5],
alpha=alpha)
num_clusters = np.size(np.unique(spike_clusters))
self.assertTrue(np.size(sig_units) == 125)
self.assertTrue(np.size(sig_units) == 232)
self.assertTrue(np.sum(p_values < alpha) == np.size(sig_units))
self.assertTrue(np.size(cluster_ids) == num_clusters)

Expand All @@ -64,7 +64,7 @@ def test_differentiate_units(self):
post_time=0.5,
alpha=alpha)
num_clusters = np.size(np.unique(spike_clusters))
self.assertTrue(np.size(sig_units) == 1)
self.assertTrue(np.size(sig_units) == 0)
self.assertTrue(np.sum(p_values < alpha) == np.size(sig_units))
self.assertTrue(np.size(cluster_ids) == num_clusters)

Expand Down

0 comments on commit 9b2f2ae

Please sign in to comment.