Skip to content

Commit

Permalink
Update MGDAT
Browse files Browse the repository at this point in the history
  • Loading branch information
Catchxu committed Apr 7, 2024
1 parent 31f3ce6 commit a4d5ae0
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 15 deletions.
Binary file modified MobileUNet.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .utils import seed_everything

class SNNet:
def __init__(self, epochs: List[int] = [10, 5], batch_size: int = 128,
def __init__(self, epochs: List[int] = [10, 5], batch_size: int = 64,
learning_rate: float = 1e-4, GPU: Optional[str] = "cuda:0",
random_state: Optional[int] = None, Mobile: bool = False):
if GPU is not None:
Expand Down
105 changes: 95 additions & 10 deletions model/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,36 @@


class GAT(nn.Module):
def __init__(self, in_dim, out_dim, nheads):
def __init__(self, in_dim, out_dim, nheads, single_layer=True):
super().__init__()
self.num_layers = 2
self.gat_layers = nn.ModuleList()
self.single_layer = single_layer

if isinstance(nheads, int):
nheads = [nheads]

self.gat_layers.append(
GATv2Conv(in_dim, out_dim, nheads[0], feat_drop=0.2, attn_drop=0.1))
self.gat_layers.append(
GATv2Conv(out_dim*nheads[0], out_dim, nheads[1], feat_drop=0.2, attn_drop=0.1)
GATv2Conv(in_dim, out_dim, nheads[0], feat_drop=0.2, attn_drop=0.1)
)


if self.single_layer:
self.gat_layers.append(
nn.Linear(out_dim*nheads[0], out_dim)
)
else:
self.gat_layers.append(
GATv2Conv(out_dim*nheads[0], out_dim, nheads[1], feat_drop=0.2, attn_drop=0.1)
)

def forward(self, blocks, x):
for i in range(self.num_layers):
g = blocks[i]
x = self.gat_layers[i](g, x).flatten(1)
if self.single_layer:
x = self.gat_layers[0](blocks, x).flatten(1)
x = self.gat_layers[1](x)
else:
for i in range(self.num_layers):
g = blocks[i]
x = self.gat_layers[i](g, x).flatten(1)
return x


Expand Down Expand Up @@ -103,7 +118,7 @@ def __init__(self, g_dim, emb_chan, patch_size, fused_dim=16,
self.bottleneck = TransformerEncoder(encoder_layer, TF_layers, fused_dim)

# GAT Fusion
self.layer = GAT(g_dim+fused_dim, g_dim, nheads=GAT_nheads)
self.layer = GAT(g_dim+fused_dim, g_dim, nheads=GAT_nheads, single_layers=False)

# mask
self.g_mask_token = nn.Parameter(torch.zeros(1, g_dim))
Expand Down Expand Up @@ -149,4 +164,74 @@ def __init__(self, g_dim, p_dim, z_dim=None):
def forward(self, g, p):
output = torch.cat((g, p), dim=1)
output = self.fc_out(output)
return output
return output


class MGDAT(nn.Module):
def __init__(self, g_dim, emb_chan, patch_size, fused_dim=16, blocks=2,
TF_layers=2, TF_nheads=4, GAT_nheads=2, mask=True):
super().__init__()

# Patch Projection (emb_chan * patch_size**2 -->> g_dim)
self.emb_chan = emb_chan
self.patch_size = patch_size
p_dim = emb_chan * patch_size**2
self.p_down = nn.Linear(p_dim, g_dim)
self.p_up = nn.Linear(g_dim, p_dim)

# Transformer Fusion
encoder_layer = TransformerLayer(g_dim*2, TF_nheads)
self.bottleneck = TransformerEncoder(encoder_layer, TF_layers, fused_dim)

# GAT Fusion
self.GAT = GAT(g_dim+fused_dim, g_dim, nheads=GAT_nheads)

# mask
self.mask = mask
if self.mask:
self.g_mask_token = nn.Parameter(torch.zeros(1, g_dim))
self.p_mask_token = nn.Parameter(torch.zeros(1, g_dim))

self.blocks = blocks

def make_mask(self, blocks, gene_tokens, patch_tokens):
# replace input data with mask tokens
mask_tokens = self.g_mask_token.repeat(blocks[-1].num_dst_nodes(), 1)
gene_tokens = torch.cat((mask_tokens, gene_tokens[blocks[-1].num_dst_nodes():]), dim=0)
mask_tokens = self.p_mask_token.repeat(blocks[-1].num_dst_nodes(), 1)
patch_tokens = torch.cat((mask_tokens, patch_tokens[blocks[-1].num_dst_nodes():]), dim=0)

return gene_tokens, patch_tokens

def attn_mask(self, blocks, idx):
mask = torch.zeros(blocks[idx].num_src_nodes(), blocks[idx].num_src_nodes())
mask[:blocks[-1].num_dst_nodes(), :blocks[-1].num_dst_nodes()] = 1
return mask.to(self.g_mask_token)

def fusion(self, blocks, gene_tokens, patch_tokens):
for i in range(self.blocks):
concat = torch.cat((gene_tokens, patch_tokens), dim=1)

if self.mask:
mask = self.attn_mask(blocks, i).bool()
else:
mask = None

fused = self.bottleneck(concat, mask)
gene_tokens = self.GAT(blocks[i], torch.cat((gene_tokens, fused), dim=-1))
patch_tokens = self.GAT(blocks[i], torch.cat((patch_tokens, fused), dim=-1))
return gene_tokens, patch_tokens

def forward(self, blocks, g, p):
p = self.p_down(p.flatten(1))

if self.mask:
g, p = self.make_mask(blocks, g, p)

g, p = self.fusion(blocks, g, p)

p = self.p_up(p)
return g, p.reshape(-1, self.emb_chan, self.patch_size, self.patch_size)



6 changes: 3 additions & 3 deletions model/meatrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from dgl.nn import GATv2Conv

from .fusion import GraphMBT, ConcatFusion
from .fusion import GraphMBT, ConcatFusion, MGDAT
from .unet import UNet


Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self, patch_size, in_dim, out_dim=[512, 256], Mobile=False, **kwarg
self.UNet = UNet(3, Mobile=Mobile, **kwargs)

emb_chan = self.UNet.emb_chan
self.Fusion = GraphMBT(out_dim[-1], emb_chan, patch_size // (2**4))
self.Fusion = MGDAT(out_dim[-1], emb_chan, patch_size // (2**4))

self.z_g_dim = out_dim[-1]
self.z_p_dim = emb_chan * patch_size**2
Expand All @@ -239,7 +239,7 @@ def forward(self, blocks, feat_g, feat_p):
z_g = self.GeneEncoder(feat_g)
z_p, p_skips = self.UNet.encode(blocks[-1], feat_p)

# Fusion with GraphMBT
# Fusion with MGDAT
z_g, z_p = self.Fusion(blocks[:-1], z_g, z_p)

fake_g = self.GeneDecoder(blocks[-1], z_g)
Expand Down
2 changes: 1 addition & 1 deletion pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def pretrain(graph: dgl.DGLGraph,
unet_epochs: int = 30,
unet_epochs: int = 20,
batch_size: int = 128,
learning_rate: float = 1e-4,
GPU: bool = True,
Expand Down

0 comments on commit a4d5ae0

Please sign in to comment.