Skip to content

Commit

Permalink
Merge branch 'update-0.2.0' of https://github.com/GazzolaLab/MiV-OS i…
Browse files Browse the repository at this point in the history
…nto update-0.2.0
  • Loading branch information
skim0119 committed Aug 8, 2022
2 parents 1f4feba + 52f13cb commit 02e5e2a
Show file tree
Hide file tree
Showing 20 changed files with 544 additions and 210 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exclude_lines =
# Enable pragma
pragma: no cover
TODO

# Don't complain if non-runnable code isn't run:
if 0:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#* Variables
PYTHON := python
PYTHON := python3
PYTHONPATH := `pwd`

#* Poetry
Expand Down
5 changes: 3 additions & 2 deletions miv/io/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def bits_to_voltage(signal: SignalType, channel_info: Sequence[Dict[str, Any]]):
signal : SignalType, numpy array
channel_info : Dict[str, Dict[str, Any]]
Channel information dictionary. Typically located in `structure.oebin` file.
channel information includes bit-volts conversion ration and units (uV or mV).
Channel information includes bit-volts conversion ration and units (uV or mV).
Returns
-------
signal : numpy array
Output signal is in microVolts unit.
Output signal is in microVolts (uV) unit.
"""
resultant_unit = pq.Quantity(1, "uV") # Final Unit
Expand Down Expand Up @@ -119,6 +119,7 @@ def load_recording(
Returns
-------
signal : SignalType, neo.core.AnalogSignal
timestamps : TimestampsType
sampling_rate : float
Raises
Expand Down
10 changes: 6 additions & 4 deletions miv/io/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def _get_binned_matrix(
"empty_channels": empty_channels,
}

def save(self, tag: str, format: str): # TODO
def save(self, tag: str, format: str): # pragma: no cover
# TODO
assert tag == "continuous", "You cannot alter raw data, change the data tag"
# save_path = os.path.join(self.data_path, tag)

Expand Down Expand Up @@ -419,7 +420,8 @@ def query_path_name(self, query_path) -> Iterable[Data]:
return list(filter(lambda d: query_path in d.data_path, self.data_list))

# DataManager Representation
def tree(self):
# TODO: Display data structure
def tree(self): # pragma: no cover
"""
Pretty-print available recordings in DataManager in tree format.
Expand Down Expand Up @@ -492,12 +494,12 @@ def _get_experiment_paths(self) -> Iterable[str]:
path_list.append(path)
return path_list

def save(self, tag: str, format: str):
def save(self, tag: str, format: str): # pragma: no cover
raise NotImplementedError # TODO
for data in self.data_list:
data.save(tag, format)

def apply_filter(self, filter: FilterProtocol):
def apply_filter(self, filter: FilterProtocol): # pragma: no cover
raise NotImplementedError # TODO
for data in self.data_list:
data.load()
Expand Down
11 changes: 6 additions & 5 deletions miv/signal/spike/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ def __init__(
feature_extractor: SpikeFeatureExtractionProtocol,
clustering_method: UnsupervisedFeatureClusteringProtocol,
):
pass
self.featrue_extractor = feature_extractor
self.clustering_method = clustering_method

def __call__(self):
pass
def __call__(self, cutouts: np.ndarray, n_group: int = 3):
assert n_group >= 2, "n_group must be larger than 1"


# UnsupervisedFeatureClusteringProtocol
class SuperParamagneticClustering:
class SuperParamagneticClustering: # pragma : no cover
"""Super-Paramagnetic Clustering (SPC)
The implementation is heavily inspired from [1]_ and [2]_.
Expand Down Expand Up @@ -195,7 +196,7 @@ def project(self, n_components, cutouts):
"""


