Skip to content

Commit

Permalink
test: update causality plot tests, minor bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jul 26, 2022
1 parent 4320dcf commit 491bfd2
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 42 deletions.
69 changes: 39 additions & 30 deletions miv/statistics/pairwise_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,46 @@
from miv.typing import SignalType


def pairwise_causality(signal: SignalType, start: float, end: float):

# Estimates pairwise Granger Causality

# Parameters
# ----------
# signal : SignalType
# Input signal
# start : float
# starting point from signal
# end : float
# End point from signal

# Returns
# -------
# C : Causality Matrix containing directional causalities for X -> Y and Y -> X,
# instantaneous causality between X,Y, and total causality. X and Y represents electrodes

p = len(signal[0])
def pairwise_causality(signal: SignalType, start: int, end: int):
"""
Estimates pairwise Granger Causality between all channels.
Parameters
----------
signal : SignalType
Input signal.
start : int
starting point from signal
end : int
End point from signal
Returns
-------
C : np.ndarray
Causality Matrix (shape=2x2) containing directional causalities for X -> Y and Y -> X,
instantaneous causality between X,Y, and total causality. X and Y represents electrodes
See Also
--------
miv.visualization.causality.pairwise_causality_plot
"""

p = len(signal[0]) # Number of channels
C = np.zeros((4, p, p)) # Causality Matrix
q = np.arange(0, p)

for j in q:
for k in q:
if j == k:
C[:, j, k] = 0
else:
C[:, j, k] = pairwise_granger(
np.transpose([signal[start:end, j], signal[start:end, k]]),
max_order=1,
)

for j in range(p):
for k in range(j + 1, p):
C[:, j, k] = pairwise_granger(
np.transpose([signal[start:end, j], signal[start:end, k]]),
max_order=1,
)
C[:, k, j] = pairwise_granger(
np.transpose([signal[start:end, k], signal[start:end, j]]),
max_order=1,
)
# for i in range(4):
# np.fill_diagonal(C[i], 0.0)

