-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev: Added suggestions and tested network forward result.
- Loading branch information
1 parent
1680dc0
commit 6009b0d
Showing
7 changed files
with
116 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,5 +11,5 @@ | |
"ED_TCN", | ||
"tweetynet", | ||
"TweetyNet", | ||
"Ava" | ||
"Ava", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,124 @@ | ||
from __future__ import annotations | ||
# from __future__ import annotations | ||
|
||
import torch | ||
from torch import nn | ||
from torch.distributions import LowRankMultivariateNormal | ||
from typing import Tuple | ||
|
||
# Is it necessary to put this in src.vak.nn.modules? | ||
class BottleneckLayer(nn.Module): | ||
def __init__(self, dims): | ||
super().__init__() | ||
self.layer = nn.Sequential( | ||
nn.Linear(dims[0], dims[1]), | ||
nn.ReLU(), | ||
nn.Linear(dims[1], dims[2])) | ||
|
||
def forward(self, x): | ||
return self.layer(x) | ||
|
||
class Ava(nn.Module): | ||
""" | ||
""" | ||
def __init__( | ||
self, | ||
hidden_dims: list[int] = [8, 8, 16, 16, 24, 24, 32], | ||
fc_dims: list[int] = [1024, 256, 64, 32], | ||
in_channels: int = 1, | ||
in_fc: int = 8192, | ||
x_shape: tuple = (128, 128) | ||
hidden_dims: Tuple[int] = (8, 8, 16, 16, 24, 24), | ||
fc_dims: Tuple[int] = (1024, 256, 64), | ||
z_dim: int = 32, | ||
in_channels: int = 1, | ||
x_shape: Tuple[int] = (128, 128) | ||
): | ||
""" | ||
""" | ||
super().__init__() | ||
self.in_fc = in_fc | ||
self.in_channels = in_channels | ||
self.x_shape = x_shape | ||
self.x_dim = torch.prod(x_shape) | ||
modules = [] | ||
for h_dim in hidden_dims: | ||
stride = 2 if h_dim == in_channels else 1 | ||
fc_dims = (*fc_dims, z_dim) | ||
hidden_dims = (*hidden_dims, z_dim) | ||
|
||
self.in_channels = in_channels | ||
self.fc_view = (int(fc_dims[-1]),int(fc_dims[-1]/2),int(fc_dims[-1]/2)) | ||
self.x_shape = torch.tensor(x_shape) | ||
self.x_dim = torch.prod(self.x_shape) | ||
self.in_fc = int(self.x_dim / 2) | ||
in_fc = self.in_fc | ||
modules = [] | ||
for h_dim in hidden_dims: | ||
stride = 2 if h_dim == in_channels else 1 | ||
modules.append( | ||
nn.Sequential( | ||
nn.BatchNorm2d(in_channels), | ||
nn.BatchNorm2d(in_channels), | ||
nn.Conv2d(in_channels, out_channels=h_dim, | ||
kernel_size=3, stride=stride, padding=1), | ||
nn.ReLU()) | ||
) | ||
in_channels = h_dim | ||
self.encoder = nn.Sequential(*modules) | ||
modules = [] | ||
for fc_dim in fc_dims[:-2]: | ||
self.encoder = nn.Sequential(*modules) | ||
modules = [] | ||
for fc_dim in fc_dims[:-2]: | ||
modules.append( | ||
nn.Sequential( | ||
nn.Linear(in_fc, fc_dim), | ||
nn.Linear(in_fc, fc_dim), | ||
nn.ReLU()) | ||
) | ||
in_fc = fc_dim | ||
self.encoder_bottleneck = nn.Sequential(*modules) | ||
|
||
self.mu_layer = nn.Sequential( | ||
nn.Linear(fc_dims[-3], fc_dims[-2]), | ||
nn.ReLU(), | ||
nn.Linear(fc_dims[-2], fc_dims[-1])) | ||
|
||
self.u_layer = nn.Sequential( | ||
nn.Linear(fc_dims[-3], fc_dims[-2]), | ||
nn.ReLU(), | ||
nn.Linear(fc_dims[-2], fc_dims[-1])) | ||
|
||
self.d_layer = nn.Sequential( | ||
nn.Linear(fc_dims[-3], fc_dims[-2]), | ||
nn.ReLU(), | ||
nn.Linear(fc_dims[-2], fc_dims[-1])) | ||
|
||
fc_dims.reverse() | ||
modules = [] | ||
for i in range(len(fc_dims)): | ||
out = self.fc_in if i == len(fc_dims) else fc_dims[i+1] | ||
self.encoder_bottleneck = nn.Sequential(*modules) | ||
self.mu_layer = BottleneckLayer(fc_dims[-3:]) | ||
self.cov_factor_layer = BottleneckLayer(fc_dims[-3:]) | ||
self.cov_diag_layer = BottleneckLayer(fc_dims[-3:]) | ||
fc_dims = fc_dims[::-1] | ||
modules = [] | ||
for i in range(len(fc_dims)): | ||
out = self.in_fc if i == len(fc_dims) - 1 else fc_dims[i+1] | ||
modules.append( | ||
nn.Sequential( | ||
nn.Linear(fc_dims[i], out), | ||
nn.Linear(fc_dims[i], out), | ||
nn.ReLU()) | ||
) | ||
self.decoder_bottleneck = nn.Sequential(*modules) | ||
|
||
hidden_dims.reverse() | ||
modules = [] | ||
for i, h_dim in enumerate(hidden_dims): | ||
stride = 2 if h_dim == in_channels else 1 | ||
output_padding = 1 if h_dim == in_channels else 0 | ||
modules.append( | ||
nn.Sequential( | ||
nn.BatchNorm2d(in_channels), | ||
nn.ConvTranspose2d(in_channels, out_channels=h_dim, | ||
kernel_size=3, stride=stride, padding=1, output_padding=output_padding), | ||
nn.ReLU() if i != len(hidden_dims)) | ||
) | ||
self.decoder_bottleneck = nn.Sequential(*modules) | ||
hidden_dims = ( *hidden_dims[-2::-1], self.in_channels) | ||
hidden_dims | ||
modules = [] | ||
for i, h_dim in enumerate(hidden_dims): | ||
stride = 2 if h_dim == in_channels else 1 | ||
output_padding = 1 if h_dim == in_channels else 0 | ||
layers = [ nn.BatchNorm2d(in_channels), | ||
nn.ConvTranspose2d(in_channels, out_channels=h_dim, kernel_size=3, stride=stride, padding=1, output_padding=output_padding)] | ||
if i != len(hidden_dims) - 1: | ||
layers.append(nn.ReLU()) | ||
|
||
modules.append( nn.Sequential(*layers) ) | ||
in_channels = h_dim | ||
|
||
self.decoder = nn.Sequential(*modules) | ||
self.decoder = nn.Sequential(*modules) | ||
|
||
def encode(self, x): | ||
""" | ||
""" | ||
x = self.encoder(x.unsqueeze(self.in_channels)).view(-1, self.in_fc) | ||
x = self.encoder_bottleneck(x) | ||
mu = self.mu_layer(x) | ||
u = self.u_layer(x).unsqueeze(-1) | ||
d = torch.exp(self.d_layer(x)) | ||
z, latent_dist = self.reparametrize(mu, u, d) | ||
return z, latent_dist | ||
""" | ||
""" | ||
x = self.encoder(x.unsqueeze(self.in_channels)).view(-1, self.in_fc) | ||
x = self.encoder_bottleneck(x) | ||
mu = self.mu_layer(x) | ||
cov_factor = self.cov_factor_layer(x).unsqueeze(-1) | ||
cov_diag = torch.exp(self.cov_diag_layer(x)) | ||
z, latent_dist = self.reparametrize(mu, cov_factor, cov_diag) | ||
return z, latent_dist | ||
|
||
|
||
def decode(self, z): | ||
""" | ||
""" | ||
z = self.decoder_bottleneck(z).view(-1,32,16,16) | ||
z = self.decoder(z).view(-1, x_dim) | ||
return z | ||
|
||
def reparametrize(self, mu, u, d): | ||
latent_dist = LowRankMultivariateNormal(mu, u, d) | ||
z = latent_dist.rsample() | ||
""" | ||
""" | ||
z = self.decoder_bottleneck(z).view(-1, self.fc_view[0], self.fc_view[1], self.fc_view[2]) | ||
z = self.decoder(z).view(-1, self.x_dim) | ||
return z | ||
|
||
@staticmethod | ||
def reparametrize(mu, cov_factor, cov_diag): | ||
latent_dist = LowRankMultivariateNormal(mu, cov_factor, cov_diag) | ||
z = latent_dist.rsample() | ||
return z, latent_dist | ||
|
||
|
||
def forward(self, x, return_latent_rec=False): | ||
z, latent_dist = self.encode(x) | ||
x_rec = self.decode(z) | ||
return x_rec, {'z': z, 'latent_dist': latent_dist,} | ||
def forward(self, x): | ||
z, latent_dist = self.encode(x) | ||
x_rec = self.decode(z).view(-1, self.x_shape[0], self.x_shape[1]) | ||
return x_rec, z, latent_dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
from .dice import DiceLoss, dice_loss | ||
from .umap import UmapLoss, umap_loss | ||
from .vae import VaeLoss, vae_loss | ||
from .vae import VaeElboLoss, vae_elbo_loss | ||
|
||
|
||
__all__ = [ | ||
"DiceLoss", | ||
"dice_loss", | ||
"UmapLoss", | ||
"umap_loss", | ||
"VaeLoss", | ||
"vae_loss" | ||
"VaeElboLoss", | ||
"vae_elbo_loss" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters