Skip to content

Commit

Permalink
dev: Added suggestions and tested network forward result.
Browse files Browse the repository at this point in the history
  • Loading branch information
marisbasha authored and NickleDave committed Sep 29, 2023
1 parent 1680dc0 commit 6009b0d
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 95 deletions.
8 changes: 4 additions & 4 deletions src/vak/models/ava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from .. import metrics, nets
from .decorator import model
from .vae_model import VAEModel
from ..nn.loss import VaeLoss
from ..nn.loss import VaeElboLoss

@model(family=VAEModel)
class AVA:
"""
"""
network = Ava
loss = VaeLoss
network = nets.Ava
loss = VaeElboLoss
optimizer = torch.optim.Adam
metrics = {
"loss": VaeLoss,
"loss": VaeElboLoss,
"kl": torch.nn.functional.kl_div
}
default_config = {"optimizer": {"lr": 0.003}}
3 changes: 1 addition & 2 deletions src/vak/models/vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def training_step(self, batch: tuple, batch_idx: int):
"""
"""
x = batch[0]
out, _ = self.network(x)
z, latent_dist = itemgetter('z', 'latent_dist')(_)
out, z, latent_dist= self.network(x)
loss = self.loss(x, z, out, latent_dist)
self.log("train_loss", loss)
return loss
Expand Down
2 changes: 1 addition & 1 deletion src/vak/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
"ED_TCN",
"tweetynet",
"TweetyNet",
"Ava"
"Ava",
]
163 changes: 84 additions & 79 deletions src/vak/nets/ava.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,124 @@
from __future__ import annotations
# from __future__ import annotations

import torch
from torch import nn
from torch.distributions import LowRankMultivariateNormal
from typing import Tuple

# Is it necessary to put this in src.vak.nn.modules?
class BottleneckLayer(nn.Module):
def __init__(self, dims):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(dims[0], dims[1]),
nn.ReLU(),
nn.Linear(dims[1], dims[2]))

def forward(self, x):
return self.layer(x)

class Ava(nn.Module):
"""
"""
def __init__(
self,
hidden_dims: list[int] = [8, 8, 16, 16, 24, 24, 32],
fc_dims: list[int] = [1024, 256, 64, 32],
in_channels: int = 1,
in_fc: int = 8192,
x_shape: tuple = (128, 128)
hidden_dims: Tuple[int] = (8, 8, 16, 16, 24, 24),
fc_dims: Tuple[int] = (1024, 256, 64),
z_dim: int = 32,
in_channels: int = 1,
x_shape: Tuple[int] = (128, 128)
):
"""
"""
super().__init__()
self.in_fc = in_fc
self.in_channels = in_channels
self.x_shape = x_shape
self.x_dim = torch.prod(x_shape)
modules = []
for h_dim in hidden_dims:
stride = 2 if h_dim == in_channels else 1
fc_dims = (*fc_dims, z_dim)
hidden_dims = (*hidden_dims, z_dim)

self.in_channels = in_channels
self.fc_view = (int(fc_dims[-1]),int(fc_dims[-1]/2),int(fc_dims[-1]/2))
self.x_shape = torch.tensor(x_shape)
self.x_dim = torch.prod(self.x_shape)
self.in_fc = int(self.x_dim / 2)
in_fc = self.in_fc
modules = []
for h_dim in hidden_dims:
stride = 2 if h_dim == in_channels else 1
modules.append(
nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=stride, padding=1),
nn.ReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
modules = []
for fc_dim in fc_dims[:-2]:
self.encoder = nn.Sequential(*modules)
modules = []
for fc_dim in fc_dims[:-2]:
modules.append(
nn.Sequential(
nn.Linear(in_fc, fc_dim),
nn.Linear(in_fc, fc_dim),
nn.ReLU())
)
in_fc = fc_dim
self.encoder_bottleneck = nn.Sequential(*modules)

self.mu_layer = nn.Sequential(
nn.Linear(fc_dims[-3], fc_dims[-2]),
nn.ReLU(),
nn.Linear(fc_dims[-2], fc_dims[-1]))

self.u_layer = nn.Sequential(
nn.Linear(fc_dims[-3], fc_dims[-2]),
nn.ReLU(),
nn.Linear(fc_dims[-2], fc_dims[-1]))

self.d_layer = nn.Sequential(
nn.Linear(fc_dims[-3], fc_dims[-2]),
nn.ReLU(),
nn.Linear(fc_dims[-2], fc_dims[-1]))

fc_dims.reverse()
modules = []
for i in range(len(fc_dims)):
out = self.fc_in if i == len(fc_dims) else fc_dims[i+1]
self.encoder_bottleneck = nn.Sequential(*modules)
self.mu_layer = BottleneckLayer(fc_dims[-3:])
self.cov_factor_layer = BottleneckLayer(fc_dims[-3:])
self.cov_diag_layer = BottleneckLayer(fc_dims[-3:])
fc_dims = fc_dims[::-1]
modules = []
for i in range(len(fc_dims)):
out = self.in_fc if i == len(fc_dims) - 1 else fc_dims[i+1]
modules.append(
nn.Sequential(
nn.Linear(fc_dims[i], out),
nn.Linear(fc_dims[i], out),
nn.ReLU())
)
self.decoder_bottleneck = nn.Sequential(*modules)