# C or causality matrix contains four p X p matrices. These are directional causalities for X -> Y and Y -> X,
# instantaneous causality between X,Y, and total causality. X and Y represents electrodes
Expand Down
7 changes: 6 additions & 1 deletion miv/statistics/spiketrain_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ def binned_spiketrain(
n_bins = int((t_end - t_start) / bin_size + 1)
time = np.linspace(t_start, bin_size * (n_bins - 1), n_bins)
bin_spike = np.zeros(n_bins)
spike = spiketrains[channel].magnitude
if isinstance(spiketrains[channel], np.ndarray):
spike = spiketrains[channel]
elif isinstance(spiketrains[channel], neo.core.SpikeTrain):
spike = spiketrains[channel].magnitude
else:
raise TypeError(f"type {type(spiketrains[channel])} is not supported.")
bins = np.digitize(spike, time)
bin_spike[bins - 1] = 1

Expand Down
5 changes: 5 additions & 0 deletions miv/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
from miv.visualization.fft_domain import *
from miv.visualization.raw_signal import *
from miv.visualization.waveform import *

# TODO:
# Unit testing in visualization is simply performed such that the file or image is properly generated.
# It does not check the correctness of the visualization. The inspection is done manually.
# If anyone have better idea on how to do unittest on visualization, please help us out.
17 changes: 10 additions & 7 deletions miv/visualization/causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@

import matplotlib.pyplot as plt
import numpy as np
from elephant.causality.granger import pairwise_granger
from viziphant.spike_train_correlation import plot_corrcoef

from miv.statistics import pairwise_causality
from miv.statistics.spiketrain_statistics import binned_spiketrain
from miv.typing import SignalType, SpikestampsType


def pairwise_causality_plot(signal: SignalType, start: float, end: float):
def pairwise_causality_plot(signal: SignalType, start: int, end: int):
"""
Plots pairwise Granger Causality
Parameters
----------
signal : SignalType
Input signal
start : float
start : int
Starting point of the signal
end : float
end : int
End point of the signal
Returns
Expand All @@ -34,6 +33,10 @@ def pairwise_causality_plot(signal: SignalType, start: float, end: float):
axes : matplotlib.axes
axes parameters for plot modification
See Also
--------
miv.statistics.pairwise_causality
"""

# Causality
Expand Down Expand Up @@ -70,7 +73,7 @@ def spike_triggered_average_plot(
spiketrains: SpikestampsType,
channel_y: int,
sampling_freq: float,
window_length: float,
window_length: int,
):
"""
Plots the spike-triggered average of Local Field Potential (LFP) from channel X
Expand All @@ -90,7 +93,7 @@ def spike_triggered_average_plot(
Channel to consider for spiketrain data
sampling_freq : float
sampling frequency for LFP recordings
window_length : float
window_length : int
window length to consider before and after spike
Returns
Expand All @@ -103,7 +106,7 @@ def spike_triggered_average_plot(
"""

# Spike Triggered Average
dt = 1 / sampling_freq
dt = 1.0 / sampling_freq
n = np.shape(signal[:, channel_x])[0] / sampling_freq
assert (
window_length < np.shape(signal[:, channel_x])[0] / 2
Expand Down
54 changes: 54 additions & 0 deletions tests/visualization/test_causality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.axes import Axes

from miv.typing import SignalType


@pytest.fixture
def mock_numpy_signal():
num_length = 512
num_channel = 32

# signal = np.arange(num_length * num_channel).reshape([num_length, num_channel])
signal_list = []
x = np.ones(num_length)
for i in range(num_channel):
_signal = (i / 1.5) + x + np.random.randn(num_length) # jitter
signal_list.append(_signal)
signal = np.array(signal_list).T
return signal


@pytest.fixture
def mock_numpy_spiketrains():
spiketrains = [
np.sort(np.rint(np.arange(200, 500, 100))).astype(np.int_) for _ in range(32)
]
return spiketrains


@pytest.mark.parametrize("start, end", [(1, 100), (1, 50), (250, 500)])
def test_pairwise_causality_plot_numpy(mock_numpy_signal, start, end):
from miv.visualization.causality import pairwise_causality_plot

fig, axes = pairwise_causality_plot(mock_numpy_signal, start, end)

assert axes.shape == (2, 2), "Dimension of axes does not match."
assert isinstance(fig, plt.Figure)


@pytest.mark.parametrize("ch1, ch2", [(0, 1), (1, 0)])
@pytest.mark.parametrize("window_length", [10, 20, 50])
def test_spike_triggered_average_plot_numpy(
mock_numpy_signal, mock_numpy_spiketrains, ch1, ch2, window_length
):
from miv.visualization.causality import spike_triggered_average_plot

fig, axes = spike_triggered_average_plot(
mock_numpy_signal, ch1, mock_numpy_spiketrains, ch2, 1, window_length
)

assert isinstance(axes, Axes)
assert isinstance(fig, plt.Figure)
18 changes: 14 additions & 4 deletions tests/visualization/test_raw_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@ def test_multi_channel_signal_plot_filecheck(tmp_path):
from miv.visualization import multi_channel_signal_plot

length = 64
channel = 64
channel = 16
signal = np.arange(channel * length).reshape([length, channel])
X, Y = np.mgrid[:8, :8]
mea_geometry = zip(range(64), X.ravel(), Y.ravel())
X, Y = np.mgrid[:4, :4]
mea_geometry = zip(range(16), X.ravel(), Y.ravel())
output_path = os.path.join(tmp_path, "output.mp4")
multi_channel_signal_plot(signal, mea_geometry, 0, length, 10, 30, output_path)
multi_channel_signal_plot(
signal,
mea_geometry,
0,
length,
10,
30,
output_path,
max_subplot_in_x=4,
max_subplot_in_y=4,
)
assert os.path.exists(output_path), "output video file does not exist."

0 comments on commit 491bfd2

Please sign in to comment.