Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The model "loses" or "degrades" its multimodality when you increase the size of the model #33

Open
F4k3r22 opened this issue Dec 29, 2024 · 5 comments

Comments

@F4k3r22
Copy link

F4k3r22 commented Dec 29, 2024

The model "loses" or "degrades" its multimodality when you increase the size of the model, the maximum that can be increased is this:

model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=256,  # Increased transformer dimensions
        depth=8,  # More layers
        dim_head=64,
        heads=8,
    )
).to(device)

I noticed this because I wanted to train with this setup and I could never get multimodality in sampling during my training:

model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=768,  # Increased transformer dimensions
        depth=12,  # More layers
        dim_head=64,
        heads=12,
    )
).to(device)

Maybe it's because of the transfusion_attn_mask, we can give more weight to the multimodal interactions of the model to compensate for the larger possible configurations

I think this could be done:

def transfusion_attn_mask(modalities: Int['b m 3']):
    modalities = modalities.long()
    
    def mask_mod(b, h, q_idx, kv_idx):

        causal_mask = causal(b, h, q_idx, kv_idx)
        

        modality_mask = torch.zeros_like(causal_mask, dtype=torch.bool)
        modality_batch = modalities[b]
        

        modality_attention_factor = 1.5  
        
        for mod_type, offset, length in modality_batch:

            current_modality_mask = modality(offset, length)(b, h, q_idx, kv_idx)
            

            if mod_type > 0: 
                current_modality_mask = current_modality_mask | current_modality_mask.transpose(-1, -2)
            

            modality_mask = modality_mask | current_modality_mask


        final_mask = causal_mask | (modality_mask * modality_attention_factor)
        
        return final_mask

    return mask_mod

Even if it's not the correct solution, I hope this observation helps you, have a nice day

@lucidrains
Copy link
Owner

@F4k3r22 turns out i introduced a bug when adding the meta tokens from hymba paper

would you like to try again?

@F4k3r22
Copy link
Author

F4k3r22 commented Dec 29, 2024

Ok, I'm going to do some training tests, and then I'll let you know if it still has multimodality

@F4k3r22
Copy link
Author

F4k3r22 commented Dec 29, 2024

Size test:
Small model:

⚡ ~ python main.py            
training autoencoder
loss: 0.38245: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:31<00:00, 15.74it/s]
training transfusion with autoencoder
decoding modality [0]: : 336it [00:12, 27.09it/s]                                                                                         | 5/100000 [00:01<9:07:43,  3.04it/s]
2024-12-29 17:17:46.610 | INFO     | transfusion_pytorch.transfusion:sample:1731 - sampling stopped at length: 336 / 256                                                       
2024-12-29 17:17:46.616 | INFO     | transfusion_pytorch.transfusion:print_modality_sample:253 - [('text', torch.Size([81])), ('modality:0', torch.Size([3, 256, 256])), ('text', torch.Size([1]))]
decoding modality [0]: : 330it [00:03, 108.58it/s]                                                                                      | 10/100000 [00:16<37:21:10,  1.34s/it]
2024-12-29 17:17:52.060 | INFO     | transfusion_pytorch.transfusion:sample:1731 - sampling stopped at length: 330 / 256                                                       
2024-12-29 17:17:52.062 | INFO     | transfusion_pytorch.transfusion:print_modality_sample:253 - [('text', torch.Size([75])), ('modality:0', torch.Size([3, 256, 256])), ('text', torch.Size([1]))]

Large model:

⚡ ~ python main.py
training autoencoder
loss: 0.35039: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:31<00:00, 15.88it/s]
training transfusion with autoencoder
decoding text: : 257it [00:34,  7.40it/s]                                                                                                | 5/100000 [00:06<33:45:19,  1.22s/it]
2024-12-29 17:22:24.087 | INFO     | transfusion_pytorch.transfusion:sample:1731 - sampling stopped at length: 257 / 256                                                       
2024-12-29 17:22:24.088 | INFO     | transfusion_pytorch.transfusion:print_modality_sample:253 - [('text', torch.Size([258]))]
decoding text: : 257it [00:09, 25.86it/s]                                                                                                                                      
2024-12-29 17:22:45.075 | INFO     | transfusion_pytorch.transfusion:sample:1731 - sampling stopped at length: 257 / 256█████████████████████| 256/256 [00:09<00:00, 24.61it/s]
2024-12-29 17:22:45.077 | INFO     | transfusion_pytorch.transfusion:print_modality_sample:253 - [('text', torch.Size([258]))]
loss: 1.211:   0%|                                                                                                                      | 14/100000 [01:11<94:28:38,  3.40s/it]loss: 1.211:   0%|                                                                                                                     | 14/100000 [01:13<144:53:56,  5.22s/it]