hidden_dims.reverse()
modules = []
for i, h_dim in enumerate(hidden_dims):
stride = 2 if h_dim == in_channels else 1
output_padding = 1 if h_dim == in_channels else 0
modules.append(
nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ConvTranspose2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=stride, padding=1, output_padding=output_padding),
nn.ReLU() if i != len(hidden_dims))
)
self.decoder_bottleneck = nn.Sequential(*modules)
hidden_dims = ( *hidden_dims[-2::-1], self.in_channels)
hidden_dims
modules = []
for i, h_dim in enumerate(hidden_dims):
stride = 2 if h_dim == in_channels else 1
output_padding = 1 if h_dim == in_channels else 0
layers = [ nn.BatchNorm2d(in_channels),
nn.ConvTranspose2d(in_channels, out_channels=h_dim, kernel_size=3, stride=stride, padding=1, output_padding=output_padding)]
if i != len(hidden_dims) - 1:
layers.append(nn.ReLU())

modules.append( nn.Sequential(*layers) )
in_channels = h_dim

self.decoder = nn.Sequential(*modules)
self.decoder = nn.Sequential(*modules)

def encode(self, x):
"""
"""
x = self.encoder(x.unsqueeze(self.in_channels)).view(-1, self.in_fc)
x = self.encoder_bottleneck(x)
mu = self.mu_layer(x)
u = self.u_layer(x).unsqueeze(-1)
d = torch.exp(self.d_layer(x))
z, latent_dist = self.reparametrize(mu, u, d)
return z, latent_dist
"""
"""
x = self.encoder(x.unsqueeze(self.in_channels)).view(-1, self.in_fc)
x = self.encoder_bottleneck(x)
mu = self.mu_layer(x)
cov_factor = self.cov_factor_layer(x).unsqueeze(-1)
cov_diag = torch.exp(self.cov_diag_layer(x))
z, latent_dist = self.reparametrize(mu, cov_factor, cov_diag)
return z, latent_dist


def decode(self, z):
"""
"""
z = self.decoder_bottleneck(z).view(-1,32,16,16)
z = self.decoder(z).view(-1, x_dim)
return z

def reparametrize(self, mu, u, d):
latent_dist = LowRankMultivariateNormal(mu, u, d)
z = latent_dist.rsample()
"""
"""
z = self.decoder_bottleneck(z).view(-1, self.fc_view[0], self.fc_view[1], self.fc_view[2])
z = self.decoder(z).view(-1, self.x_dim)
return z

@staticmethod
def reparametrize(mu, cov_factor, cov_diag):
latent_dist = LowRankMultivariateNormal(mu, cov_factor, cov_diag)
z = latent_dist.rsample()
return z, latent_dist


def forward(self, x, return_latent_rec=False):
z, latent_dist = self.encode(x)
x_rec = self.decode(z)
return x_rec, {'z': z, 'latent_dist': latent_dist,}
def forward(self, x):
z, latent_dist = self.encode(x)
x_rec = self.decode(z).view(-1, self.x_shape[0], self.x_shape[1])
return x_rec, z, latent_dist
6 changes: 3 additions & 3 deletions src/vak/nn/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .dice import DiceLoss, dice_loss
from .umap import UmapLoss, umap_loss
from .vae import VaeLoss, vae_loss
from .vae import VaeElboLoss, vae_elbo_loss


__all__ = [
"DiceLoss",
"dice_loss",
"UmapLoss",
"umap_loss",
"VaeLoss",
"vae_loss"
"VaeElboLoss",
"vae_elbo_loss"
]
9 changes: 5 additions & 4 deletions src/vak/nn/loss/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import math
import torch
import numpy as np
# vak.nn.loss.vae
def vae_loss(


def vae_elbo_loss(
x: torch.Tensor,
z: torch.Tensor,
x_rec: torch.Tensor,
Expand All @@ -25,7 +26,7 @@ def vae_loss(
elbo = elbo + torch.sum(latent_dist.entropy())
return elbo

class VaeLoss(torch.nn.Module):
class VaeElboLoss(torch.nn.Module):
""""""

def __init__(
Expand All @@ -47,7 +48,7 @@ def forward(
latent_dist: torch.Tensor,
):
x_shape = x.shape
elbo = vae_loss(x=x, z=z, x_rec=x_rec, latent_dist=latent_dist, model_precision=self.model_precision, z_dim=self.z_dim)
elbo = vae_elbo_loss(x=x, z=z, x_rec=x_rec, latent_dist=latent_dist, model_precision=self.model_precision, z_dim=self.z_dim)
if self.return_latent_rec:
return -elbo, z.detach().cpu().numpy(), \
x_rec.view(-1, x_shape[0], x_shape[1]).detach().cpu().numpy()
Expand Down
20 changes: 18 additions & 2 deletions test_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,26 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from src.vak.nets.ava import Ava\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x_shape = (3, 128, 512)\n",
"input = torch.zeros(x_shape)\n",
"net = Ava(x_shape=(x_shape[1], x_shape[2]))\n",
"output, _ = net.forward(input)\n",
"assert output.shape == x_shape, 'Error'"
]
}
],
"metadata": {
Expand Down

0 comments on commit 6009b0d

Please sign in to comment.