diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c34b9254..f64d7411 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ Thanks for your interest in contributing MiV-OS project. The following is a set of guidelines how to contributes. These are mostly guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. -#### Table Of Contents +**Table Of Contents** [TLTR! I need three-line summary!!](#three-line-summary) @@ -36,10 +36,15 @@ The following is a set of guidelines how to contributes. These are mostly guidel ### Installation and packages First **create the fork repository and clone** to your local machine. -We provide [requirements.txt](requirements.txt) to include all the dependencies. +We provide [requirements.txt](https://github.com/GazzolaLab/MiV-OS/blob/main/requirements.txt) to include all the dependencies that is required to develop. You can either install using `pip install -r requirements.txt` or ```bash -$ pip install -r requirements.txt +$ pip install miv-os[dev] ``` +If you are more interested in working for documentation, use +```bash +$ pip install miv-os[docs] +``` +More details are included [here](https://github.com/GazzolaLab/MiV-OS/blob/main/docs/README.md). ### Pre-Commit diff --git a/Makefile b/Makefile index 81021c25..e5d5a257 100644 --- a/Makefile +++ b/Makefile @@ -7,5 +7,5 @@ mypy: coverage: @pytest --cov=miv tests/ -all:test mypy +all:test mypy coverage ci: test mypy diff --git a/README.md b/README.md index 93540f24..c73a426d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![codecov][badge-codecov]][link-codecov] [//]: # (Remove this line for the first release) -_The package is under pre-alpha development. First Alpha release target mid-April 2022._ +_The package is under pre-alpha development. First Alpha release target mid-May 2022._ diff --git a/docs/_static/assets/spike_cutout_example.png b/docs/_static/assets/spike_cutout_example.png new file mode 100644 index 00000000..75dd309f Binary files /dev/null and b/docs/_static/assets/spike_cutout_example.png differ diff --git a/docs/api/_toctree/FilterAPI/miv.signal.filter.FilterCollection.rst b/docs/api/_toctree/FilterAPI/miv.signal.filter.FilterCollection.rst deleted file mode 100644 index b21e210b..00000000 --- a/docs/api/_toctree/FilterAPI/miv.signal.filter.FilterCollection.rst +++ /dev/null @@ -1,7 +0,0 @@ -miv.signal.filter.FilterCollection -================================== - -.. currentmodule:: miv.signal.filter - -.. autoclass:: FilterCollection - :members: append, insert diff --git a/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.PCADecomposition.rst b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.PCADecomposition.rst new file mode 100644 index 00000000..b082e359 --- /dev/null +++ b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.PCADecomposition.rst @@ -0,0 +1,13 @@ +miv.signal.spike.PCADecomposition +================================= + +.. currentmodule:: miv.signal.spike + +.. autoclass:: PCADecomposition + + .. rubric:: Methods + + .. autosummary:: + + ~PCADecomposition.__init__ + ~PCADecomposition.project diff --git a/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SpikeFeatureExtractionProtocol.rst b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SpikeFeatureExtractionProtocol.rst new file mode 100644 index 00000000..0fdadf8b --- /dev/null +++ b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SpikeFeatureExtractionProtocol.rst @@ -0,0 +1,13 @@ +miv.signal.spike.SpikeFeatureExtractionProtocol +=============================================== + +.. currentmodule:: miv.signal.spike + +.. autoclass:: SpikeFeatureExtractionProtocol + + + .. rubric:: Methods + + .. autosummary:: + + ~SpikeFeatureExtractionProtocol.__init__ diff --git a/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SuperParamagneticClustering.rst b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SuperParamagneticClustering.rst new file mode 100644 index 00000000..b7364c35 --- /dev/null +++ b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.SuperParamagneticClustering.rst @@ -0,0 +1,13 @@ +miv.signal.spike.SuperParamagneticClustering +============================================ + +.. currentmodule:: miv.signal.spike + +.. autoclass:: SuperParamagneticClustering + + + .. rubric:: Methods + + .. autosummary:: + + ~SuperParamagneticClustering.__init__ diff --git a/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.UnsupervisedFeatureClusteringProtocol.rst b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.UnsupervisedFeatureClusteringProtocol.rst new file mode 100644 index 00000000..6cc6fb87 --- /dev/null +++ b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.UnsupervisedFeatureClusteringProtocol.rst @@ -0,0 +1,14 @@ +miv.signal.spike.UnsupervisedFeatureClusteringProtocol +====================================================== + +.. currentmodule:: miv.signal.spike + +.. autoclass:: UnsupervisedFeatureClusteringProtocol + + .. rubric:: Methods + + .. autosummary:: + + ~UnsupervisedFeatureClusteringProtocol.__init__ + ~UnsupervisedFeatureClusteringProtocol.fit + ~UnsupervisedFeatureClusteringProtocol.predict diff --git a/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.WaveletDecomposition.rst b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.WaveletDecomposition.rst new file mode 100644 index 00000000..f8676829 --- /dev/null +++ b/docs/api/_toctree/SpikeSortingAPI/miv.signal.spike.WaveletDecomposition.rst @@ -0,0 +1,13 @@ +miv.signal.spike.WaveletDecomposition +===================================== + +.. currentmodule:: miv.signal.spike + +.. autoclass:: WaveletDecomposition + + .. rubric:: Methods + + .. autosummary:: + + ~WaveletDecomposition.__init__ + ~WaveletDecomposition.project diff --git a/docs/api/_toctree/StatisticsAPI/miv.statistics.spikestamps_statistics.rst b/docs/api/_toctree/StatisticsAPI/miv.statistics.spikestamps_statistics.rst new file mode 100644 index 00000000..9d19a195 --- /dev/null +++ b/docs/api/_toctree/StatisticsAPI/miv.statistics.spikestamps_statistics.rst @@ -0,0 +1,6 @@ +miv.statistics.spikestamps\_statistics +====================================== + +.. currentmodule:: miv.statistics + +.. autofunction:: spikestamps_statistics diff --git a/docs/api/_toctree/StatisticsAPI/miv.statistics.summarizer.StatisticsSummary.rst b/docs/api/_toctree/StatisticsAPI/miv.statistics.summarizer.StatisticsSummary.rst deleted file mode 100644 index 4c4dd6a0..00000000 --- a/docs/api/_toctree/StatisticsAPI/miv.statistics.summarizer.StatisticsSummary.rst +++ /dev/null @@ -1,13 +0,0 @@ -miv.statistics.summarizer.StatisticsSummary -=========================================== - -.. currentmodule:: miv.statistics.summarizer - -.. autoclass:: StatisticsSummary - - .. rubric:: Methods - - .. autosummary:: - - ~StatisticsSummary.__init__ - ~StatisticsSummary.spikestamps_summary diff --git a/docs/api/_toctree/VisualizationAPI/miv.visualization.fft_domain.plot_frequency_domain.rst b/docs/api/_toctree/VisualizationAPI/miv.visualization.fft_domain.plot_frequency_domain.rst new file mode 100644 index 00000000..c0120f70 --- /dev/null +++ b/docs/api/_toctree/VisualizationAPI/miv.visualization.fft_domain.plot_frequency_domain.rst @@ -0,0 +1,6 @@ +miv.visualization.fft\_domain.plot\_frequency\_domain +===================================================== + +.. currentmodule:: miv.visualization.fft_domain + +.. autofunction:: plot_frequency_domain diff --git a/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.extract_waveforms.rst b/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.extract_waveforms.rst new file mode 100644 index 00000000..8e138bc0 --- /dev/null +++ b/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.extract_waveforms.rst @@ -0,0 +1,6 @@ +miv.visualization.waveform.extract\_waveforms +============================================= + +.. currentmodule:: miv.visualization.waveform + +.. autofunction:: extract_waveforms diff --git a/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.plot_waveforms.rst b/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.plot_waveforms.rst new file mode 100644 index 00000000..288b0bd7 --- /dev/null +++ b/docs/api/_toctree/VisualizationAPI/miv.visualization.waveform.plot_waveforms.rst @@ -0,0 +1,6 @@ +miv.visualization.waveform.plot\_waveforms +========================================== + +.. currentmodule:: miv.visualization.waveform + +.. autofunction:: plot_waveforms diff --git a/docs/api/io.rst b/docs/api/io.rst index da6820c9..4c67339e 100644 --- a/docs/api/io.rst +++ b/docs/api/io.rst @@ -1,6 +1,8 @@ -********************* -Input / Output Module -********************* +******************** +Data Managing Module +******************** + +.. automodule:: miv.io.data .. automodule:: miv.io.binary :members: diff --git a/docs/api/signal.rst b/docs/api/signal.rst index b8461213..2f096736 100644 --- a/docs/api/signal.rst +++ b/docs/api/signal.rst @@ -1,34 +1,7 @@ -************************* -Signal Processing Modules -************************* +********************* +Signal Pre-Processing +********************* +.. automodule:: miv.signal.filter.filter_collection -Filter -###### - -.. currentmodule:: miv.signal.filter - -.. automodule:: miv.signal.filter - - .. autosummary:: - :nosignatures: - :toctree: _toctree/FilterAPI - - FilterProtocol - ButterBandpass - FilterCollection - -Spike Detection -############### - -.. automodule:: miv.signal.spike - - .. autosummary:: - :nosignatures: - :toctree: _toctree/DetectionAPI - - SpikeDetectionProtocol - ThresholdCutoff - -Spike Sorting -############# +.. automodule:: miv.signal.spike.detection diff --git a/docs/api/sorting.rst b/docs/api/sorting.rst new file mode 100644 index 00000000..9d6be9f0 --- /dev/null +++ b/docs/api/sorting.rst @@ -0,0 +1,5 @@ +******************** +Spike Sorting Module +******************** + +.. automodule:: miv.signal.spike.sorting diff --git a/docs/api/statistics.rst b/docs/api/statistics.rst index 8562e7c4..d2802b62 100644 --- a/docs/api/statistics.rst +++ b/docs/api/statistics.rst @@ -2,40 +2,4 @@ Statistics Modules ****************** -Statistics Tools -================ - -Spikestamps ------------ - -.. currentmodule:: miv.statistics - .. automodule:: miv.statistics.summarizer - - .. autosummary:: - :nosignatures: - :toctree: _toctree/StatisticsAPI - - StatisticsSummary - -Useful External Packages -######################## - -scipy statistics -================ - -`scipy `_ - -.. autosummary:: - - scipy.stats.describe - -elephant.statistics -=================== - -`elephant documentation: `_ - -.. autosummary:: - - elephant.statistics.mean_firing_rate - elephant.statistics.instantaneous_rate diff --git a/docs/api/visualization.rst b/docs/api/visualization.rst new file mode 100644 index 00000000..9232fc08 --- /dev/null +++ b/docs/api/visualization.rst @@ -0,0 +1,48 @@ +******************* +Visualization Tools +******************* + +Plotting Tools +============== + +DFT Plot +-------- + +.. currentmodule:: miv.visualization + +.. automodule:: miv.visualization.fft_domain + + .. autosummary:: + :nosignatures: + :toctree: _toctree/VisualizationAPI + + plot_frequency_domain + +Spike Waveform Overlap +---------------------- + +.. currentmodule:: miv.visualization + +.. automodule:: miv.visualization.waveform + + .. autosummary:: + :nosignatures: + :toctree: _toctree/VisualizationAPI + + extract_waveforms + plot_waveforms + +Useful External Packages +======================== + +Here are few external `python` packages that can also be used for visualization. + +Viziphant +--------- + +`viziphant (elephant) documentation: `_ + +.. autosummary:: + + viziphant.rasterplot.rasterplot + viziphant.rasterplot.rasterplot_rates diff --git a/docs/conf.py b/docs/conf.py index e87860e6..04fa4468 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -108,7 +108,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] +html_static_path = ["_static", "_static/assets"] html_css_files = ["css/*", "css/logo.css"] # -- Options for numpydoc --------------------------------------------------- diff --git a/docs/guide/data_management.md b/docs/guide/data_management.md new file mode 100644 index 00000000..9d103e8e --- /dev/null +++ b/docs/guide/data_management.md @@ -0,0 +1,94 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# Data Management + +- Data +- DataManager +- load_continuous_data (raw) + +```{code-cell} ipython3 +:tags: [hide-cell] + +import os +import numpy as np +import quantities as pq +import matplotlib.pyplot as plt +from glob import glob +from miv.io import * +``` + +```{code-cell} ipython3 +datapath = './2022-03-10_16-19-09' +os.path.exists(datapath) +``` + +```{code-cell} ipython3 +filepath = './2022-03-10_16-19-09/Record Node 104/experiment1/recording1/continuous/Rhythm_FPGA-100.0/continuous.dat' +os.path.exists(filepath) +``` + +## 1. Data Load + +```{code-cell} ipython3 +# Load dataset from OpenEphys recording +folder_path: str = "~/Open Ephys/2022-03-10-16-19-09" # Data Path +# Provide the path of experimental recording tree to the DataSet class +# Data set class will load the data and create a list of objects for each data +# dataset = load_data(folder_path, device="OpenEphys") +dataset = Dataset(data_folder_path=folder_path, + device="OpenEphys", + channels=32, + sampling_rate=30E3, + timestamps_npy="", # We can read similar to continuous.dat + + ) +#TODO: synchornized_timestamp what for shifted ?? +# Masking channels for data set. Channels can be a list. +# Show user the tree. Implement representation method. filter_collection.html#FilterCollection.insert +# An example code to get the tree https://github.com/skim0119/mindinvitro/blob/master/utility/common.py +# Trimming the tree?? +``` + +### 1.1. Meta Data Structure + +```{code-cell} ipython3 +# Get signal and rate(hz) +record_node: int = dataset.get_nodes[0] +recording = dataset[record_node]["experiment1"]["recording1"] # Returns the object for recording 1 +# TODO: does openephys returns the timestamp?? +timestamp = recording.timestamp # returns the time stamp for the recording. + +signal, _, rate = recording.continuous["100"] +# time = recording.continuous["100"].timestamp / rate +num_channels = signal.shape[1] +``` + +### 1.2 Raw Data + ++++ + +If the data is provided in single `continuous.dat` instead of meta-data, user must provide number of channels and sampling rate in order to import data accurately. + +> **WARNING** The size of the raw datafile can be _large_ depending on sampling rate and the amount of recorded duration. We highly recommand using meta-data structure to handle datafiles, since it only loads the data during the processing and unloads once the processing is done. + +```{code-cell} ipython3 +from miv.io import load_continuous_data_file + +datapath = 'continuous.dat' +rate = 30_000 +num_channel = 64 +timestamps, signal = load_continuous_data_file(datapath, num_channel, rate) +``` + +## 2. Instant Visualization diff --git a/docs/guide/signal_processing.ipynb b/docs/guide/signal_processing.ipynb deleted file mode 100644 index b5bcb9b0..00000000 --- a/docs/guide/signal_processing.ipynb +++ /dev/null @@ -1,327 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "d81cdede", - "metadata": {}, - "source": [ - "# Signal Processing Guideline\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "566761ed", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import os\n", - "import numpy as np\n", - "import quantities as pq\n", - "import matplotlib.pyplot as plt\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "cb0de5c5", - "metadata": {}, - "source": [ - "## 1. Data Load" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "799f0547", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "from miv.io import load_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e106e97", - "metadata": {}, - "outputs": [], - "source": [ - "# Load dataset from OpenEphys recording\n", - "folder_path: str = \"~/Open Ephys/2022-03-10-16-19-09\" # Data Path\n", - "dataset = load_data(folder_path, device=\"OpenEphys\")" - ] - }, - { - "cell_type": "markdown", - "id": "f667f6e4", - "metadata": {}, - "source": [ - "### 1.1. Meta Data Structure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c21bc85", - "metadata": {}, - "outputs": [], - "source": [ - "# Get signal and rate(hz)\n", - "record_node: int = dataset.get_nodes[0]\n", - "recording = dataset[record_node][\"experiment1\"][\"recording1\"]\n", - "signal, _, rate = recording.continuous[\"100\"]\n", - "# time = recording.continuous[\"100\"].timestamp / rate\n", - "num_channels = signal.shape[1]" - ] - }, - { - "cell_type": "markdown", - "id": "a72339e8", - "metadata": {}, - "source": [ - "### 1.2. Array Data Structure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66972c64", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "16ce7358", - "metadata": {}, - "source": [ - "### 1.3 Raw Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a575ab1d", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "cd5d1420", - "metadata": {}, - "source": [ - "## 2. Filtering Raw Signal" - ] - }, - { - "cell_type": "markdown", - "id": "c6eda669", - "metadata": {}, - "source": [ - "We provide a set of basic signal filter tools. It is highly recommended to filter the signal before doing the spike-detection.\n", - "Here, we provide examples of how to create and apply the filter to the [`dataset`](../api/io.rst)." - ] - }, - { - "cell_type": "markdown", - "id": "c4402012", - "metadata": {}, - "source": [ - "If you have further suggestion on other filters to include, please leave an issue on our [GitHub issue page](https://github.com/GazzolaLab/MiV-OS/issues) with `enhancement` tag." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "783e3dd4", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "from miv.signal.filter import FilterCollection, ButterBandpass" - ] - }, - { - "cell_type": "markdown", - "id": "1432314b", - "metadata": {}, - "source": [ - "### 2.1 Filter Collection\n", - "\n", - "[Here](../api/signal.html#filter) is the list of provided filters.\n", - "All filters are `Callable`, taking `signal` and `sampling_rate` as parameters.\n", - "To define a multiple filters together, we provide [`FilterCollection`](../api/_toctree/FilterAPI/miv.signal.filter.FilterCollection) that execute multiple filters in a series.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e49e9f31", - "metadata": {}, - "outputs": [], - "source": [ - "# Butter bandpass filter\n", - "pre_filter = ButterBandpass(lowcut=300, highcut=3000, order=5)\n", - "\n", - "# How to construct sequence of filters\n", - "pre_filter = (\n", - " FilterCollection(tag=\"Filter Example\")\n", - " .append(ButterBandpass(lowcut=300, highcut=3000, order=5))\n", - " #.append(Limiter(400*pq.mV))\n", - " #.append(Filter1(**filter1_kwargs))\n", - " #.append(Filter2(**filter2_kwargs))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7430675d", - "metadata": {}, - "source": [ - "### 2.2 Apply Filter\n", - "\n", - "There are two way to apply the filter on the signal.\n", - "- If the signal is stored in `numpy array` format, you can directly call the filter `prefilter(signal, sampling_rate)`.\n", - "- If you want to apply the filter to all signals in the `dataset`, `dataset` provide `.apply_filter` method that takes any `filter` (any filter that abide [`filter protocol`](../api/_toctree/FilterAPI/miv.signal.filter.FilterProtocol)).\n", - " - You can select [subset of `dataset`](../api/dataset.html#data-subset) and [mask-out channels](../api/dataset.html#mask-channel) before applying the filter.\n", - " \n", - "You can check the list of all provided filters [here](../api/signal.html#filter)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c3dc284e", - "metadata": {}, - "outputs": [], - "source": [ - "# Apply filter to entire dataset\n", - "dataset.apply_filter(pre_filter)\n", - "filtered_signal = dataset[record_node]['experiment1']['recording1'].filtered_signal\n", - "\n", - "# Apply filter to array\n", - "rate = 30_000\n", - "filtered_signal = pre_filter(data_array, sampling_rate=rate)\n", - "\n", - "# Retrieve data from dataset and apply filter\n", - "data = dataset[record_node]['experiment1']['recording1']\n", - "filtered_signal = pre_filter(data, sampling_rate=rate)" - ] - }, - { - "cell_type": "markdown", - "id": "c0aad0c0", - "metadata": {}, - "source": [ - "## 3. Spike Detection\n", - "\n", - "You can check the available method [here](../api/signal.html#spike-detection).\n", - "\n", - "Most simple example of spike-detection method is using `ThresholdCutoff`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16848ba7", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "from miv.signal.spike import ThresholdCutoff" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb1e5ee9", - "metadata": {}, - "outputs": [], - "source": [ - "# Define spike-detection method\n", - "spike_detection = ThresholdCutoff()\n", - "\n", - "# The detection can be used directly as the following.\n", - "# signal : np.array or neo.core.AnalogSignal, shape(N_channels, N)\n", - "# timestamps : np.array, shape(N) \n", - "# sampling_rate : float\n", - "timestamps = spike_detection(signal, timestamps, sampling_rate=30_000, cutoff=3.5)\n", - "\n", - "# The detection can be applied on the dataset\n", - "dataset.apply_spike_detection(spike_detection)" - ] - }, - { - "cell_type": "markdown", - "id": "11db868c", - "metadata": {}, - "source": [ - "## 4. Spike Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ed78920", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import neo\n", - "from viziphant.rasterplot import rasterplot_rates" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a4dffb6", - "metadata": {}, - "outputs": [], - "source": [ - "# Plot\n", - "rasterplot_rates(spiketrain_list)" - ] - } - ], - "metadata": { - "celltoolbar": "Tags", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/guide/signal_processing.md b/docs/guide/signal_processing.md new file mode 100644 index 00000000..e61f4a6c --- /dev/null +++ b/docs/guide/signal_processing.md @@ -0,0 +1,184 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# Signal Processing Guideline + +```{code-cell} ipython3 +:tags: [hide-cell] + +import os +import numpy as np +import quantities as pq +import matplotlib.pyplot as plt + +``` + +## 1. Data Load + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.io import load_data +from miv.io.data import Data, Dataset +``` + +```{code-cell} ipython3 +# Load dataset from OpenEphys recording +folder_path: str = "~/Open Ephys/2022-03-10-16-19-09" # Data Path +# Provide the path of experimental recording tree to the DataSet class +# Data set class will load the data and create a list of objects for each data +# dataset = load_data(folder_path, device="OpenEphys") +dataset = Dataset(data_folder_path=folder_path, + device="OpenEphys", + channels=32, + sampling_rate=30E3, + timestamps_npy="", # We can read similar to continuous.dat + + ) +#TODO: synchornized_timestamp what for shifted ?? +# Masking channels for data set. Channels can be a list. +# Show user the tree. Implement representation method. filter_collection.html#FilterCollection.insert +# An example code to get the tree https://github.com/skim0119/mindinvitro/blob/master/utility/common.py +# Trimming the tree?? +``` + +### 1.1. Meta Data Structure + +```{code-cell} ipython3 +# Get signal and rate(hz) +record_node: int = dataset.get_nodes[0] +recording = dataset[record_node]["experiment1"]["recording1"] # Returns the object for recording 1 +# TODO: does openephys returns the timestamp?? +timestamp = recording.timestamp # returns the time stamp for the recording. + +signal, _, rate = recording.continuous["100"] +# time = recording.continuous["100"].timestamp / rate +num_channels = signal.shape[1] +``` + +### 1.2 Raw Data + ++++ + +If the data is provided in single `continuous.dat` instead of meta-data, user must provide number of channels and sampling rate in order to import data accurately. + +> **WARNING** The size of the raw datafile can be _large_ depending on sampling rate and the amount of recorded duration. We highly recommand using meta-data structure to handle datafiles, since it only loads the data during the processing and unloads once the processing is done. + +```{code-cell} ipython3 +from miv.io import load_continuous_data_file + +datapath = 'continuous.dat' +rate = 30_000 +num_channel = 64 +timestamps, signal = load_continuous_data_file(datapath, num_channel, rate) +``` + +## 2. Filtering Raw Signal + ++++ + +We provide a set of basic signal filter tools. It is highly recommended to filter the signal before doing the spike-detection. +Here, we provide examples of how to create and apply the filter to the [`dataset`](../api/io.rst). + ++++ + +If you have further suggestion on other filters to include, please leave an issue on our [GitHub issue page](https://github.com/GazzolaLab/MiV-OS/issues) with `enhancement` tag. + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.signal.filter import FilterCollection, ButterBandpass +``` + +### 2.1 Filter Collection + +[Here](../api/signal.html#filter) is the list of provided filters. +All filters are `Callable`, taking `signal` and `sampling_rate` as parameters. +To define a multiple filters together, we provide [`FilterCollection`](miv.signal.filter.FilterCollection) that execute multiple filters in a series. + +```{code-cell} ipython3 +# Butter bandpass filter +pre_filter = ButterBandpass(lowcut=300, highcut=3000, order=5) + +# How to construct sequence of filters +pre_filter = ( + FilterCollection(tag="Filter Example") + .append(ButterBandpass(lowcut=300, highcut=3000, order=5)) + #.append(Limiter(400*pq.mV)) + #.append(Filter1(**filter1_kwargs)) + #.append(Filter2(**filter2_kwargs)) +) +``` + +### 2.2 Apply Filter + +There are two way to apply the filter on the signal. +- If the signal is stored in `numpy array` format, you can directly call the filter `prefilter(signal, sampling_rate)`. +- If you want to apply the filter to all signals in the `dataset`, `dataset` provide `.apply_filter` method that takes any `filter` (any filter that abide [`filter protocol`](../api/_toctree/FilterAPI/miv.signal.filter.FilterProtocol)). + - You can select [subset of `dataset`](../api/dataset.html#data-subset) and [mask-out channels](../api/dataset.html#mask-channel) before applying the filter. + +You can check the list of all provided filters [here](../api/signal.html#filter). + +```{code-cell} ipython3 +# Apply filter to entire dataset +dataset.apply_filter(pre_filter) +filtered_signal = dataset[record_node]['experiment1']['recording1'].filtered_signal + +# Apply filter to array +rate = 30_000 +filtered_signal = pre_filter(data_array, sampling_rate=rate) + +# Retrieve data from dataset and apply filter +data = dataset[record_node]['experiment1']['recording1'] +filtered_signal = pre_filter(data, sampling_rate=rate) +``` + +## 3. Spike Detection + +You can check the available method [here](../api/signal.html#spike-detection). + +Most simple example of spike-detection method is using `ThresholdCutoff`. + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.signal.spike import ThresholdCutoff +``` + +```{code-cell} ipython3 +# Define spike-detection method +spike_detection = ThresholdCutoff() + +# The detection can be used directly as the following. +# signal : np.array or neo.core.AnalogSignal, shape(N_channels, N) +# timestamps : np.array, shape(N) +# sampling_rate : float +timestamps = spike_detection(signal, timestamps, sampling_rate=30_000, cutoff=3.5) + +# The detection can be applied on the dataset +dataset.apply_spike_detection(spike_detection) +``` + +## 4. Spike Visualization + +```{code-cell} ipython3 +:tags: [hide-cell] + +import neo +from viziphant.rasterplot import rasterplot_rates +``` + +```{code-cell} ipython3 +# Plot +rasterplot_rates(spiketrain_list) +``` diff --git a/docs/guide/spike_cutout.md b/docs/guide/spike_cutout.md new file mode 100644 index 00000000..b29c321d --- /dev/null +++ b/docs/guide/spike_cutout.md @@ -0,0 +1,67 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# Spike Cutout Visualization + +```{code-cell} ipython3 +:tags: [hide-cell] + +import os, sys +import numpy as np +import quantities as pq +import matplotlib.pyplot as plt +``` + +```{code-cell} ipython3 +:tags: [remove-cell] + +from miv.io import load_continuous_data_file + +datapath = '2022-03-10_16-19-09/Record Node 104/spontaneous/recording1/continuous/Rhythm_FPGA-100.0/continuous.dat' +rate = 30_000 +timestamps, signal = load_continuous_data_file(datapath, 64, rate) +``` + +## Pre-Filter + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.signal.filter import ButterBandpass +from miv.signal.spike import ThresholdCutoff +``` + +```{code-cell} ipython3 +pre_filter = ButterBandpass(lowcut=300, highcut=3000, order=5) +filtered_signal = pre_filter(signal, sampling_rate=rate) + +spike_detection = ThresholdCutoff() +spks = spike_detection(filtered_signal, timestamps, sampling_rate=30_000, progress_bar=False) +``` + +## Plot + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.visualization import extract_waveforms, plot_waveforms +``` + +```{code-cell} ipython3 +cutouts = extract_waveforms( + filtered_signal, spks, channel=7, sampling_rate=rate +) +plot_waveforms(cutouts, rate, n_spikes=250) +``` + +![spike cutout output](../_static/assets/spike_cutout_example.png) diff --git a/docs/guide/spike_sorting.md b/docs/guide/spike_sorting.md new file mode 100644 index 00000000..a627244a --- /dev/null +++ b/docs/guide/spike_sorting.md @@ -0,0 +1,155 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# + +## References + +- Spike sorting based on discrete wavelet transform coefficients (Letelier 2000) +- Unsupervised spike detection and sorting with wavelets and superparamagnetic clustering (Quiroga 2004) +- A novel and fully automatic spike-sorting implementation with variable number of features (Chaure 2018) + +```{code-cell} ipython3 +:tags: [hide-cell] + +import os, sys +import numpy as np +import scipy +import scipy.special +import quantities as pq +import matplotlib.pyplot as plt +import pywt +``` + +```{code-cell} ipython3 +:tags: [remove-cell] + +sys.path.append('../..') +``` + +```{code-cell} ipython3 +:tags: [remove-cell] + +from miv.io import load_continuous_data_file + +datapath = '2022-03-10_16-19-09/Record Node 104/spontaneous/recording1/continuous/Rhythm_FPGA-100.0/continuous.dat' +rate = 30_000 +timestamps, signal = load_continuous_data_file(datapath, 64, rate) +``` + +## Pre-Filter + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.signal.filter import ButterBandpass +from miv.signal.spike import ThresholdCutoff +``` + +```{code-cell} ipython3 +pre_filter = ButterBandpass(lowcut=300, highcut=3000, order=5) +filtered_signal = pre_filter(signal, sampling_rate=rate) + +spike_detection = ThresholdCutoff() +spks = spike_detection(filtered_signal, timestamps, sampling_rate=30_000, progress_bar=False) +``` + +## Plot + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.visualization import extract_waveforms, plot_waveforms +``` + +```{code-cell} ipython3 +cutouts = extract_waveforms( + filtered_signal, spks, channel=7, sampling_rate=rate +) +plot_waveforms(cutouts, rate, n_spikes=250) +``` + +## Simple Clustering + +```{code-cell} ipython3 +:tags: [hide-cell] + +from sklearn.mixture import GaussianMixture +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +``` + +```{code-cell} ipython3 +scaler = StandardScaler() +scaled_cutouts = scaler.fit_transform(cutouts) + +pca = PCA() +pca.fit(scaled_cutouts) +# print(pca.explained_variance_ratio_) + +pca.n_components = 2 +transformed = pca.fit_transform(scaled_cutouts) +``` +```{code-cell} ipython3 +# Clustering +n_components = 3 # Number of clustering components +gmm = GaussianMixture(n_components=n_components, n_init=10) +labels = gmm.fit_predict(transformed) +``` + +```{code-cell} ipython3 +tmp_list = [] +for i in range(n_components): + idx = labels == i + tmp_list.append(timestamps[idx]) + spikestamps_clustered.append(tmp_list) + +_ = plt.figure(figsize=(8, 8)) +for i in range(n_components): + idx = labels == i + _ = plt.plot(transformed[idx, 0], transformed[idx, 1], ".") + _ = plt.title("Cluster assignments by a GMM") + _ = plt.xlabel("Principal Component 1") + _ = plt.ylabel("Principal Component 2") + _ = plt.legend([0, 1, 2]) + _ = plt.axis("tight") + +_ = plt.figure(figsize=(8, 8)) +for i in range(n_components): + idx = labels == i + color = plt.rcParams["axes.prop_cycle"].by_key()["color"][i] + plot_waveforms( + cutouts[idx, :], rate, n_spikes=100, color=color, + ) +# custom legend +custom_lines = [plt.Line2D([0], [0], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i], lw=4,) \ + for i in range(n_components)] +plt.legend(custom_lines, [f"component {i}" for i in range(n_components)]) +``` + +## Wavelet Decomposition + +```{code-cell} ipython3 +:tags: [hide-cell] + +from miv.signal.spike import SpikeSorting, WaveletDecomposition +from sklearn.clusterr import MeanShift +``` + +```{raw-cell} +spike_sorting = SpikeSorting( + feature_extractor=WaveletDecomposition(), + clsutering_method=sklearn.cluster.MeanShift() +) +label, index = spike_sorting(cutouts, return_index=True) +``` diff --git a/docs/index.rst b/docs/index.rst index aec5ad51..3ebebdc6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,5 @@ .. MiV-OS documentation master file, created by sphinx-quickstart on Thu Mar 24 23:35:49 2022. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. MiV-OS documentation! ===================== @@ -17,18 +15,29 @@ Installation Instruction You can also download the source code from `GitHub `_ directly. +Contribution +------------ + +Any contribution to this project is welcome! If you are interested or have any questions, please don't hesitate to contact us. +If you are interested in contributing to this project, we prepared contribution guideline :ref:`here `. + .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Overview overview/about + overview/references + overview/contribution .. toctree:: :maxdepth: 2 :caption: User Guide + guide/data_management guide/signal_processing + guide/spike_cutout + guide/spike_sorting .. toctree:: :maxdepth: 2 @@ -36,7 +45,9 @@ You can also download the source code from `GitHub Indices and tables ================== diff --git a/docs/overview/contribution.md b/docs/overview/contribution.md new file mode 100644 index 00000000..5b272509 --- /dev/null +++ b/docs/overview/contribution.md @@ -0,0 +1,3 @@ +```{include} ../../CONTRIBUTING.md +:relative-images: +``` diff --git a/docs/overview/references.rst b/docs/overview/references.rst new file mode 100644 index 00000000..1d26e3af --- /dev/null +++ b/docs/overview/references.rst @@ -0,0 +1,22 @@ +********** +References +********** + +Neural Ensemble +############### + +- Python-Neo [1]_ +- Elephant/Viziphant [2]_ + +Algorithm +######### + +- PyWavelets [3]_ + +--------------- + +.. [1] Garcia S., Guarino D., Jaillet F., Jennings T.R., Pröpper R., Rautenberg P.L., Rodgers C., Sobolev A.,Wachtler T., Yger P. and Davison A.P. (2014) Neo: an object model for handling electrophysiology data in multiple formats. Frontiers in Neuroinformatics 8:10: doi:10.3389/fninf.2014.00010 + +.. [2] Denker M, Yegenoglu A, Grün S (2018) Collaborative HPC-enabled workflows on the HBP Collaboratory using the Elephant framework. Neuroinformatics 2018, P19. doi:10.12751/incf.ni2018.0019 + +.. [3] Gregory R. Lee, Ralf Gommers, Filip Wasilewski, Kai Wohlfahrt, Aaron O’Leary (2019). PyWavelets: A Python package for wavelet analysis. Journal of Open Source Software, 4(36), 1237, https://doi.org/10.21105/joss.01237. diff --git a/miv/io/__init__.py b/miv/io/__init__.py index e49ec5ed..0ec7ac30 100644 --- a/miv/io/__init__.py +++ b/miv/io/__init__.py @@ -1 +1,2 @@ +from miv.io.data import * from miv.io.binary import * diff --git a/miv/io/binary.py b/miv/io/binary.py index c21cd109..29a10bb8 100644 --- a/miv/io/binary.py +++ b/miv/io/binary.py @@ -1,214 +1,225 @@ __doc__ = """ -We expect the data structure to follow the default format exported from OpenEphys system: `format `_. +------------------------------------- -Original Author +Raw Data Loader +############### -- open-ephys/analysis-tools/Python3/Binary.py (commit: 871e003) -- original author: malfatti - - date: 2019-07-27 -- last modified by: skim449 - - date: 2022-04-11 """ -__all__ = ["Load", "load_data", "load_continuous_data_file"] -from typing import Optional +__all__ = ["load_continuous_data", "load_recording", "oebin_read", "apply_channel_mask"] + +from typing import Any, Dict, Optional, Union, List, Set, Sequence + +import os import numpy as np from ast import literal_eval from glob import glob +import quantities as pq +import neo +from miv.typing import SignalType, TimestampsType -def ApplyChannelMap(Data, ChannelMap): - print("Retrieving channels according to ChannelMap... ", end="") - for R, Rec in Data.items(): - if Rec.shape[1] < len(ChannelMap) or max(ChannelMap) > Rec.shape[1] - 1: - print("") - print("Not enough channels in data to apply channel map. Skipping...") - continue - Data[R] = Data[R][:, ChannelMap] - - return Data +def apply_channel_mask(signal: np.ndarray, channel_mask: Set[int]): + """Apply channel mask on the given signal. + Parameters + ---------- + signal : np.ndarray + Shape of the signal is expected to be (num_data_point, num_channels). + channel_mask : Set[int] -def BitsToVolts(Data, ChInfo, Unit): - print("Converting to uV... ", end="") - Data = {R: Rec.astype("float32") for R, Rec in Data.items()} + Returns + ------- + output signal : SignalType - if Unit.lower() == "uv": - U = 1 - elif Unit.lower() == "mv": - U = 10 ** -3 + Raises + ------ + IndexError + Typically raise index error when the dimension of the signal is less than 2. + AttributeError + If signal is non numpy array type. - for R in Data.keys(): - for C in range(len(ChInfo)): - Data[R][:, C] = Data[R][:, C] * ChInfo[C]["bit_volts"] * U - if "ADC" in ChInfo[C]["channel_name"]: - Data[R][:, C] *= 10 ** 6 + """ - return Data + num_channels = signal.shape[1] + channel_index_set = set(range(num_channels)) - channel_mask + channel_index = np.array(np.sort(list(channel_index_set))) + signal = signal[:, channel_index] + return signal -def Load( - Folder, Processor=None, Experiment=None, Recording=None, Unit="uV", ChannelMap=[] -): +def bits_to_voltage(signal: SignalType, channel_info: Sequence[Dict[str, Any]]): """ - Loads data recorded by Open Ephys in Binary format as numpy memmap. + Convert binary bit data to voltage (microVolts) + + Parameters + ---------- + signal : SignalType, numpy array + channel_info : Dict[str, Dict[str, Any]] + Channel information dictionary. Typically located in `structure.oebin` file. + channel information includes bit-volts conversion ration and units (uV or mV). + + Returns + ------- + signal : numpy array + Output signal is in microVolts unit. - Load(Folder, Processor=None, Experiment=None, Recording=None, Unit='uV', ChannelMap=[]) + """ + resultant_unit = pq.Quantity(1, "uV") # Final Unit + for channel in range(len(channel_info)): + bit_to_volt_conversion = channel_info[channel]["bit_volts"] + recorded_unit = pq.Quantity([1], channel_info[channel]["units"]) + unit_conversion = (recorded_unit / resultant_unit).simplified + signal[:, channel] *= bit_to_volt_conversion * unit_conversion + if "ADC" in channel_info[channel]["channel_name"]: + signal[:, channel] *= 10 ** 6 + return signal + + +def oebin_read(file_path: str): + """ + Oebin file reader in dictionary form Parameters ---------- - Folder: str - Folder containing at least the subfolder 'experiment1'. - - Processor: str or None, optional - Processor number to load, according to subsubsubfolders under - Folder>experimentX/recordingY/continuous . The number used is the one - after the processor name. For example, to load data from the folder - 'Channel_Map-109_100.0' the value used should be '109'. - If not set, load all processors. - - Experiment: int or None, optional - Experiment number to load, according to subfolders under Folder. - If not set, load all experiments. - - Recording: int or None, optional - Recording number to load, according to subsubfolders under Folder>experimentX . - If not set, load all recordings. - - Unit: str or None, optional - Unit to return the data, either 'uV' or 'mV' (case insensitive). In - both cases, return data in float32. Defaults to 'uV'. - If anything else, return data in int16. - - ChannelMap: list, optional - If empty (default), load all channels. - If not empty, return only channels in ChannelMap, in the provided order. - CHANNELS ARE COUNTED STARTING AT 0. + file_path : str Returns ------- - Data: dict - Dictionary with data in the structure Data[Processor][Experiment][Recording]. + info : Dict[str, any] + recording information stored in oebin file. + """ + # TODO: may need fix for multiple continuous data. + # TODO: may need to include processor name/id + info = literal_eval(open(file_path).read()) + return info - Rate: dict - Dictionary with sampling rates in the structure Rate[Processor][Experiment]. +def load_recording( + folder: str, + channel_mask: Optional[Set[int]] = None, +): + """ + Loads data recorded by Open Ephys in Binary format as numpy memmap. + The path should contain + + - continuous//continuous.dat: signal (cannot have multiple file) + - continuous//timestamps.dat: timestamps + - structure.oebin: number of channels and sampling rate. - Example + Parameters + ---------- + folder: str + folder containing at least the subfolder 'experiment1'. + channel_mask: Set[int], optional + Channel index list to ignore in import (default=None) + + Returns ------- - import Binary + signal : SignalType, neo.core.AnalogSignal + sampling_rate : float - Folder = '/home/user/PathToData/2019-07-27_00-00-00' - Data, Rate = Binary.Load(Folder) + Raises + ------ + AssertionError + If more than one "continuous.dat" file exist in the directory. - ChannelMap = [0,15,1,14] - Recording = 3 - Data2, Rate2 = Binary.Load(Folder, Recording=Recording, ChannelMap=ChannelMap, Unit='Bits') """ - Files = sorted(glob(Folder + "/**/*.dat", recursive=True)) - InfoFiles = sorted(glob(Folder + "/*/*/structure.oebin")) - - Data, Rate = {}, {} - for F, File in enumerate(Files): - File = File.replace("\\", "/") # Replace windows file delims - Exp, Rec, _, Proc = File.split("/")[-5:-1] - Exp = str(int(Exp[10:]) - 1) - Rec = str(int(Rec[9:]) - 1) - Proc = Proc.split(".")[0].split("-")[-1] - if "_" in Proc: - Proc = Proc.split("_")[0] - - if Proc not in Data.keys(): - Data[Proc], Rate[Proc] = {}, {} - - if Experiment: - if int(Exp) != Experiment - 1: - continue - - if Recording: - if int(Rec) != Recording - 1: - continue - - if Processor: - if Proc != Processor: - continue - - print("Loading recording", int(Rec) + 1, "...") - if Exp not in Data[Proc]: - Data[Proc][Exp] = {} - Data[Proc][Exp][Rec] = np.memmap(File, dtype="int16", mode="c") - - Info = literal_eval(open(InfoFiles[F]).read()) - ProcIndex = [ - Info["continuous"].index(_) - for _ in Info["continuous"] - if str(_["source_processor_id"]) == Proc - ][ - 0 - ] # Changed to source_processor_id from recorded_processor_id - - ChNo = Info["continuous"][ProcIndex]["num_channels"] - if Data[Proc][Exp][Rec].shape[0] % ChNo: - print("Rec", Rec, "is broken") - del Data[Proc][Exp][Rec] - continue - - SamplesPerCh = Data[Proc][Exp][Rec].shape[0] // ChNo - Data[Proc][Exp][Rec] = Data[Proc][Exp][Rec].reshape((SamplesPerCh, ChNo)) - Rate[Proc][Exp] = Info["continuous"][ProcIndex]["sample_rate"] - - for Proc in Data.keys(): - for Exp in Data[Proc].keys(): - if Unit.lower() in ["uv", "mv"]: - ChInfo = Info["continuous"][ProcIndex]["channels"] - Data[Proc][Exp] = BitsToVolts(Data[Proc][Exp], ChInfo, Unit) - - if ChannelMap: - Data[Proc][Exp] = ApplyChannelMap(Data[Proc][Exp], ChannelMap) - - print("Done.") - - return (Data, Rate) - - -def load_data(): - raise NotImplementedError - - -def load_continuous_data_file( - data_file: str, - channels: int, - timestamps_npy: Optional[str] = "", - sampling_rate: float = 30000, + file_path: List[str] = glob(os.path.join(folder, "**", "*.dat"), recursive=True) + assert ( + len(file_path) == 1 + ), f"There should be only one 'continuous.dat' file. (There exists {file_path})" + + # load structure information dictionary + info_file: str = os.path.join(folder, "structure.oebin") + info: Dict[str, Any] = oebin_read(info_file) + num_channels: int = info["continuous"][0]["num_channels"] + sampling_rate: float = float(info["continuous"][0]["sample_rate"]) + # channel_info: Dict[str, Any] = info["continuous"][0]["channels"] + + # TODO: maybe need to support multiple continuous.dat files in the future + signal, timestamps = load_continuous_data(file_path[0], num_channels, sampling_rate) + + if channel_mask: + signal = apply_channel_mask(signal, channel_mask) + + # To Voltage + signal = bits_to_voltage(signal, info["continuous"][0]["channels"]) + # signal = neo.core.AnalogSignal( + # signal*pq.uV, sampling_rate=sampling_rate * pq.Hz + # ) + return signal, timestamps, sampling_rate + + +def load_continuous_data( + data_path: str, + num_channels: int, + sampling_rate: float, + timestamps_path: Optional[str] = None, + start_at_zero: bool = True, ): """ - Describe function + Load single continous data file and return timestamps and raw data in numpy array. + Typical `data_path` from OpenEphys has a name `continuous.dat`. + + .. note:: + The output data is raw-data without unit conversion. In order to convert the unit + to voltage, you need to multiply by `bit_volts` conversion ratio. This ratio and + units are typially saved in `structure.oebin` file. Parameters ---------- - data_file: continuous.dat file from Open_Ethys recording - channels: number of recording channels recorded from + data_path : str + continuous.dat file path from Open_Ethys recording. + num_channels : int + number of recording channels recorded. Note, this method will not throw an error + if you don't provide the correct number of channels. + sampling_rate : float + data sampling rate. + timestamps_path : Optional[str] + If None, first check if the file "timestamps.npy" exists on the same directory. + If the file doesn't exist, we deduce the timestamps based on the sampling rate + and the length of the data. + start_at_zero : bool + If True, the timestamps is adjusted to start at zero. + Note, recorded timestamps might not start at zero for some reason. Returns ------- - raw_data: - timestamps: + raw_data: SignalType, numpy array + timestamps: TimestampsType, numpy array + + Raises + ------ + FileNotFoundError + If data_path is invalid. + ValueError + If the error message shows the array cannot be reshaped due to shape, + make sure the num_channels is set accurately. """ - raw_data: np.ndarray = np.memmap(data_file, dtype="int16") - length = raw_data.size // channels - raw_data = np.reshape(raw_data, (length, channels)) + # Read raw data signal + raw_data: np.ndarray = np.memmap(data_path, dtype="int16", mode="c") + length = raw_data.size // num_channels + raw_data = np.reshape(raw_data, (length, num_channels)).astype("float32") + + # Get timestamps_path + if timestamps_path is None: + dirname = os.path.dirname(data_path) + timestamps_path = os.path.join(dirname, "timestamps.npy") - timestamps_zeroed = np.array(range(0, length)) / sampling_rate - if timestamps_npy == "": - timestamps = timestamps_zeroed - else: - timestamps = np.load(timestamps_npy) / sampling_rate + # Get timestamps + if os.path.exists(timestamps_path): + timestamps = np.array(np.load(timestamps_path), dtype=np.float64) + timestamps /= float(sampling_rate) + else: # If timestamps_path doesn't exist, deduce the stamps + timestamps = np.array(range(0, length)) / sampling_rate - # only take first 32 channels - raw_data = raw_data[:, 0:32] + # Adjust timestamps to start from zero + if start_at_zero and not np.isclose(timestamps[0], 0.0): + timestamps -= timestamps[0] - return np.array(timestamps), np.array(raw_data) + return np.array(raw_data), timestamps diff --git a/miv/io/data.py b/miv/io/data.py index d0d7096c..5c5ee386 100644 --- a/miv/io/data.py +++ b/miv/io/data.py @@ -1,70 +1,163 @@ +__doc__ = """ + +.. Note:: + We expect the data structure to follow the default format + exported from OpenEphys system: + `format `_. + +.. Note:: + For simple experiments, you may prefer to use :ref:`api/io:Raw Data Loader`. + However, we generally recommend to use ``Data`` or ``DataManager`` for + handling data, especially when the size of the raw data is large. + +Module +###### + +.. currentmodule:: miv.io.data + +.. autoclass:: Data + :members: + +---------------------- + +.. autoclass:: DataManager + :members: + +""" +__all__ = ["Data", "DataManager"] + +from typing import Any, Optional, Iterable, Callable, List, Set + from collections.abc import MutableSequence -from typing import Optional +import logging + import os +from glob import glob import numpy as np +from contextlib import contextmanager + +from miv.io.binary import load_continuous_data, load_recording from miv.signal.filter import FilterProtocol from miv.typing import SignalType class Data: - """ - For each continues.dat file, there will be one Data object + """Single data unit handler. + + Each data unit that contains single recording. This class provides useful tools, + such as masking channel, export data, interface with other packages, etc. + If you have multiple recordings you would like to handle at the same time, use + `DataManager` instead. + + By default recording setup, the following directory structure is expected in ``data_path``:: + + recording1 # <- recording data_path + ├── continuous + │   └── Rhythm_FPGA-100.0 + │   ├── continuous.dat + │   ├── synchronized_timestamps.npy + │   └── timestamps.npy + ├── events + │   ├── Message_Center-904.0 + │   │   └── TEXT_group_1 + │   │   ├── channels.npy + │   │   ├── text.npy + │   │   └── timestamps.npy + │   └── Rhythm_FPGA-100.0 + │   └── TTL_1 + │   ├── channel_states.npy + │   ├── channels.npy + │   ├── full_words.npy + │   └── timestamps.npy + ├── sync_messages.txt + ├── structure.oebin + └── analysis # <- post-processing result + ├── spike_data.npz + ├── plot + ├── spike + └── mea_overlay + + + Parameters + ---------- + data_path : str """ def __init__( self, data_path: str, - channels: int, - sampling_rate: float = 30000, - timestamps_npy: Optional[str] = "", - ): - self.data_path = data_path - self.channels = channels - self.sampling_rate = sampling_rate - self.timestamps_npy = timestamps_npy - - def load( - self, ): + self.data_path: str = data_path + self.analysis_path: str = os.path.join(data_path, "analysis") + self.masking_channel_set: Set[int] = set() + @contextmanager + def load(self): """ - Describe function + Context manager for loading data instantly. - Parameters - ---------- - data_file: continuous.dat file from Open_Ethys recording - channels: number of recording channels recorded from + Examples + -------- + >>> data = Data(data_path) + >>> with data.load() as (timestamps, raw_signal): + ... ... Returns ------- - raw_data: - timestamps: + signal : SignalType, neo.core.AnalogSignal + timestamps : TimestampsType, numpy array + sampling_rate : float + + Raises + ------ + FileNotFoundError + If some key files are missing. """ + # TODO: Not sure this is safe implementation + if not self.check_path_validity(): + raise FileNotFoundError("Data directory does not have all necessary files.") + try: + signal, timestamps, sampling_rate = load_recording( + self.data_path, self.masking_channel_set + ) + yield signal, timestamps, sampling_rate + except FileNotFoundError as e: + logging.error( + f"The file could not be loaded because the file {self.data_path} does not exist." + ) + logging.error(e.strerror) + except ValueError as e: + logging.error( + "The data size does not match the number of channel. Check if oebin or continuous.dat file is corrupted." + ) + logging.error(e.strerror) + finally: + del timestamps + del signal - raw_data: np.ndarray = np.memmap(self.data_path, dtype="int16") - length = raw_data.size // self.channels - raw_data = np.reshape(raw_data, (length, self.channels)) + def set_channel_mask(self, channel_id: Iterable[int]): + """ + Set the channel masking. - timestamps_zeroed = np.array(range(0, length)) / self.sampling_rate - if self.timestamps_npy == "": - timestamps = timestamps_zeroed - else: - timestamps = np.load(self.timestamps_npy) / self.sampling_rate + Parameters + ---------- + channel_id : Iterable[int], list + List of channel id that will be ignored. - # only take first 32 channels - raw_data = raw_data[:, 0 : self.channels] + Notes + ----- + If the index exceed the number of channels, it will be ignored. - # TODO: do we want timestaps a member of the class? - return np.array(timestamps), np.array(raw_data) + Examples + -------- + >>> data = Data(data_path) + >>> data.set_channel_mask(range(12,23)) - def unload( - self, - ): - # TODO: remove the data from memory - pass + """ + self.masking_channel_set.update(channel_id) - def save(self, tag: str, format: str): + def save(self, tag: str, format: str): # TODO assert tag == "continuous", "You cannot alter raw data, change the data tag" # save_path = os.path.join(self.data_path, tag) @@ -80,65 +173,176 @@ def save(self, tag: str, format: str): "Please choose one of the supported formats: dat, npz, neo", ) + def check_path_validity(self): + """ + Check if necessary files exist in the directory. -class Dataset(MutableSequence): - def __init__( - self, - data_folder_path: str, - channels: int, - sampling_rate: float = 30000, - timestamps_npy: Optional[str] = "", - ): - self.data_folder_path = data_folder_path + - Check `continious.dat` exists. (only one) + - Check `structure.oebin` exists. - # From the path get data paths and create data objects - self.load_data_sets(channels, sampling_rate, timestamps_npy) + Returns + ------- + bool + Return true if all necessary files exist in the directory. - def load_data_sets(self, channels, sampling_rate, timestamps_npy): """ - Create data objects from the data three. + + continuous_dat_paths = glob( + os.path.join(self.data_path, "**", "continuous.dat"), recursive=True + ) + if len(continuous_dat_paths) != 1: + logging.warning( + f"One and only one continuous.dat file can exist in the data path. Found: {continuous_dat_paths}" + ) + return False + if not os.path.exists(os.path.join(self.data_path, "structure.oebin")): + logging.warning("Missing structure.oebin in the data path.") + return False + return True + + +class DataManager(MutableSequence): + """ + Data collection manager. + + By default recording setup, the directory is named after the date and time + of the recording. The structure of ``data_collection_path`` typically look + like below:: + + 2022-03-10_16-19-09 <- data_collection_path + └── Record Node 104 + └── experiment1 + └── recording1 <- data_path (Data module) + ├── experiment2 + ├── experiment3 + ├── experiment4 + ├── spontaneous + ├── settings.xml + ├── settings_2.xml + └── settings_3.xml Parameters ---------- - path + data_collection_path : str + Path for data collection. - Returns - ------- + """ + + def __init__(self, data_collection_path: str): + self.data_collection_path = data_collection_path + self.data_list: Iterable[Data] = [] + + # From the path get data paths and create data objects + self._load_data_paths() + + @property + def data_path_list(self) -> Iterable[str]: + return [data.data_path for data in self.data_list] + + def tree(self): + """ + Pretty-print available recordings in DataManager in tree format. + + Examples + -------- + >>> data_collection = DataManager("2022-05-15_13-51-36") + >>> data_collection.tree() + 2022-05-15_14-51-36 + 0: + └── Record Node 103/experiment3_std2_pt_ESC/recording1 + 1: + └── Record Node 103/experiment2_std1_pt_ESC/recording1 + 2: + └── Record Node 103/experiment1_cont_ESC/recording1 + """ + # TODO: Either use logging or other str stream + if not self.data_list: + print( + "Data list is empty. Check if data_collection_path exists and correct" + ) + return + print(self.data_collection_path) + for idx, data in enumerate(self.data_list): + print(" " * 4 + f"{idx}: {data}") + print( + " " * 4 + + " └── " + + data.data_path[len(self.data_collection_path) + 1 :] + ) + + def _load_data_paths(self): + """ + Create data objects from the data three. """ # From the path get the data path list - self.data_path_list = self._get_data_path_from_tree() + data_path_list = self._get_experiment_paths() - # Create an object for each continues.dat and store them in data list to manipulate later. + # Create data object self.data_list = [] - for data_path in self.data_path_list: - self.data_list.append( - Data(data_path, channels, sampling_rate, timestamps_npy) - ) + invalid_count = 0 + for path in data_path_list: + data = Data(path) + if data.check_path_validity(): + self.data_list.append(data) + else: + invalid_count += 1 + logging.info( + f"Total {len(data_path_list)} recording found. There are {invalid_count} invalid paths." + ) - def _get_data_path_from_tree(self): + def _get_experiment_paths(self) -> Iterable[str]: """ - This function gets the data for each continues.dat file inside the data folder. + Get experiment paths. + Returns ------- data_path_list : list """ - # TODO: implement algorithm to get paths of all continues.dat files. - # Use self.data_folder_path - raise NotImplementedError("Loading data tree not implemented yet") - data_path_list = [] - return data_path_list + # Use self.data_collection_path + path_list = [] + for path in glob( + os.path.join(self.data_collection_path, "*", "experiment*", "recording*") + ): + if ( + ("Record Node" in path) + and ("experiment" in path) + and os.path.isdir(path) + ): + path_list.append(path) + return path_list def save(self, tag: str, format: str): + raise NotImplementedError # TODO for data in self.data_list: data.save(tag, format) def apply_filter(self, filter: FilterProtocol): + raise NotImplementedError # TODO for data in self.data_list: data.load() data = filter(data, sampling_rate=0) data.save(tag="filter", format="npz") data.unload() - # def apply_spike_detection(self, method: DetectionProtocol): - # raise NotImplementedError("Wait until we make it") + # MutableSequence abstract methods + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + return self.data_list[idx] + + def __delitem__(self, idx): + del self.data_list[idx] + + def __setitem__(self, idx, data): + if data.check_path_validity(): + self.data_list[idx] = data + else: + logging.warning("Invalid data cannot be loaded to the DataManager.") + + def insert(self, idx, data): + if data.check_path_validity(): + self.data_list.insert(idx, data) + else: + logging.warning("Invalid data cannot be loaded to the DataManager.") diff --git a/miv/signal/filter/butter_bandpass_filter.py b/miv/signal/filter/butter_bandpass_filter.py index 3d882f8f..ef29a008 100644 --- a/miv/signal/filter/butter_bandpass_filter.py +++ b/miv/signal/filter/butter_bandpass_filter.py @@ -7,6 +7,7 @@ import numpy.typing as npt import scipy.signal as sps +import matplotlib.pyplot as plt from miv.typing import SignalType @@ -32,10 +33,38 @@ class ButterBandpass: order: int = 5 tag: str = "" - def __call__(self, signal: SignalType, sampling_rate: float) -> SignalType: + def __call__( + self, + signal: SignalType, + sampling_rate: float, + plot_frequency_response: bool = False, + **kwargs, + ) -> SignalType: + """__call__. + + Parameters + ---------- + signal : SignalType + signal + sampling_rate : float + sampling_rate + plot_frequency_response : bool + plot_frequency_response + kwargs : + kwargs + + Returns + ------- + SignalType + + """ b, a = self._butter_bandpass(sampling_rate) y = sps.lfilter(b, a, signal) - return y + if plot_frequency_response: + fig = self.plot_frequency_response(a, b) + return y, fig + else: + return y def __post_init__(self): assert ( @@ -54,3 +83,14 @@ def _butter_bandpass(self, sampling_rate: float): high = self.highcut / nyq b, a = sps.butter(self.order, [low, high], btype="band") return b, a + + def plot_frequency_response(self, a, b): + w, h = sps.freqs(b, a) + fig = plt.figure() + plt.semilogx(w, 20 * np.log10(abs(h))) + plt.title( + f"Butterworth filter (order{self.order}) frequency response [{self.lowcut},{self.highcut}]" + ) + plt.xlabel("Frequency") + plt.ylabel("Amplitude") + return fig diff --git a/miv/signal/filter/filter_collection.py b/miv/signal/filter/filter_collection.py index 3465777e..35338479 100644 --- a/miv/signal/filter/filter_collection.py +++ b/miv/signal/filter/filter_collection.py @@ -1,4 +1,23 @@ -__doc__ = "" +__doc__ = """ + +Signal Filter +############# + + + +.. currentmodule:: miv.signal.filter + +.. autoclass:: FilterCollection + :members: append, insert + +.. autosummary:: + :nosignatures: + :toctree: _toctree/FilterAPI + + FilterProtocol + ButterBandpass + +""" __all__ = ["FilterCollection"] from typing import Union, List @@ -46,8 +65,8 @@ def __init__(self, tag: str = ""): def __call__(self, signal: SignalType, sampling_rate: float) -> SignalType: for filter in self.filters: - y: SignalType = filter(signal, sampling_rate) - return y + signal = filter(signal, sampling_rate) + return signal # MutableSequence abstract methods def __len__(self): diff --git a/miv/signal/filter/protocol.py b/miv/signal/filter/protocol.py index ece88bca..9393f508 100644 --- a/miv/signal/filter/protocol.py +++ b/miv/signal/filter/protocol.py @@ -24,3 +24,7 @@ def __call__(self, array: SignalType, sampling_rate: float, **kwargs) -> SignalT samping_rate : float """ ... + + def __repr__(self) -> str: + """String representation for interactive debugging.""" + ... diff --git a/miv/signal/spike/__init__.py b/miv/signal/spike/__init__.py index d08d9d78..7ea090f6 100644 --- a/miv/signal/spike/__init__.py +++ b/miv/signal/spike/__init__.py @@ -1,2 +1,3 @@ from miv.signal.spike.protocol import * from miv.signal.spike.detection import * +from miv.signal.spike.sorting import * diff --git a/miv/signal/spike/detection.py b/miv/signal/spike/detection.py index b5b67849..d179148e 100644 --- a/miv/signal/spike/detection.py +++ b/miv/signal/spike/detection.py @@ -1,4 +1,20 @@ -__doc__ = "" +__doc__ = """ + +Spike Detection +############### + + + +.. currentmodule:: miv.signal.spike + +.. autosummary:: + :nosignatures: + :toctree: _toctree/DetectionAPI + + SpikeDetectionProtocol + ThresholdCutoff + +""" __all__ = ["ThresholdCutoff"] from typing import Union, List, Iterable @@ -6,6 +22,8 @@ import numpy as np import quantities as pq +from tqdm import tqdm + from dataclasses import dataclass from miv.typing import SignalType, TimestampsType, SpikestampsType @@ -17,16 +35,23 @@ class ThresholdCutoff: """ThresholdCutoff Spike sorting step by step guide is well documented `here `_. - Parameters + Attributes ---------- dead_time : float (default=0.003) - search_range : floatfloatfloatfloat + search_range : float (default=0.002) + cutoff : Union[float, np.ndarray] + (default=5.0) + use_mad : bool + (default=False) + tag : str """ dead_time: float = 0.003 search_range: float = 0.002 + cutoff: float = 5.0 + use_mad: bool = True tag: str = "Threshold Cutoff Spike Detection" def __call__( @@ -35,8 +60,7 @@ def __call__( timestamps: TimestampsType, sampling_rate: float, units: Union[str, pq.UnitTime] = "sec", - cutoff: Union[float, np.ndarray] = 5.0, - use_mad: bool = True, + progress_bar: bool = True, ) -> List[SpikestampsType]: """Execute threshold-cutoff method and return spike stamps @@ -50,10 +74,8 @@ def __call__( sampling_rate units : Union[str, pq.UnitTime] (default='sec') - cutoff : Union[float, np.ndarray] - (default=5.0) - use_mad : bool - (default=False) + progress_bar : bool + Toggle progress bar (default=True) Returns ------- @@ -62,12 +84,14 @@ def __call__( """ # Spike detection for each channel spiketrain_list = [] - num_channels = len(signal) # type: ignore - for channel in range(num_channels): - array = signal[channel] # type: ignore + num_channels = signal.shape[1] # type: ignore + for channel in tqdm(range(num_channels), disable=not progress_bar): + array = signal[:, channel] # type: ignore # Spike Detection: get spikestamp - spike_threshold = self.compute_spike_threshold(array, use_mad=use_mad) + spike_threshold = self.compute_spike_threshold( + array, cutoff=self.cutoff, use_mad=self.use_mad + ) crossings = self.detect_threshold_crossings( array, sampling_rate, spike_threshold, self.dead_time ) @@ -76,13 +100,15 @@ def __call__( ) spikestamp = spikes / sampling_rate # Convert spikestamp to neo.SpikeTrain (for plotting) - spiketrain = neo.SpikeTrain(spikestamp, units=units) + spiketrain = neo.SpikeTrain( + spikestamp, units=units, t_stop=timestamps.max() + ) spiketrain_list.append(spiketrain) return spiketrain_list def compute_spike_threshold( self, signal: SignalType, cutoff: float = 5.0, use_mad: bool = True - ) -> float: + ) -> float: # TODO: make this function compatible to array of cutoffs (for each channel) """ Returns the threshold for the spike detection given an array of signal. diff --git a/miv/signal/spike/protocol.py b/miv/signal/spike/protocol.py index a24ffeb0..da092835 100644 --- a/miv/signal/spike/protocol.py +++ b/miv/signal/spike/protocol.py @@ -1,7 +1,12 @@ -__all__ = ["SpikeDetectionProtocol"] +__all__ = [ + "SpikeDetectionProtocol", + "SpikeFeatureExtractionProtocol", + "UnsupervisedFeatureClusteringProtocol", +] -from typing import Protocol +from typing import Protocol, Union, Any, Iterable +import numpy as np import neo.core from miv.typing import SignalType, TimestampsType, SpikestampsType @@ -15,3 +20,32 @@ def __call__( def __repr__(self) -> str: ... + + +# TODO: Behavior is clear, but not sure what the name should be +# class SpikeSortingProtocol(Protocol): +# def __call__( +# self, signals: Iterable[SignalType] +# ) -> np.ndarray: +# ... +# +# def __repr__(self) -> str: +# ... + + +class SpikeFeatureExtractionProtocol(Protocol): + """ex) wavelet transform, PCA, etc.""" + + def __repr__(self) -> str: + ... + + +class UnsupervisedFeatureClusteringProtocol(Protocol): + def __repr__(self) -> str: + ... + + def fit(self, X: np.ndarray): + ... + + def predict(self, X: np.ndarray) -> np.ndarray: + ... diff --git a/miv/signal/spike/sorting.py b/miv/signal/spike/sorting.py new file mode 100644 index 00000000..6c243a86 --- /dev/null +++ b/miv/signal/spike/sorting.py @@ -0,0 +1,347 @@ +__doc__ = """ + +Typical spike-sorting procedure can be described in three steps: (1) spike detection, (2) feature decomposition, and (3) clustering. +We provide separate module to perform spike-detection; see :ref:`here `. + +We provide `SpikeSorting` module that composes *feature-decomposition* method and *unsupervised-clustering* method. +A commonly used feature-decomposition method includes PCA or wavelet decomposition. +For clustering method, one implemented few commonly appearing methods from the literatures (listed below). +Additionally, one can use out-of-the-box clustering modules from `sklearn`. + +.. note:: Depending on the method of clustering, there might be an additional step to find optimum number of cluster. + +.. currentmodule:: miv.signal.spike + +.. autoclass:: SpikeSorting + :members: + +Available Feature Extractor +########################### + +.. autosummary:: + :toctree: _toctree/SpikeSortingAPI + + SpikeFeatureExtractionProtocol + WaveletDecomposition + PCADecomposition + +Unsupervised Clustering +####################### + +.. autosummary:: + :toctree: _toctree/SpikeSortingAPI + + UnsupervisedFeatureClusteringProtocol + SuperParamagneticClustering + +Other external tools +-------------------- + +Following external modules can also be used for the spike sorting. + +Sklearn Clustering +~~~~~~~~~~~~~~~~~~ +.. autosummary:: + + sklearn.cluster.MeanShift + sklearn.cluster.KMeans + +""" +__all__ = [ + "SpikeSorting", + "WaveletDecomposition", + "PCADecomposition", + "SuperParamagneticClustering", +] + +from typing import Any, Union, Optional + +from dataclasses import dataclass + +import numpy as np +import scipy +import scipy.special +import quantities as pq + +import pywt + +import matplotlib.pyplot as plt + +from sklearn.mixture import GaussianMixture +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +import neo +from miv.typing import SignalType, TimestampsType, SpikestampsType +from miv.signal.spike.protocol import ( + SpikeFeatureExtractionProtocol, + UnsupervisedFeatureClusteringProtocol, +) + + +class SpikeSorting: + """ + Spike sorting module. + User can specify the method for feature extraction (e.g. WaveletDecomposition, PCADecomposition, etc) + and the method for clustering (e.g. MeanShift, KMeans, SuperParamagneticClustering, etc). + + + Examples + -------- + >>> spike_sorting = SpikeSorting( + ... feature_extractor=WaveletDecomposition(), + ... clustering_method=sklearn.cluster.MeanShift() + ... ) + >>> label, index = spike_sorting(cutouts, return_index=True) + + + Parameters + ---------- + feature_extractor : SpikeFeatureExtractionProtocol + clustering_method : UnsupervisedFeatureClusteringProtocol + + """ + + def __init__( + self, + feature_extractor: SpikeFeatureExtractionProtocol, + clustering_method: UnsupervisedFeatureClusteringProtocol, + ): + pass + + def __call__(self): + pass + + +# UnsupervisedFeatureClusteringProtocol +class SuperParamagneticClustering: + """Super-Paramagnetic Clustering (SPC) + + The implementation is heavily inspired from [1]_ and [2]_. + + + .. [1] Quiroga RQ, Nadasdy Z, Ben-Shaul Y. Unsupervised spike detection and sorting with wavelets and superparamagnetic clustering. Neural Comput. 2004 Aug;16(8):1661-87. doi: 10.1162/089976604774201631. PMID: 15228749. + .. [2] Fernando J. Chaure, Hernan G. Rey, and Rodrigo Quian Quiroga. A novel and fully automatic spike-sorting implementation with variable number of features. Journal of Neurophysiology 2018 120:4, 1859-1871. https://doi.org/10.1152/jn.00339.2018 + + """ + + pass + + +class PCADecomposition: + """PCA Decomposition + + Other studies that use PCA decomposition: [1]_, [2]_ + + .. [1] G. Hilgen, M. Sorbaro, S. Pirmoradian, J.-O. Muthmann, I. Kepiro, S. Ullo, C. Juarez Ramirez, A. Puente Encinas, A. Maccione, L. Berdondini, V. Murino, D. Sona, F. Cella Zanacchi, E. Sernagor, M.H. Hennig (2016). Unsupervised spike sorting for large scale, high density multielectrode arrays. Cell Reports 18, 2521–2532. bioRxiv: http://dx.doi.org/10.1101/048645. + .. [2] Yger P, Spampinato GL, Esposito E, Lefebvre B, Deny S, Gardella C, Stimberg M, Jetter F, Zeck G, Picaud S, Duebel J, Marre O. A spike sorting toolbox for up to thousands of electrodes validated with ground truth recordings in vitro and in vivo. Elife. 2018 Mar 20;7:e34518. doi: 10.7554/eLife.34518. PMID: 29557782; PMCID: PMC5897014. + + """ + + def __init__(self): + pass + + def project(self, n_features, cutouts): + scaler = StandardScaler() + scaled_cutouts = scaler.fit_transform(cutouts) + + pca = PCA() + pca.fit(scaled_cutouts) + # print(pca.explained_variance_ratio_) + + pca.n_components = 2 + transformed = pca.fit_transform(scaled_cutouts) + + # Clustering + n_components = 3 # Number of clustering components + gmm = GaussianMixture(n_components=n_components, n_init=10) + labels = gmm.fit_predict(transformed) + return labels + + """ + tmp_list = [] + for i in range(n_components): + idx = labels == i + tmp_list.append(timestamps[idx]) + spikestamps_clustered.append(tmp_list) + + _ = plt.figure(figsize=(8, 8)) + for i in range(n_components): + idx = labels == i + _ = plt.plot(transformed[idx, 0], transformed[idx, 1], ".") + _ = plt.title("Cluster assignments by a GMM") + _ = plt.xlabel("Principal Component 1") + _ = plt.ylabel("Principal Component 2") + _ = plt.legend([0, 1, 2]) + _ = plt.axis("tight") + + _ = plt.figure(figsize=(8, 8)) + for i in range(n_components): + idx = labels == i + color = plt.rcParams["axes.prop_cycle"].by_key()["color"][i] + plot_waveforms( + cutouts[idx, :], + rate, + n_spikes=100, + color=color, + ) + # custom legend + custom_lines = [ + plt.Line2D( + [0], + [0], + color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i], + lw=4, + ) + for i in range(n_components) + ] + plt.legend(custom_lines, [f"component {i}" for i in range(n_components)]) + """ + + +class WaveletDecomposition: + """ + Wavelet Decomposition for spike sorting. + The implementation is heavily inspired from [1]_ and [2]_; + their MatLab implementation (wave_clus) can be found `here `_. + + The default setting uses four-level multiresolution decomposition with Haar wavelets. + To learn about possible choice of wavelet, check `PyWavelets module `_. + + Other studies that use wavelet decomposition: [3]_ + + .. [1] Letelier JC, Weber PP. Spike sorting based on discrete wavelet transform coefficients. J Neurosci Methods. 2000 Sep 15;101(2):93-106. doi: 10.1016/s0165-0270(00)00250-8. PMID: 10996370. + .. [2] Quiroga RQ, Nadasdy Z, Ben-Shaul Y. Unsupervised spike detection and sorting with wavelets and superparamagnetic clustering. Neural Comput. 2004 Aug;16(8):1661-87. doi: 10.1162/089976604774201631. PMID: 15228749. + .. [3] Nenadic, Z., and Burdick, J. W. (2005). Spike detection using the continuous wavelet transform. IEEE Trans. BioMed. Eng. 52, 74–87. doi: 10.1109/TBME.2004.839800 + + """ + + def __init__(self): + pass + + def project(self, n_features): + ## Wavelet Decomposition + number_of_spikes = 400 + data_length = 100 + cutouts = np.empty([number_of_spikes, data_length]) + # spikes_l = cutouts[0] + coeffs = pywt.wavedec(cutouts, "haar", level=4) + features = np.concatenate(coeffs, axis=1) + + def test_ks(x): + # Calculates CDF + + # xCDF = [] + yCDF = [] + x = x[~np.isnan(x)] + n = x.shape[0] + x.sort() + + # Get cumulative sums + yCDF = (np.arange(n) + 1) / n + + # Remove duplicates; only need final one with total count + notdup = np.concatenate([np.diff(x), [1]]) > 0 + x_expcdf = x[notdup] + y_expcdf = np.concatenate([[0], yCDF[notdup]]) + + # The theoretical CDF (theocdf) is assumed to ben ormal + # with unknown mean and sigma + zScore = (x_expcdf - x.mean()) / x.std() + # theocdf = normcdf(zScore, 0, 1) + + mu = 0 + sigma = 1 + theocdf = 0.5 * scipy.special.erfc(-(zScore - mu) / (np.sqrt(2) * sigma)) + + # Compute the Maximum distance: max|S(x) - theocdf(x)|. + + delta1 = ( + y_expcdf[:-1] - theocdf + ) # Vertical difference at jumps approaching from the LEFT. + delta2 = ( + y_expcdf[1:] - theocdf + ) # Vertical difference at jumps approaching from the RIGHT. + deltacdf = np.abs(np.concatenate([delta1, delta2])) + + KSmax = deltacdf.max() + return KSmax + + ks = [] + for idx, feature in enumerate(np.moveaxis(features, 1, 0)): + std_feature = np.std(feature) + mean_feature = np.mean(feature) + thr_dist = std_feature * 3 + thr_dist_min = mean_feature - thr_dist + thr_dist_max = mean_feature + thr_dist + aux = feature[ + np.logical_and(feature > thr_dist_min, feature < thr_dist_max) + ] + + if aux.shape[0] > 10: + ks.append(test_ks(aux)) + else: + ks.append(0) + + max_inputs = 0.75 + min_inputs = 10 + + # if all: + # max_inputs = features.shape[1] + if max_inputs < 1: + max_inputs = np.ceil(max_inputs * features.shape[1]).astype(int) + + ind = np.argsort(ks) + A = np.array(ks)[ind] + A = A[A.shape[0] - max_inputs :] # Cutoff coeffs + + ncoeff = A.shape[0] + maxA = A.max() + nd = 10 + d = (A[nd - 1 :] - A[: -nd + 1]) / maxA * ncoeff / nd + all_above1 = d[np.nonzero(d >= 1)] + if all_above1.shape[0] >= 2: + # temp_bla = smooth(diff(all_above1),3) + aux2 = np.diff(all_above1) + temp_bla = np.convolve(aux, np.ones(3) / 3) + temp_bla = temp_bla[1:-1] + temp_bla[0] = aux2[0] + temp_bla[-1] = aux2[-1] + # ask to be above 1 for 3 consecutive coefficients + thr_knee_diff = all_above1[np.nonzero(temp_bla[1:] == 1)[:1]] + nd / 2 + inputs = max_inputs - thr_knee_diff + 1 + else: + inputs = min_inputs + + plot_feature_stats = True + if plot_feature_stats: + fig = plt.figure() + plt.stairs(np.sort(ks)) + plt.plot( + [len(ks) - inputs + 1, len(ks) - inputs + 1], + fig.axes[0].get_ylim(), + "r", + ) + plt.plot( + [len(ks) - max_inputs, len(ks) - max_inputs], + fig.axes[0].get_ylim(), + "--k", + ) + plt.ylabel("ks_stat") + plt.xlabel("# features") + plt.title( + f"number of spikes = {number_of_spikes}, inputs_selected = {inputs}" + ) + + if inputs > max_inputs: + inputs = max_inputs + elif inputs.shape[0] == 0 or inputs < min_inputs: + inputs = min_inputs + + coeff = ind[-inputs:] + # CRATES INPUT MATRIX FOR SPC + input_for_spc = np.zeros((number_of_spikes, inputs)) + + for i in range(number_of_spikes): + for j in range(inputs): + input_for_spc[i, j] = features[i, coeff[j]] diff --git a/miv/statistics/__init__.py b/miv/statistics/__init__.py index 7ee99a30..920f46d3 100644 --- a/miv/statistics/__init__.py +++ b/miv/statistics/__init__.py @@ -1 +1,2 @@ from miv.statistics.summarizer import * +from miv.statistics.utility import * diff --git a/miv/statistics/summarizer.py b/miv/statistics/summarizer.py index c519083e..ba540bdc 100644 --- a/miv/statistics/summarizer.py +++ b/miv/statistics/summarizer.py @@ -1,7 +1,46 @@ -__doc__ = "" -__all__ = ["StatisticsSummary"] +__doc__ = """ -from typing import Any, Optional, Iterable, Union +Statistics Tools +================ + +Spikestamps +----------- + +.. currentmodule:: miv.statistics + +.. autosummary:: + :nosignatures: + :toctree: _toctree/StatisticsAPI + + spikestamps_statistics + +Useful External Packages +======================== + +Here are few external `python` packages that can be used for further statistical analysis. + +scipy statistics +---------------- + +`scipy `_ + +.. autosummary:: + + scipy.stats.describe + +elephant.statistics +------------------- + +`elephant documentation: `_ + +.. autosummary:: + + elephant.statistics.mean_firing_rate + elephant.statistics.instantaneous_rate +""" +__all__ = ["spikestamps_statistics"] + +from typing import Any, Optional, Iterable, Union, Dict import numpy as np import matplotlib.pyplot as plt import datetime @@ -12,25 +51,35 @@ import elephant.statistics -class StatisticsSummary: - """StatisticsSummary.""" - - def __init__(self): - pass - - def spikestamps_summary( - self, - spikestamps: Iterable[neo.core.SpikeTrain], - t_start: Optional[float] = None, - t_stop: Optional[float] = None, - ) -> Iterable[Any]: - rates = elephant.statistics.mean_firing_rate( - spikestamps, t_start, t_stop, axis=0 - ) - rates_mean_over_channel = np.mean(rates) - rates_variance_over_channel = np.var(rates) - return { - "rates": rates, - "mean": rates_mean_over_channel, - "variance": rates_variance_over_channel, - } +# FIXME: For now, we provide the free function for simple usage. For more +# advanced statistical analysis, we should have a module wrapper. +def spikestamps_statistics( + spikestamps: Union[np.ndarray, Iterable[float], Iterable[neo.core.SpikeTrain]], + t_start: Optional[float] = None, + t_stop: Optional[float] = None, +) -> Dict[str, Any]: + """ + Process basic spikestamps statistics: rates, mean, variance. + + Parameters + ---------- + spikestamps : Iterable[neo.core.SpikeTrain] + t_start : Optional[float] + t_stop : Optional[float] + + Returns + ------- + Iterable[Any] + + """ + rates = [ + elephant.statistics.mean_firing_rate(spikestamp, t_start, t_stop, axis=0) + for spikestamp in spikestamps + ] + rates_mean_over_channel = np.mean(rates) + rates_variance_over_channel = np.var(rates) + return { + "rates": rates, + "mean": rates_mean_over_channel, + "variance": rates_variance_over_channel, + } diff --git a/miv/statistics/utility.py b/miv/statistics/utility.py new file mode 100644 index 00000000..b4756e3b --- /dev/null +++ b/miv/statistics/utility.py @@ -0,0 +1,21 @@ +__all__ = ["signal_to_noise"] + +import numpy as np + +from miv.typing import SignalType + + +def signal_to_noise(signal: SignalType, axis: int = 0, ddof: int = 0): + """signal_to_noise. + + Parameters + ---------- + signal : SignalType + axis : int + Axis of interest. By default, signal axis is 0 (default=1) + ddof : int + """ + signal_np = np.asanyarray(signal) + m = signal_np.mean(axis) + sd = signal_np.std(axis=axis, ddof=ddof) + return np.abs(np.where(sd == 0, 0, m / sd)) diff --git a/miv/typing.py b/miv/typing.py index 487330c2..e53b1105 100644 --- a/miv/typing.py +++ b/miv/typing.py @@ -7,7 +7,8 @@ import neo SignalType = Union[ - npt.ArrayLike, np.ndarray, neo.core.AnalogSignal -] # Shape should be [n_channel, signal_length] -TimestampsType = npt.ArrayLike -SpikestampsType = Union[npt.ArrayLike, neo.core.SpikeTrain] + np.ndarray, + neo.core.AnalogSignal, # npt.DTypeLike +] # Shape should be [signal_length, n_channel] +TimestampsType = np.ndarray +SpikestampsType = Union[np.ndarray, neo.core.SpikeTrain] diff --git a/miv/version.py b/miv/version.py index bc4ffb3d..4d389414 100644 --- a/miv/version.py +++ b/miv/version.py @@ -1 +1 @@ -VERSION = "0.0.2" +VERSION = "0.0.3" diff --git a/miv/visualization/__init__.py b/miv/visualization/__init__.py new file mode 100644 index 00000000..336c0548 --- /dev/null +++ b/miv/visualization/__init__.py @@ -0,0 +1,2 @@ +from miv.visualization.fft_domain import * +from miv.visualization.waveform import * diff --git a/miv/visualization/fft_domain.py b/miv/visualization/fft_domain.py new file mode 100644 index 00000000..88b2731a --- /dev/null +++ b/miv/visualization/fft_domain.py @@ -0,0 +1,47 @@ +__all__ = ["plot_frequency_domain"] + +import os +import numpy as np + +from scipy import fftpack +from scipy.signal import welch + +import matplotlib.pyplot as plt + +from miv.typing import SignalType + + +def plot_frequency_domain(signal: SignalType, sampling_rate: float) -> plt.Figure: + """ + Plot DFT frequency domain + + Parameters + ---------- + signal : SignalType + Input signal + sampling_rate : float + Sampling frequency + + Returns + ------- + figure: plt.Figure + + """ + # FFT + fig = plt.figure() + sig_fft = fftpack.fft(signal) + # sample_freq = fftpack.fftfreq(signal.size, d=1 / sampling_rate) + plt.plot(np.abs(sig_fft) ** 2) + plt.xlabel("Frequency [Hz]") + plt.ylabel("DFT frequency") + + # Welch (https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html) + f, Pxx_den = welch(signal, sampling_rate, nperseg=1024) + f_med, Pxx_den_med = welch(signal, sampling_rate, nperseg=1024, average="median") + plt.figure() + plt.semilogy(f, Pxx_den, label="mean") + plt.semilogy(f_med, Pxx_den_med, label="median") + plt.xlabel("frequency [Hz]") + plt.ylabel("PSD [uV**2/Hz]") + plt.legend() + return fig diff --git a/miv/visualization/waveform.py b/miv/visualization/waveform.py new file mode 100644 index 00000000..8dcc9913 --- /dev/null +++ b/miv/visualization/waveform.py @@ -0,0 +1,140 @@ +__doc__ = """ +Module for extracting each spike waveform and visualize. +""" +__all__ = ["extract_waveforms", "plot_waveforms"] + +from typing import Any, Optional, Union, Tuple, Dict + +import os +import numpy as np + +import quantities as pq + +from sklearn.mixture import GaussianMixture +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from scipy.signal import lfilter, savgol_filter + +import matplotlib.pyplot as plt + +import neo +from miv.typing import SignalType, SpikestampsType + +# TODO: Modularize the entire process. + + +def extract_waveforms( + signal: SignalType, + spikestamps: SpikestampsType, + channel: int, + sampling_rate: float, + pre: pq.Quantity = 0.001 * pq.s, + post: pq.Quantity = 0.002 * pq.s, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Extract spike waveforms as signal cutouts around each spike index as a spikes x samples numpy array + + Parameters + ---------- + signal : SignalType + The signal as a 1-dimensional numpy array + spikestamps : SpikestampsType + The sample index of all spikes as a 1-dim numpy array + channel : int + Interested channel + sampling_rate : float + The sampling frequency in Hz + pre : pq.Quantity + The duration of the cutout before the spike in seconds. (default=0.001 s) + post : pq.Quantity + The duration of the cutout after the spike in seconds. (default=0.002 s) + + Returns + ------- + Stack of spike cutout: np.ndarray + Return stacks of spike cutout; shape(n_spikes, width). + + """ + # TODO: Refactor this part + signal = signal[:, channel] + spikestamps = spikestamps[channel] + + cutouts = [] + pre_idx = int(pre * sampling_rate) + post_idx = int(post * sampling_rate) + + # Padding signal + signal = np.pad(signal, ((pre_idx, post_idx),), constant_values=0) + for time in spikestamps: + index = int(round(time * sampling_rate)) + # if index - pre_idx >= 0 and index + post_idx <= signal.shape[0]: + # cutout = signal[(index - pre_idx) : (index + post_idx)] + # cutouts.append(cutout) + cutout = signal[index : (index + post_idx + pre_idx)] + cutouts.append(cutout) + + return np.stack(cutouts) + + +def plot_waveforms( + cutouts: np.ndarray, + sampling_rate: float, + pre: float = 0.001, + post: float = 0.002, + n_spikes: Optional[int] = 100, + color: str = "k", # TODO: change typing to matplotlib color + plot_kwargs: Dict[Any, Any] = None, +) -> plt.Figure: + """ + Plot an overlay of spike cutouts + + Parameters + ---------- + cutouts : np.ndarray + A spikes x samples array of cutouts + sampling_rate : float + The sampling frequency in Hz + pre : float + The duration of the cutout before the spike in seconds + post : float + The duration of the cutout after the spike in seconds + n_spikes : Optional[int] + The number of cutouts to plot. None to plot all. (Default: 100) + color : str + The line color as a pyplot line/marker style. (Default: 'k'=black) + plot_kwargs : Dict[Any, Any] + Addtional keyword-arguments for matplotlib.pyplot.plot. + + Returns + ------- + Figure : plt.Figure + + """ + if n_spikes is None: + n_spikes = cutouts.shape[0] + n_spikes = min(n_spikes, cutouts.shape[0]) + + if not plot_kwargs: + plot_kwargs = {} + + # TODO: Need to match unit + time_in_us = np.arange(-pre * 1000, post * 1000, 1e3 / sampling_rate) + fig = plt.figure(figsize=(12, 6)) + + for i in range(n_spikes): + plt.plot( + time_in_us, + cutouts[ + i, + ], + color, + linewidth=1, + alpha=0.3, + **plot_kwargs + ) + plt.xlabel("Time (ms)") + plt.ylabel("Voltage (uV)") + plt.title("Cutouts") + + return fig diff --git a/requirements.txt b/requirements.txt index 85004cdd..b4745d26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ numpy>=1.19.2 omegaconf pandas Pillow +PyWavelets quantities scikit-learn scipy>=1.5.2 diff --git a/tests/filter/__init__.py b/tests/filter/__init__.py new file mode 100644 index 00000000..57d0fa8f --- /dev/null +++ b/tests/filter/__init__.py @@ -0,0 +1 @@ +from tests.filter.mock_filter import mock_filter_list diff --git a/tests/filter/test_butterworth.py b/tests/filter/test_butterworth.py index f66aacec..ebf2ebe9 100644 --- a/tests/filter/test_butterworth.py +++ b/tests/filter/test_butterworth.py @@ -49,7 +49,7 @@ def test_butterworth_filter_analytical(lowcut, highcut, order, tag, sig, rate, r @pytest.mark.parametrize("lowcut, highcut, order, tag", ParameterSet) -def test_butterworth_repr_string_test(lowcut, highcut, order, tag): +def test_butterworth_repr_string(lowcut, highcut, order, tag): filt = ButterBandpass(lowcut, highcut, order, tag) for v in [lowcut, highcut, order, tag]: assert str(v) in repr(filt) diff --git a/tests/filter/test_filter_collection.py b/tests/filter/test_filter_collection.py index e69de29b..ecd45db3 100644 --- a/tests/filter/test_filter_collection.py +++ b/tests/filter/test_filter_collection.py @@ -0,0 +1,167 @@ +from typing import runtime_checkable +from typing import Protocol, Any + +import pytest + +import numpy as np + +from miv.signal.filter import FilterProtocol, FilterCollection +from tests.filter.mock_filter import mock_filter_list, mock_nonfilter_list + +from tests.filter.test_filter_protocol import RuntimeFilterProtocol + + +def test_empty_filter_protocol_abide(): + empty_filter = FilterCollection() + assert isinstance(empty_filter, RuntimeFilterProtocol) + + +@pytest.mark.parametrize("MockFilter1", mock_filter_list) +@pytest.mark.parametrize("MockFilter2", mock_filter_list) +def test_mock_filter_collection_protocol_abide( + MockFilter1: FilterProtocol, MockFilter2: FilterProtocol +): + filter_collection = FilterCollection().append(MockFilter1()).append(MockFilter2()) + assert isinstance(filter_collection, RuntimeFilterProtocol) + + filter_collection = FilterCollection().append(MockFilter1()).append(MockFilter1()) + assert isinstance(filter_collection, RuntimeFilterProtocol) + + filter_collection = ( + FilterCollection().append(MockFilter1()).insert(0, MockFilter1()) + ) + assert isinstance(filter_collection, RuntimeFilterProtocol) + + filter_collection = ( + FilterCollection().insert(0, MockFilter1()).insert(0, MockFilter1()) + ) + assert isinstance(filter_collection, RuntimeFilterProtocol) + + +def test_filter_collection_representation(): + filter_collection = FilterCollection() + assert "Collection" in repr(filter_collection) + + tag = "whats the point of doing a PhD?" + filter_collection = FilterCollection(tag=tag) + assert tag in repr(filter_collection) + + +class TestFilterCollectionMutableSequence: + @pytest.fixture(scope="class") + def load_collection(self): + flt = FilterCollection() + # Bypass check, but its fine for testing + flt.append(3) + flt.append(5.0) + flt.append("a") + flt.append(np.random.randn(3, 5)) + return flt + + def test_len(self, load_collection): + assert len(load_collection) == 4 + + def test_getitem(self, load_collection): + assert load_collection[0] == 3 + assert load_collection[2] == "a" + + @pytest.mark.xfail + def test_getitem_with_faulty_index_fails(self, load_collection): + # Fails and exception is raised + load_collection[100] + + @pytest.mark.xfail + def test_setitem_with_faulty_index_fails(self, load_collection): + # If this fails, an exception is raised + # and pytest automatically fails + load_collection[200] = 1.0 + + def test_setitem(self, load_collection): + # If this fails, an exception is raised + # and pytest automatically fails + load_collection[3] = 1.0 + + def test_insert(self, load_collection): + load_collection.insert(3, "ss") + assert load_collection[3] == "ss" + load_collection.insert(1, 1.0) + assert np.isclose(load_collection[1], 1.0) + assert load_collection[4] == "ss" + + def test_str(self, load_collection): + assert str(load_collection[0]) == "3" + + @pytest.mark.xfail + def test_delitem(self, load_collection): + del load_collection[0] + assert load_collection[0] == 3 + + +class TestFilterCollectionFunctionality: + @pytest.mark.parametrize("sampling_rate", [20, 50, 0]) + def test_empty_filter_collection_bypass(self, sampling_rate): + flt = FilterCollection() + test_signal = np.random.random([2, 50]) + filtered_signal = flt(test_signal, sampling_rate) + np.testing.assert_allclose(test_signal, filtered_signal) + + @pytest.fixture(scope="function") + def mock_filter(self): + from tests.filter import mock_filter_list + + MockFilter = type("MockFilter", (mock_filter_list[0],), {}) + return MockFilter() + + +class TestFilterCollectionCompatibility: + """ + Collection of compatibility test of FilterCollection with other filter modules + """ + + def test_filter_collection_with_butterworth_io_shape(self): + from miv.signal.filter import ButterBandpass + + flt = ( + FilterCollection() + .append(ButterBandpass(1, 2)) + .append(ButterBandpass(2, 3)) + .append(ButterBandpass(3, 4)) + ) + sig = np.random.random([2, 50]) + filt_sig = flt(sig, sampling_rate=1000) + np.testing.assert_equal(sig.shape, filt_sig.shape) + + from tests.filter.test_butterworth import ( + AnalyticalTestSet as ButterworthAnalyticalTestSet, + ) + from tests.filter.test_butterworth import ParameterSet as ButterworthParameterSet + + @pytest.mark.parametrize("lowcut, highcut, order, tag", ButterworthParameterSet[:1]) + @pytest.mark.parametrize("sig, rate, result", ButterworthAnalyticalTestSet) + def test_filter_collection_with_butterworth_value( + self, lowcut, highcut, order, tag, sig, rate, result + ): + from miv.signal.filter import ButterBandpass + + filt = FilterCollection().append(ButterBandpass(lowcut, highcut, order, tag)) + ans = filt(signal=sig, sampling_rate=rate) + np.testing.assert_allclose(ans, result) + + @pytest.mark.parametrize("lowcut, highcut, order, tag", ButterworthParameterSet) + def test_filter_collection_repr_string_with_butterworth( + self, lowcut, highcut, order, tag + ): + from miv.signal.filter import ButterBandpass + + filt = FilterCollection().append(ButterBandpass(lowcut, highcut, order, tag)) + for v in [lowcut, highcut, order, tag]: + assert str(v) in repr(filt) + + +class TestFilterCollectionIntegration: + """ + Collection of integration test of FilterCollection with other modules + """ + + def test_filter_collection_operate_on_dataset(self): + pass # TODO diff --git a/tests/io/__init__.py b/tests/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/io/mock_continuous_signal.py b/tests/io/mock_continuous_signal.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/io/test_binary_io.py b/tests/io/test_binary_io.py new file mode 100644 index 00000000..508ea40f --- /dev/null +++ b/tests/io/test_binary_io.py @@ -0,0 +1,156 @@ +import pytest + +import tempfile +import os +import numpy as np + +from miv.io.binary import ( + load_continuous_data, + load_recording, + oebin_read, + apply_channel_mask, +) + + +@pytest.mark.parametrize("signal", [np.arange(10), np.array([])]) +def test_apply_channel_mask_shape_failure(signal): + with pytest.raises(IndexError): + apply_channel_mask(signal, {0}) + + +@pytest.mark.parametrize("signal", [5, [1, 2, 3], (1, 5, 9)]) +def test_apply_channel_mask_non_numpy_ndtype(signal): + with pytest.raises(AttributeError): + apply_channel_mask(signal, {0}) + + +@pytest.mark.parametrize( + "signal, mask, solution", + [ + (np.arange(9).reshape([3, 3]), {0}, np.arange(9).reshape([3, 3])[:, [1, 2]]), + (np.arange(9).reshape([3, 3]), {1}, np.arange(9).reshape([3, 3])[:, [0, 2]]), + (np.arange(9).reshape([3, 3]), {0, 1}, np.arange(9).reshape([3, 3])[:, [2]]), + ( + np.arange(9).reshape([3, 3]), + {1, -1, 5}, + np.arange(9).reshape([3, 3])[:, [0, 2]], + ), + ], +) +def test_apply_channel_mask_functionality(signal, mask, solution): + output = apply_channel_mask(signal, mask) + np.testing.assert_allclose(output, solution) + + +def test_oebin_read_functionality(): + a = {"a": 1, "b": "string data", "c": True} + with tempfile.NamedTemporaryFile("w+t", delete=False) as fp: + fp.writelines(str(a)) + fp.seek(0) + b = oebin_read(fp.name) + assert a == b + + +@pytest.mark.parametrize("num_channels, signal_length", [(4, 100), (1, 50), (10, 5)]) +def test_load_continuous_data_temp_file_without_timestamps(num_channels, signal_length): + signal = np.arange(signal_length * num_channels).reshape( + [signal_length, num_channels] + ) + filename = os.path.join(tempfile.mkdtemp(), "continuous.dat") + fp = np.memmap(filename, dtype="int16", mode="w+", shape=signal.shape) + fp[:] = signal[:] + fp.flush() + + raw_data, timestamps = load_continuous_data(fp.filename, num_channels, 1) + np.testing.assert_allclose(timestamps, np.arange(signal_length)) + np.testing.assert_allclose(raw_data, signal) + + +@pytest.mark.parametrize( + "num_channels, signal_length, freq", [(4, 100, 1), (1, 50, 5), (10, 5, 2)] +) +def test_load_continuous_data_temp_file_with_timestamps_shift( + num_channels, signal_length, freq +): + signal = np.arange(signal_length * num_channels).reshape( + [signal_length, num_channels] + ) + dirname = tempfile.mkdtemp() + filename = os.path.join(dirname, "continuous.dat") + timestamps_filename = os.path.join(dirname, "timestamps.npy") + # Prepare continuous.dat + fp = np.memmap(filename, dtype="int16", mode="w+", shape=signal.shape) + fp[:] = signal[:] + fp.flush() + # Prepare timestamps.npy + timestamps = np.arange(signal_length) + np.pi + np.save(timestamps_filename, timestamps) + + # With shift + raw_data, out_timestamps = load_continuous_data( + fp.filename, num_channels, freq, start_at_zero=False + ) + np.testing.assert_allclose(out_timestamps, timestamps / freq) + np.testing.assert_allclose(raw_data, signal) + + # Without shift + raw_data, out_timestamps = load_continuous_data( + fp.filename, num_channels, freq, start_at_zero=True + ) + np.testing.assert_allclose(out_timestamps, (timestamps - np.pi) / freq) + np.testing.assert_allclose(raw_data, signal) + + +@pytest.mark.parametrize( + "num_channels, signal_length, freq", [(4, 100, 1), (1, 50, 5), (10, 5, 2)] +) +def test_load_continuous_data_temp_file_timestamps_path_test( + num_channels, signal_length, freq +): + signal = np.arange(signal_length * num_channels).reshape( + [signal_length, num_channels] + ) + dirname = tempfile.mkdtemp() + filename = os.path.join(dirname, "continuous.dat") + timestamps_filename = os.path.join(dirname, "a.npy") + # Prepare continuous.dat + fp = np.memmap(filename, dtype="int16", mode="w+", shape=signal.shape) + fp[:] = signal[:] + fp.flush() + # Prepare timestamps.npy + timestamps = np.arange(signal_length) + np.save(timestamps_filename, timestamps) + + # With shift + raw_data, out_timestamps = load_continuous_data( + fp.filename, num_channels, freq, "a.npy" + ) + np.testing.assert_allclose(out_timestamps, timestamps / freq) + np.testing.assert_allclose(raw_data, signal) + + +@pytest.mark.parametrize( + "num_channels, signal_length, freq", [(4, 100, 1), (1, 50, 5), (10, 5, 2)] +) +def test_load_recording_assertion_single_data_file(num_channels, signal_length, freq): + signal = np.arange(signal_length * num_channels).reshape( + [signal_length, num_channels] + ) + + dirname = tempfile.mkdtemp() + os.makedirs(os.path.join(dirname, "continuous", "temp1")) + os.makedirs(os.path.join(dirname, "continuous", "temp2")) + filename1 = os.path.join(dirname, "continuous", "temp1", "continuous.dat") + filename2 = os.path.join(dirname, "continuous", "temp2", "continuous.dat") + # Prepare continuous.dat + fp1 = np.memmap(filename1, dtype="int16", mode="w+", shape=signal.shape) + fp2 = np.memmap(filename2, dtype="int16", mode="w+", shape=signal.shape) + fp1[:] = 1.0 + fp2[:] = 2.0 + fp1.flush() + fp2.flush() + + with pytest.raises( + AssertionError, match=r"(?=.*temp1.*)(?=.*temp2.*)(?=There should be only one)" + ): + load_recording(dirname) diff --git a/tests/io/test_data_single_module.py b/tests/io/test_data_single_module.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/statistics/__init__.py b/tests/statistics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/statistics/test_spikestamps_statistics.py b/tests/statistics/test_spikestamps_statistics.py new file mode 100644 index 00000000..d24dfba1 --- /dev/null +++ b/tests/statistics/test_spikestamps_statistics.py @@ -0,0 +1,22 @@ +import pytest + +import numpy as np + +from miv.statistics import spikestamps_statistics +from neo.core import SpikeTrain +import quantities as pq + +SpikestampsTestSet = [ + [[1, 2, 3]], + [[1, 2, 3], [3, 6, 9, 12]], + [SpikeTrain([4, 8, 12], units=pq.s, t_stop=120)], +] +TrueRates = [1, [1, 1.0 / 3], 1.0 / 40] + + +@pytest.mark.parametrize("spikestamps, true_rate", zip(SpikestampsTestSet, TrueRates)) +def test_spikestamps_statistics_base_function(spikestamps, true_rate): + result = spikestamps_statistics(spikestamps) + np.testing.assert_allclose(result["rates"], true_rate) + assert np.isclose(result["mean"], np.mean(true_rate)) + assert np.isclose(result["variance"], np.var(true_rate)) diff --git a/tests/visualization/__init__.py b/tests/visualization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/visualization/test_fft.py b/tests/visualization/test_fft.py new file mode 100644 index 00000000..b371874a --- /dev/null +++ b/tests/visualization/test_fft.py @@ -0,0 +1,5 @@ +import pytest + +import numpy as np + +from miv.visualization import plot_frequency_domain diff --git a/tests/visualization/test_waveform.py b/tests/visualization/test_waveform.py new file mode 100644 index 00000000..da6dce3f --- /dev/null +++ b/tests/visualization/test_waveform.py @@ -0,0 +1,5 @@ +import pytest + +import numpy as np + +from miv.visualization import extract_waveforms, plot_waveforms