And if it helps at all I'll leave the training file in case there's an error in that:
Small model:

from shutil import rmtree
from pathlib import Path
import random
import torch
from torch import nn, tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torchvision.transforms as T
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image
import pandas as pd
import ast
from PrometheusCore import EncoderV1, DecoderV1
import numpy as np
from transfusion_pytorch import Transfusion, print_modality_sample
from torch.nn.utils.rnn import pad_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setup directories
rmtree('./results', ignore_errors=True)
results_folder = Path('./results')
results_folder.mkdir(exist_ok=True, parents=True)

# Checkpoint functions remain the same
def save_checkpoint(model, epoch, loss, checkpoint_dir='checkpoints'):
    checkpoint_path = Path(checkpoint_dir)
    checkpoint_path.mkdir(exist_ok=True, parents=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'loss': loss
    }

    checkpoint_file = checkpoint_path / f'checkpoint_epoch_{epoch}.pt'
    torch.save(checkpoint, checkpoint_file)
    latest_checkpoint = checkpoint_path / 'checkpoint_latest.pt'
    torch.save(checkpoint, latest_checkpoint)

def collate_fn(batch):
    # Separar las captions y las imágenes
    captions, images = zip(*batch)

    # Aplicar padding a las captions
    padded_captions = pad_sequence(captions, batch_first=True, padding_value=0)

    # Convertir imágenes en un solo tensor (PyTorch lo maneja bien si tienen tamaño uniforme)
    images = torch.stack(images)

    return padded_captions, images

def load_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

def divisible_by(num, den):
    return (num % den) == 0

def cycle(iter_dl):
    while True:
        for batch in iter_dl:
            yield batch

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

class Normalize(Module):
    def forward(self, x):
        return F.normalize(x, dim = -1)

class Flickr30kDataset(Dataset):
    def __init__(self, csv_path, images_dir):
        """
        Args:
            csv_path: Ruta al archivo CSV con los datos de Flickr30k
            images_dir: Directorio que contiene las imágenes
        """
        self.images_dir = Path(images_dir)

        # Cargar y procesar el CSV
        self.df = pd.read_csv(csv_path)
        # Convertir las strings de lista a listas reales
        self.df['raw'] = self.df['raw'].apply(ast.literal_eval)

        # Filtrar por split=='train' si es necesario
        self.df = self.df[self.df['split'] == 'train'].reset_index(drop=True)

        self.transform = T.Compose([
            T.Resize(286),  # Resize más grande para crop
            T.RandomCrop(256),  # Random crop para data augmentation
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Cargar imagen
        image_path = self.images_dir / row['filename']
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image)

        # Seleccionar una caption aleatoria de las 5 disponibles
        caption = random.choice(row['raw'])
        caption = np.frombuffer(caption.encode('utf-8'), dtype=np.uint8)
        caption = tensor(caption, dtype=torch.long)

        return caption, image_tensor

# Larger encoder for 256x256 images
dim_latent = 256  # Increased latent dimension for more complex images

encoder = EncoderV1(dim_latent=dim_latent)

# Decoder modificado para manejar las dimensiones correctamente
decoder = DecoderV1(dim_latent=dim_latent)

# Create dataset and dataloaders
dataset = Flickr30kDataset(csv_path="/teamspace/studios/this_studio/flickr30k/flickr_annotations_30k.csv", images_dir="/teamspace/studios/this_studio/flickr30k/flickr30k-images/")

# Training parameters
batch_size = 8  # Reduced batch size due to larger images
accum_steps = 4
autoencoder_train_steps = 500  # Increased steps for more complex dataset
learning_rate = 1e-4  # Adjusted learning rate

autoencoder_optimizer = AdamW([*encoder.parameters(), *decoder.parameters()], lr=learning_rate)
autoencoder_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
autoencoder_iter_dl = cycle(autoencoder_dataloader)

print('training autoencoder')

with tqdm(total=autoencoder_train_steps) as pbar:
    for step in range(autoencoder_train_steps):
        total_loss = 0
        autoencoder_optimizer.zero_grad()

        for _ in range(accum_steps):
            _, images = next(autoencoder_iter_dl)
            images = images.to(device)

            latents = encoder(images)
            latents = latents.lerp(torch.randn_like(latents), torch.rand_like(latents) * 0.1)
            #print("Latents shape:", latents.shape)
            reconstructed = decoder(latents)

            loss = F.mse_loss(images, reconstructed)
            loss = loss / accum_steps  # Normalizar la pérdida
            total_loss += loss.item()

            loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_([*encoder.parameters(), *decoder.parameters()], 1.0)

        autoencoder_optimizer.step()
        autoencoder_optimizer.zero_grad()

        pbar.set_description(f'loss: {total_loss:.5f}')
        pbar.update()

        # Guardar muestras periódicamente
        if step % 500 == 0:
            with torch.no_grad():
                save_image(
                    reconstructed[0].cpu().clamp(min=-1., max=1.),
                    str(results_folder / f'reconstruction_{step}.png'),
                    normalize=True
                )

# Initialize Prometheus with new parameters
model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=256,  # Increased transformer dimensions
        depth=8,  # More layers
        dim_head=64,
        heads=8,
    )
).to(device)

