Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- finish refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Dec 19, 2018
1 parent b96bf92 commit a4d2ebe
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/plot_train_side_loss_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from inferno.trainers.basic import Trainer

from inferno.extensions.layers.convolutional import Conv2D
from inferno.extensions.model.res_unet import _ResBlock as ResBlock
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models.res_unet import _ResBlock as ResBlock
from inferno.extensions.models import ResBlockUNet
from inferno.utils.torch_utils import unwrap
from inferno.utils.python_utils import ensure_dir
import pylab
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_unet_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def label_transform(x):
# With :code:`activated=False` we make sure that the last layer
# is not activated since we chain the UNet with a sigmoid
# activation function.
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
from inferno.extensions.layers import RemoveSingletonDimension

model = torch.nn.Sequential(
Expand Down Expand Up @@ -199,7 +199,7 @@ def predict(trainer, test_loader, save_dir=None):
# a rather exotic UNet which uses different types
# of convolutions/non-linearities in the different branches
# of the unet
from inferno.extensions.model import UNetBase
from inferno.extensions.models import UNetBase
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D

class MySimple2DUnet(UNetBase):
Expand Down
2 changes: 2 additions & 0 deletions tests/extensions/containers/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from functools import reduce
import torch


class TestGraph(unittest.TestCase):
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_graph_basic(self):
model.add_output_node('output_0', previous='conv1')
ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model)

@unittest.skipUnless(torch.cuda.is_available(), "No cuda.")
def test_graph_device_transfers(self):
from inferno.extensions.containers.graph import Graph
from inferno.extensions.layers.convolutional import ConvELU2D
Expand Down
10 changes: 5 additions & 5 deletions tests/extensions/model/res_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@

class ResUNetTest(unittest.TestCase):
def test_res_unet_2d(self):
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
if cuda.is_available():
tester.cuda()
tester(ResBlockUNet(in_channels=1, out_channels=1, dim=2))

def test_res_unet_3d(self):
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
if cuda.is_available():
tester.cuda()
# test default unet 3d
tester(ResBlockUNet(in_channels=1, out_channels=1, dim=3))

def test_2d_side_out_bot_up(self):
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
depth = 3
in_channels = 3

Expand All @@ -40,7 +40,7 @@ def test_2d_side_out_bot_up(self):
self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32])

def test_2d_side_out_up(self):
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
depth = 3
in_channels = 3

Expand All @@ -58,7 +58,7 @@ def test_2d_side_out_up(self):
self.assertEqual(list(out_list[2].size()), [1, 8, 64, 32])

def test_2d_side_out_down(self):
from inferno.extensions.model import ResBlockUNet
from inferno.extensions.models import ResBlockUNet
depth = 3
in_channels = 3

Expand Down
4 changes: 2 additions & 2 deletions tests/extensions/model/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

class UNetTest(unittest.TestCase):
def test_unet_2d(self):
from inferno.extensions.model import UNet
from inferno.extensions.models import UNet
tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
if cuda.is_available():
tester.cuda()
tester(UNet(1, 1, dim=2, initial_features=32))

def test_unet_3d(self):
from inferno.extensions.model import UNet
from inferno.extensions.models import UNet
tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
if cuda.is_available():
tester.cuda()
Expand Down

0 comments on commit a4d2ebe

Please sign in to comment.