Skip to content

Commit

Permalink
add pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
guanhuaw committed Jul 28, 2024
1 parent 85c7722 commit 675cdf7
Show file tree
Hide file tree
Showing 35 changed files with 5,702 additions and 476 deletions.
50 changes: 32 additions & 18 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,49 @@ name: Python-CI
on:
push:
branches:
- main
- master
- develop
- feature/*
pull_request:
branches:
- main
- master
- develop
- feature/*

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x' # Specify the Python version you need

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install ruff pytest
- name: Lint with Ruff
run: |
ruff ./mirtorch
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10" # Specify the Python version you need

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install ruff pytest
- name: Lint with Ruff
run: |
ruff ./mirtorch
- name: Test with pytest
run: |
pytest ./tests
- name: Automated Version Bump
if: github.ref == 'refs/heads/master'
uses: phips28/gh-action-bump-version@master
with:
tag-prefix: "v" # or an empty string if you don't want a prefix
filename: "pyproject.toml"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

- name: Push changes
if: github.ref == 'refs/heads/master'
run: git push && git push --tags
29 changes: 19 additions & 10 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,23 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v3
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
name: python-package-distributions
path: dist/
- name: Download all the dists
uses: actions/download-artifact@v3
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ mrt/
docs/_build
docs/_static
docs/_templates
.ruff_cache
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ repos:
rev: v2.1.0
hooks:
- id: codespell
exclude: ^(?:tests|docs|examples)/
exclude: ^(?:tests|docs|examples|mirtorch/vendors)/
29 changes: 29 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,32 @@ Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.

----------------------------- LICENSE for pytorch_wavelets----------------------------
This licence applies to any parts of this library which are novel in comparison
to the original DTCWT MATLAB toolbox written by Nick Kingsbury and Cian
Shaffrey. See the Provenance section of README.rst file for details on any further
restrictions of use. If you wish to use the DTCWT, you should read that license as well.
The DWT sections come under this license.

MIT License

Copyright (c) 2020 Fergal Cotter

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Documentation: https://mirtorch.readthedocs.io/en/latest/
### Installation

We recommend to [pre-install `PyTorch` first](https://pytorch.org/).
To install the `MIRTorch` package, after cloning the repo, please try `pip install -e .`(one may modify the package locally with this option.)
Use `pip install mirtorch` to install.
To install the `MIRTorch` locally, after cloning the repo, please try `pip install -e .`(one may modify the package locally with this option.)

------

Expand Down
4 changes: 4 additions & 0 deletions mirtorch/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
ConjTranspose,
BlockDiagonal,
Kron,
Vstack,
Hstack,
)
from .basics import (
Diff1d,
Expand Down Expand Up @@ -51,4 +53,6 @@
"Gmri",
"GmriGram",
"Sense",
"Vstack",
"Hstack",
]
90 changes: 80 additions & 10 deletions mirtorch/linear/linearmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@ def __repr__(self):
f"<LinearMap {self.__class__.__name__} of {self.size_out}x{self.size_in}>"
)

def __call__(self, x) -> Tensor:
def __call__(self, x: Tensor) -> Tensor:
# for a instance A, we can apply it by calling A(x). Equal to A*x
return self.apply(x)

def _apply(self, x) -> Tensor:
def _apply(self, x: Tensor) -> Tensor:
# worth noting that the function here should be differentiable,
# for example, composed of native torch functions,
# or torch.autograd.Function, or nn.module
raise NotImplementedError

def _apply_adjoint(self, x) -> Tensor:
def _apply_adjoint(self, x: Tensor) -> Tensor:
raise NotImplementedError

def apply(self, x) -> Tensor:
def apply(self, x: Tensor) -> Tensor:
r"""
Apply the forward operator
"""
Expand All @@ -85,7 +85,7 @@ def apply(self, x) -> Tensor:
), f"Shape of input data {x.shape} and forward linear op {self.size_in} do not match!"
return self._apply(x)

def adjoint(self, x) -> Tensor:
def adjoint(self, x: Tensor) -> Tensor:
r"""
Apply the adjoint operator
"""
Expand All @@ -109,7 +109,7 @@ def __add__(self: LinearMap, other: LinearMap) -> LinearMap:

def __mul__(
self: LinearMap, other: Union[str, int, LinearMap, Tensor]
) -> LinearMap:
) -> Union[LinearMap, Tensor]:
r"""
Reload the * symbol.
"""
Expand Down Expand Up @@ -347,10 +347,80 @@ def _apply_adjoint(self, x: Tensor):


class Vstack(LinearMap):
# TODO
pass
r"""
Vertical stacking of linear operators.
.. math::
[A1; A2; ...; An] * x = [A1(x); A2(x); ...; An(x)]
Attributes:
A: List of LinearMaps to be stacked vertically
dim: the dimension along which to stack the LinearMaps
"""

def __init__(self, A: List[LinearMap], dim: int = 0):
self.A = A

# Check that all input sizes are the same
assert all(
[A[i].size_in == A[0].size_in for i in range(len(A))]
), "All input sizes must be the same"

# Calculate the total output size
size_out = [sum(A[i].size_out[0] for i in range(len(A)))] + list(
A[0].size_out[1:]
)

self.dim = dim

super().__init__(A[0].size_in, size_out)

def _apply(self, x: Tensor) -> Tensor:
return torch.cat([A_i(x) for A_i in self.A], dim=self.dim)

def _apply_adjoint(self, x: Tensor) -> Tensor:
outputs = []
start = 0
for A_i in self.A:
end = start + A_i.size_out[0]
outputs.append(A_i.H(x[start:end]))
start = end
return sum(outputs)


class Hstack(LinearMap):
# TODO
pass
r"""
Horizontal stacking of linear operators.
.. math::
[A1, A2, ..., An] * [x1; x2; ...; xn] = A1(x1) + A2(x2) + ... + An(xn)
Attributes:
A: List of LinearMaps to be stacked horizontally
"""

def __init__(self, A: List[LinearMap], dim: int = 0):
self.A = A

# Check that all output sizes are the same
assert all(
[A[i].size_out == A[0].size_out for i in range(len(A))]
), "All output sizes must be the same"

# Calculate the total input size
size_in = [sum(A[i].size_in[0] for i in range(len(A)))] + list(A[0].size_in[1:])
self.dim = dim

super().__init__(size_in, A[0].size_out)

def _apply(self, x: Tensor) -> Tensor:
outputs = []
start = 0
for A_i in self.A:
end = start + A_i.size_in[0]
outputs.append(A_i(x[start:end]))
start = end
return sum(outputs)

def _apply_adjoint(self, x: Tensor) -> Tensor:
return torch.cat([A_i.H(x) for A_i in self.A], dim=self.dim)
17 changes: 5 additions & 12 deletions mirtorch/linear/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import math
from typing import Union, List
from typing import Union, List, Tuple

import numpy as np
import torch
Expand All @@ -31,24 +31,19 @@ def __init__(
self,
size_in: List[int],
size_out: List[int],
dims: Union[int, List[int]] | None = None,
dims: Tuple[int] | None = None,
norm: str = "ortho",
):
super(FFTCn, self).__init__(size_in, size_out)
self.norm = norm
self.dims = dims

@torch.jit.script
def fwd(x: Tensor, dims: Union[int, List[int]], norm: str) -> Tensor:
x = ifftshift(x, dims)
x = fftn(x, dim=dims, norm=norm)
x = fftshift(x, dims)
return x

def _apply(self, x: Tensor) -> Tensor:
x = ifftshift(x, self.dims)
x = fftn(x, dim=self.dims, norm=self.norm)
x = fftshift(x, self.dims)
return x

@torch.jit.script
def _apply_adjoint(self, x: Tensor) -> Tensor:
x = ifftshift(x, self.dims)
if self.norm == "ortho":
Expand Down Expand Up @@ -101,7 +96,6 @@ def __init__(
self.smaps = smaps
self.batchmode = batchmode

@torch.jit.script
def _apply(self, x: Tensor) -> Tensor:
r"""
Args:
Expand All @@ -115,7 +109,6 @@ def _apply(self, x: Tensor) -> Tensor:
k = fftshift(k, self.dims) * self.masks
return k

@torch.jit.script
def _apply_adjoint(self, k: Tensor) -> Tensor:
r"""
Args:
Expand Down
6 changes: 3 additions & 3 deletions mirtorch/linear/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List
from typing import Union, Tuple, List

import numpy as np
import torch
Expand Down Expand Up @@ -39,7 +39,7 @@ def finitediff_adj(y: Tensor, dim: int = -1, mode="reflexive"):
Returns:
y: the first-order finite difference of x
"""
if mode == "reflexibe":
if mode == "reflexive":
len_dim = y.shape[dim]
return torch.cat(
(
Expand Down Expand Up @@ -105,7 +105,7 @@ def fftshift(x: Tensor, dims: Union[int, List[int]] | None = None):
return torch.roll(x, shifts, dims)


def ifftshift(x: Tensor, dims: Union[int, List[int]] | None = None):
def ifftshift(x: Tensor, dims: Union[int, Tuple[int]] | None = None):
"""
Similar to np.fft.ifftshift but applies to PyTorch tensors. From fastMRI code.
"""
Expand Down
2 changes: 1 addition & 1 deletion mirtorch/linear/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Sequence, Tuple, List

import torch
from pytorch_wavelets import DWTForward, DWTInverse
from mirtorch.vendors.pytorch_wavelets import DWTForward, DWTInverse
from torch import Tensor

from .linearmaps import LinearMap
Expand Down
Loading

0 comments on commit 675cdf7

Please sign in to comment.