dataloader = model.create_dataloader(dataset, batch_size=batch_size, shuffle=True)
iter_dl = cycle(dataloader)

optimizer = AdamW(model.parameters_without_encoder_decoder(), lr=learning_rate)
transfusion_train_steps = 100000  # Increased steps

print('training transfusion with autoencoder')

with tqdm(total=transfusion_train_steps) as pbar:
    for index in range(transfusion_train_steps):
        step = index + 1
        model.train()

        data = next(iter_dl)
        #print(f"Data from dataloader: {data}")
        loss = model(data)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Increased grad clip

        optimizer.step()
        optimizer.zero_grad()

        pbar.set_description(f'loss: {loss.item():.3f}')
        pbar.update()

        if divisible_by(step, 5):
            save_checkpoint(model=model, epoch=step, loss=loss.item())
            one_multimodal_sample = model.sample(max_length=256)  # Increased for longer captions

            print_modality_sample(one_multimodal_sample)

            if len(one_multimodal_sample) < 2:
                continue

            caption, image, *_ = one_multimodal_sample

            filename = f'{step}.png'
            save_image(
                image[1].cpu().clamp(min=-1., max=1.),
                str(results_folder / filename),
                normalize=True
            )

Large model:

from shutil import rmtree
from pathlib import Path
import random
import torch
from torch import nn, tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torchvision.transforms as T
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image
import pandas as pd
import ast
from PrometheusCore import EncoderV1, DecoderV1
import numpy as np
from transfusion_pytorch import Transfusion, print_modality_sample
from torch.nn.utils.rnn import pad_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setup directories
rmtree('./results', ignore_errors=True)
results_folder = Path('./results')
results_folder.mkdir(exist_ok=True, parents=True)

# Checkpoint functions remain the same
def save_checkpoint(model, epoch, loss, checkpoint_dir='checkpoints'):
    checkpoint_path = Path(checkpoint_dir)
    checkpoint_path.mkdir(exist_ok=True, parents=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'loss': loss
    }

    checkpoint_file = checkpoint_path / f'checkpoint_epoch_{epoch}.pt'
    torch.save(checkpoint, checkpoint_file)
    latest_checkpoint = checkpoint_path / 'checkpoint_latest.pt'
    torch.save(checkpoint, latest_checkpoint)

def collate_fn(batch):
    # Separar las captions y las imágenes
    captions, images = zip(*batch)

    # Aplicar padding a las captions
    padded_captions = pad_sequence(captions, batch_first=True, padding_value=0)

    # Convertir imágenes en un solo tensor (PyTorch lo maneja bien si tienen tamaño uniforme)
    images = torch.stack(images)

    return padded_captions, images

def load_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

def divisible_by(num, den):
    return (num % den) == 0

def cycle(iter_dl):
    while True:
        for batch in iter_dl:
            yield batch

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

class Normalize(Module):
    def forward(self, x):
        return F.normalize(x, dim = -1)

