From 6b53cb04115c9ed92af5c5778a544bc98cc854dc Mon Sep 17 00:00:00 2001 From: Kenneth Date: Sun, 8 Dec 2024 17:33:37 +0000 Subject: [PATCH] feat: add code for finetuning moondream --- .gitignore | 2 + moondream/hyperparams.py | 31 ++++ moondream/test.py | 64 +++++++++ moondream/train.py | 164 ++++++++++++++++++++++ augmentation.py => resnet/augmentation.py | 0 hyperparams.py => resnet/hyperparams.py | 0 inference.py => resnet/inference.py | 0 label.py => resnet/label.py | 0 model.py => resnet/model.py | 0 resnet.py => resnet/resnet.py | 0 train.py => resnet/train.py | 0 11 files changed, 261 insertions(+) create mode 100644 moondream/hyperparams.py create mode 100644 moondream/test.py create mode 100644 moondream/train.py rename augmentation.py => resnet/augmentation.py (100%) rename hyperparams.py => resnet/hyperparams.py (100%) rename inference.py => resnet/inference.py (100%) rename label.py => resnet/label.py (100%) rename model.py => resnet/model.py (100%) rename resnet.py => resnet/resnet.py (100%) rename train.py => resnet/train.py (100%) diff --git a/.gitignore b/.gitignore index db5f73d..c1f5a1f 100644 --- a/.gitignore +++ b/.gitignore @@ -250,3 +250,5 @@ $RECYCLE.BIN/ # End of https://www.toptal.com/developers/gitignore/api/python,macos,linux,windows test_images/ +checkpoints/ +samples/ diff --git a/moondream/hyperparams.py b/moondream/hyperparams.py new file mode 100644 index 0000000..81dcc7b --- /dev/null +++ b/moondream/hyperparams.py @@ -0,0 +1,31 @@ +TEST_SIZE = 0.2 + +# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or +# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit. +EPOCHS = 1 + +# Number of samples to process in each batch. Set this to the highest value that doesn't cause an +# out-of-memory error. Decrease it if you're running out of memory. +BATCH_SIZE = 8 + +# Number of batches to process before updating the model. You can use this to simulate a higher batch +# size than your GPU can handle. Set this to 1 to disable gradient accumulation. +GRAD_ACCUM_STEPS = 2 + +# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule +# of thumb, increase it by 1.4 times each time you double the effective batch size. +# +# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ +# +# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the +# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a +# cosine schedule. +LR = 1e-5 + +# Whether to use Weights and Biases for logging training metrics. +USE_WANDB = False + +ANSWER_EOS = "<|endoftext|>" + +# Number of tokens used to represent each image. +IMG_TOKENS = 729 diff --git a/moondream/test.py b/moondream/test.py new file mode 100644 index 0000000..cc08f29 --- /dev/null +++ b/moondream/test.py @@ -0,0 +1,64 @@ +import torch +import datasets +import transformers +import pathlib + +DEVICE = "cuda" +DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 +MD_REVISION = "2024-07-23" + +tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") +moondream = transformers.AutoModelForCausalLM.from_pretrained( + "./checkpoints/moondream-mai", + trust_remote_code=True, + attn_implementation="flash_attention_2", + torch_dtype=DTYPE, + device_map={"": DEVICE}, +) + +diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\ + .shuffle()\ + .take(100)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe this image.", + "answer": "This is an AI image." + } + }) + +flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ + .shuffle()\ + .take(100)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe this image.", + "answer": "This is a real image." + } + }) + +dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle() + +pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) + +for i, sample in enumerate(dataset): + sample['image'].save(f"samples/{i}.png", "PNG") + + md_answer = moondream.answer_question( + moondream.encode_image(sample['image']), + sample['qa']['question'], + tokenizer=tokenizer, + num_beams=4, + no_repeat_ngram_size=5, + early_stopping=True + ) + + if i < 3: + print('Question:', sample['qa']['question']) + print('Ground Truth:', sample['qa']['answer']) + print('Moondream:', md_answer) + else: + break diff --git a/moondream/train.py b/moondream/train.py new file mode 100644 index 0000000..b9c159f --- /dev/null +++ b/moondream/train.py @@ -0,0 +1,164 @@ +import math +import torch +import datasets +import transformers +import bitsandbytes +from tqdm import tqdm +from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS + +DEVICE = "cuda" +DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 +MD_REVISION = "2024-07-23" + +diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe this image.", + "answer": "This is an AI image." + } + })\ + .train_test_split(test_size=TEST_SIZE) + +flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ + .take(5000)\ + .select_columns(["image"])\ + .map(lambda row: { + **row, + "qa": { + "question": "Describe this image.", + "answer": "This is a real image." + } + })\ + .train_test_split(test_size=TEST_SIZE) + +training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle() +test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle() + +tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") +moondream = transformers.AutoModelForCausalLM.from_pretrained( + "vikhyatk/moondream2", + trust_remote_code=True, + attn_implementation="flash_attention_2", + torch_dtype=DTYPE, + device_map={"": DEVICE}, +) + +def collate(batch): + images = [] + all_tokens = [] + all_labels = [] + + for sample in batch: + images.append(sample["image"]) + + tokens = [tokenizer.bos_token_id] + labels = [-100] * (IMG_TOKENS + 1) + + qa = sample["qa"] + q_t = tokenizer( + f"\n\nQuestion: {qa['question']}\n\nAnswer:", + add_special_tokens=False, + ).input_ids + tokens.extend(q_t) + labels.extend([-100] * len(q_t)) + + a_t = tokenizer( + f" {qa['answer']}{ANSWER_EOS}", + add_special_tokens=False, + ).input_ids + tokens.extend(a_t) + labels.extend(a_t) + + all_tokens.append(tokens) + all_labels.append(labels) + + longest_label_len = -1 + for label in all_labels: + longest_label_len = max(longest_label_len, len(label)) + + all_attn_masks = [] + for i in range(len(batch)): + label_len = len(all_labels[i]) + pad_len = longest_label_len - label_len + + all_labels[i].extend([-100] * pad_len) + all_tokens[i].extend([tokenizer.eos_token_id] * pad_len) + all_attn_masks.append([1] * label_len + [0] * pad_len) + + return ( + images, + torch.stack([torch.tensor(token, dtype=torch.long) for token in all_tokens]), + torch.stack([torch.tensor(label, dtype=torch.long) for label in all_labels]), + torch.stack([torch.tensor(mask, dtype=torch.bool) for mask in all_attn_masks]), + ) + +def compute_loss(batch): + images, tokens, labels, masks = batch + + tokens = tokens.to(DEVICE) + labels = labels.to(DEVICE) + masks = masks.to(DEVICE) + + with torch.no_grad(): + img_embeds = moondream.vision_encoder(images) + + token_embeds = moondream.text_model.get_input_embeddings()(tokens) + + # start with embedding vector that represents bos, then insert image embeds, then the rest of the token embeds + # + the image + all the tokens + inputs_embeds = torch.cat((token_embeds[:, 0:1, :], img_embeds, token_embeds[:, 1:, :]), dim=1) + + outputs = moondream.text_model( + inputs_embeds=inputs_embeds, + labels=labels, + attention_mask=masks, + ) + + return outputs.loss + +def lr_schedule(step, max_steps): + x = step / max_steps + if x < 0.1: + return 0.1 * LR + 0.9 * LR * x / 0.1 + else: + return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2 + +dataloaders = { + "train": torch.utils.data.DataLoader( + training_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + collate_fn=collate, + ) +} + +moondream.text_model.train() +moondream.text_model.transformer.gradient_checkpointing_enable() + +total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS +optimizer = bitsandbytes.optim.Adam8bit( + [{"params": moondream.text_model.parameters()}], + lr=LR*0.1, + betas=(0.9, 0.95), + eps=1e-6, +) + +i = 0 +for epoch in range(EPOCHS): + for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"): + i += 1 + + loss = compute_loss(batch) + loss.backward() + + if i % GRAD_ACCUM_STEPS == 0: + optimizer.step() + optimizer.zero_grad() + + lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + +moondream.save_pretrained("checkpoints/moondream-mai") diff --git a/augmentation.py b/resnet/augmentation.py similarity index 100% rename from augmentation.py rename to resnet/augmentation.py diff --git a/hyperparams.py b/resnet/hyperparams.py similarity index 100% rename from hyperparams.py rename to resnet/hyperparams.py diff --git a/inference.py b/resnet/inference.py similarity index 100% rename from inference.py rename to resnet/inference.py diff --git a/label.py b/resnet/label.py similarity index 100% rename from label.py rename to resnet/label.py diff --git a/model.py b/resnet/model.py similarity index 100% rename from model.py rename to resnet/model.py diff --git a/resnet.py b/resnet/resnet.py similarity index 100% rename from resnet.py rename to resnet/resnet.py diff --git a/train.py b/resnet/train.py similarity index 100% rename from train.py rename to resnet/train.py