class WaveletDecomposition:
class WaveletDecomposition: # TODO
"""
Wavelet Decomposition for spike sorting.
The implementation is heavily inspired from [1]_ and [2]_;
Expand Down
32 changes: 9 additions & 23 deletions miv/statistics/burst.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,17 @@ def burst(spiketrains: SpikestampsType, channel: float, min_isi: float, min_len:
burst_spike = (spike_interval <= min_isi).astype(
np.bool_
) # Only spikes within specified min ISI are 1 otherwise 0 and are stored
burst = [] # List to store burst parameters

flag = False
start_idx = -1
# Try to find the start and end indices for burst interval
delta = np.logical_xor(burst_spike[:-1], burst_spike[1:])
for idx, dval in enumerate(delta):
if dval:
flag = ~flag
if flag:
start_idx = idx + 1
else:
if idx + 1 - start_idx >= min_len:
burst.append((start_idx, idx + 1))
Q = np.array(burst)
interval = np.where(delta)[0]
if len(interval) % 2:
interval = np.append(interval, len(delta))
interval += 1
interval = interval.reshape([-1, 2])
mask = np.diff(interval) >= min_len
interval = interval[mask.ravel(), :]
Q = np.array(interval)

if np.sum(Q) == 0:
start_time = 0
Expand All @@ -75,14 +72,3 @@ def burst(spiketrains: SpikestampsType, channel: float, min_isi: float, min_len:
burst_rate = burst_len / (burst_duration)

return start_time, burst_duration, burst_len, burst_rate


if __name__ == "__main__":
import timeit

from neo.core import SpikeTrain

arr = np.random.random(10000)
arr = np.sort(arr)
train0 = [SpikeTrain(times=arr, units="sec", t_stop=arr.max())]
o_algorithm = timeit.timeit(lambda: burst(train0, 0, 0.2, 5), number=1000)
175 changes: 66 additions & 109 deletions miv/visualization/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,99 +43,56 @@ def plot_connectivity(

if directionality:
g = graphviz.Digraph("G", filename="connectivity", engine="neato")
for i in elec_mat:
for j in elec_mat:
if i == j:
continue
else:
if connectivity_matrix[int(i - 1), int(j - 1)] != 0:

ki = int(np.where(mea_map == i)[0][0])
ji = int(np.where(mea_map == i)[1][0])
kj = int(np.where(mea_map == j)[0][0])
jj = int(np.where(mea_map == j)[1][0])

x1 = str(ki + 0.0)
y1 = str(ji) + "!"
posstr1 = x1 + "," + y1

x2 = str(kj + 0.0)
y2 = str(jj) + "!"
posstr2 = x2 + "," + y2
g.edge(str(i), str(j), color="red")

g.node(
str(i),
pos=posstr1,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)

g.node(
str(j),
pos=posstr2,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)
return g.view()

else:
g = graphviz.Graph("G", filename="connectivity", engine="neato")
for i in elec_mat:
for j in elec_mat:
if i == j:
continue
else:
if connectivity_matrix[int(i - 1), int(j - 1)] != 0 and i < j:

ki = int(np.where(mea_map == i)[0][0])
ji = int(np.where(mea_map == i)[1][0])
kj = int(np.where(mea_map == j)[0][0])
jj = int(np.where(mea_map == j)[1][0])

x1 = str(ki + 0.0)
y1 = str(ji) + "!"
posstr1 = x1 + "," + y1

x2 = str(kj + 0.0)
y2 = str(jj) + "!"
posstr2 = x2 + "," + y2
g.edge(str(i), str(j), color="red")

g.node(
str(i),
pos=posstr1,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)

g.node(
str(j),
pos=posstr2,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)
return g.view()

for i in elec_mat:
for j in elec_mat:
if i == j:
continue
if not directionality and i >= j:
continue
if np.isclose(connectivity_matrix[int(i - 1), int(j - 1)], 0):
continue
# Register
ki = int(np.where(mea_map == i)[0][0])
ji = int(np.where(mea_map == i)[1][0])
kj = int(np.where(mea_map == j)[0][0])
jj = int(np.where(mea_map == j)[1][0])

x1 = str(ki + 0.0)
y1 = str(ji) + "!"
posstr1 = x1 + "," + y1

x2 = str(kj + 0.0)
y2 = str(jj) + "!"
posstr2 = x2 + "," + y2
g.edge(str(i), str(j), color="red")

g.node(
str(i),
pos=posstr1,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)

g.node(
str(j),
pos=posstr2,
fillcolor="deepskyblue2",
color="black",
style="filled,solid",
shape="circle",
fontcolor="black",
fontsize="15",
fontname="Times:Roman bold",
)
return g.view()


def plot_connectivity_interactive(
Expand Down Expand Up @@ -184,24 +141,24 @@ def plot_connectivity_interactive(
for j in elec_mat:
if i == j:
continue
else:
if connectivity_matrix[int(i - 1), int(j - 1)] != 0 and i < j:
ki = int(np.where(mea_map == i)[0][0]) * 100
ji = int(np.where(mea_map == i)[1][0]) * 100
kj = int(np.where(mea_map == j)[0][0]) * 100
jj = int(np.where(mea_map == j)[1][0]) * 100

x1 = str(ki + 0.0)
y1 = str(ji)

x2 = str(kj + 0.0)
y2 = str(jj)

net.add_node(
str(int(i)), x=x1, y=y1, shape="dot", color="#039AFB"
) # size = size1)
net.add_node(str(int(j)), x=x2, y=y2, shape="dot", color="#039AFB!")
net.add_edge(str(int(i)), str(int(j)), width="1")
if np.isclose(connectivity_matrix[int(i - 1), int(j - 1)], 0.0) or i >= j:
continue
ki = int(np.where(mea_map == i)[0][0]) * 100
ji = int(np.where(mea_map == i)[1][0]) * 100
kj = int(np.where(mea_map == j)[0][0]) * 100
jj = int(np.where(mea_map == j)[1][0]) * 100

x1 = str(ki + 0.0)
y1 = str(ji)

x2 = str(kj + 0.0)
y2 = str(jj)

net.add_node(
str(int(i)), x=x1, y=y1, shape="dot", color="#039AFB"
) # size = size1)
net.add_node(str(int(j)), x=x2, y=y2, shape="dot", color="#039AFB!")
net.add_edge(str(int(i)), str(int(j)), width="1")

for n in net.nodes:
n.update({"physics": False})
Expand Down
17 changes: 1 addition & 16 deletions miv/visualization/fft_domain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
__all__ = ["plot_frequency_domain", "plot_spectral"]

import os

import matplotlib.pyplot as plt
import numpy as np
from scipy import fftpack
Expand All @@ -10,16 +8,14 @@
from miv.typing import SignalType


def plot_frequency_domain(signal: SignalType, sampling_rate: float) -> plt.Figure:
def plot_frequency_domain(signal: SignalType) -> plt.Figure:
"""
Plot DFT frequency domain
Parameters
----------
signal : SignalType
Input signal
sampling_rate : float
Sampling frequency
Returns
-------
Expand All @@ -29,20 +25,9 @@ def plot_frequency_domain(signal: SignalType, sampling_rate: float) -> plt.Figur
# FFT
fig = plt.figure()
sig_fft = fftpack.fft(signal)
# sample_freq = fftpack.fftfreq(signal.size, d=1 / sampling_rate)
plt.plot(np.abs(sig_fft) ** 2)
plt.xlabel("Frequency [Hz]")
plt.ylabel("DFT frequency")

# Welch (https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html)
f, Pxx_den = welch(signal, sampling_rate, nperseg=1024)
f_med, Pxx_den_med = welch(signal, sampling_rate, nperseg=1024, average="median")
plt.figure()
plt.semilogy(f, Pxx_den, label="mean")
plt.semilogy(f_med, Pxx_den_med, label="median")
plt.xlabel("frequency [Hz]")
plt.ylabel("PSD [uV**2/Hz]")
plt.legend()
return fig


Expand Down
Loading

0 comments on commit 02e5e2a

Please sign in to comment.