class Flickr30kDataset(Dataset):
    def __init__(self, csv_path, images_dir):
        """
        Args:
            csv_path: Ruta al archivo CSV con los datos de Flickr30k
            images_dir: Directorio que contiene las imágenes
        """
        self.images_dir = Path(images_dir)

        # Cargar y procesar el CSV
        self.df = pd.read_csv(csv_path)
        # Convertir las strings de lista a listas reales
        self.df['raw'] = self.df['raw'].apply(ast.literal_eval)

        # Filtrar por split=='train' si es necesario
        self.df = self.df[self.df['split'] == 'train'].reset_index(drop=True)

        self.transform = T.Compose([
            T.Resize(286),  # Resize más grande para crop
            T.RandomCrop(256),  # Random crop para data augmentation
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Cargar imagen
        image_path = self.images_dir / row['filename']
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image)

        # Seleccionar una caption aleatoria de las 5 disponibles
        caption = random.choice(row['raw'])
        caption = np.frombuffer(caption.encode('utf-8'), dtype=np.uint8)
        caption = tensor(caption, dtype=torch.long)

        return caption, image_tensor

# Larger encoder for 256x256 images
dim_latent = 256  # Increased latent dimension for more complex images

encoder = EncoderV1(dim_latent=dim_latent)

# Decoder modificado para manejar las dimensiones correctamente
decoder = DecoderV1(dim_latent=dim_latent)

# Create dataset and dataloaders
dataset = Flickr30kDataset(csv_path="/teamspace/studios/this_studio/flickr30k/flickr_annotations_30k.csv", images_dir="/teamspace/studios/this_studio/flickr30k/flickr30k-images/")

# Training parameters
batch_size = 8  # Reduced batch size due to larger images
accum_steps = 4
autoencoder_train_steps = 500  # Increased steps for more complex dataset
learning_rate = 1e-4  # Adjusted learning rate

autoencoder_optimizer = AdamW([*encoder.parameters(), *decoder.parameters()], lr=learning_rate)
autoencoder_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
autoencoder_iter_dl = cycle(autoencoder_dataloader)

print('training autoencoder')

with tqdm(total=autoencoder_train_steps) as pbar:
    for step in range(autoencoder_train_steps):
        total_loss = 0
        autoencoder_optimizer.zero_grad()

        for _ in range(accum_steps):
            _, images = next(autoencoder_iter_dl)
            images = images.to(device)

            latents = encoder(images)
            latents = latents.lerp(torch.randn_like(latents), torch.rand_like(latents) * 0.1)
            #print("Latents shape:", latents.shape)
            reconstructed = decoder(latents)

            loss = F.mse_loss(images, reconstructed)
            loss = loss / accum_steps  # Normalizar la pérdida
            total_loss += loss.item()

            loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_([*encoder.parameters(), *decoder.parameters()], 1.0)

        autoencoder_optimizer.step()
        autoencoder_optimizer.zero_grad()

        pbar.set_description(f'loss: {total_loss:.5f}')
        pbar.update()

        # Guardar muestras periódicamente
        if step % 500 == 0:
            with torch.no_grad():
                save_image(
                    reconstructed[0].cpu().clamp(min=-1., max=1.),
                    str(results_folder / f'reconstruction_{step}.png'),
                    normalize=True
                )

# Initialize Prometheus with new parameters
model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=768,  # Increased transformer dimensions
        depth=12,  # More layers
        dim_head=64,
        heads=12,
    )
).to(device)

dataloader = model.create_dataloader(dataset, batch_size=batch_size, shuffle=True)
iter_dl = cycle(dataloader)

optimizer = AdamW(model.parameters_without_encoder_decoder(), lr=learning_rate)
transfusion_train_steps = 100000  # Increased steps

print('training transfusion with autoencoder')

with tqdm(total=transfusion_train_steps) as pbar:
    for index in range(transfusion_train_steps):
        step = index + 1
        model.train()

        data = next(iter_dl)
        #print(f"Data from dataloader: {data}")
        loss = model(data)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Increased grad clip

        optimizer.step()
        optimizer.zero_grad()

        pbar.set_description(f'loss: {loss.item():.3f}')
        pbar.update()

        if divisible_by(step, 5):
            save_checkpoint(model=model, epoch=step, loss=loss.item())
            one_multimodal_sample = model.sample(max_length=256)  # Increased for longer captions

            print_modality_sample(one_multimodal_sample)

            if len(one_multimodal_sample) < 2:
                continue

            caption, image, *_ = one_multimodal_sample

            filename = f'{step}.png'
            save_image(
                image[1].cpu().clamp(min=-1., max=1.),
                str(results_folder / filename),
                normalize=True
            )

@F4k3r22
Copy link
Author

F4k3r22 commented Dec 29, 2024

Hey, I'm going to do a test by increasing the dimension of the autoencoder along with the model, maybe it's because of this discrepancy in dimensionality that causes this error.

@F4k3r22
Copy link
Author

F4k3r22 commented Dec 29, 2024

I already did the test and it still seems to lose that multimodality, even though the auto encoder and Transfusion share the same dimensionality:

model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=768,  # Increased transformer dimensions
        depth=12,  # More layers
        dim_head=64,
        heads=12,
    )
).to(device)

AutoEncoder:

# Larger encoder for 256x256 images
dim_latent = 768  # Increased latent dimension for more complex images

encoder = EncoderV1(dim_latent=dim_latent)


decoder = DecoderV1(dim_latent=dim_latent)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants