Skip to content

Commit

Permalink
feat: add code for finetuning moondream
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethnym committed Dec 8, 2024
1 parent 4bf2c89 commit 6b53cb0
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,5 @@ $RECYCLE.BIN/
# End of https://www.toptal.com/developers/gitignore/api/python,macos,linux,windows

test_images/
checkpoints/
samples/
31 changes: 31 additions & 0 deletions moondream/hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
TEST_SIZE = 0.2

# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or
# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit.
EPOCHS = 1

# Number of samples to process in each batch. Set this to the highest value that doesn't cause an
# out-of-memory error. Decrease it if you're running out of memory.
BATCH_SIZE = 8

# Number of batches to process before updating the model. You can use this to simulate a higher batch
# size than your GPU can handle. Set this to 1 to disable gradient accumulation.
GRAD_ACCUM_STEPS = 2

# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule
# of thumb, increase it by 1.4 times each time you double the effective batch size.
#
# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
#
# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the
# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a
# cosine schedule.
LR = 1e-5

# Whether to use Weights and Biases for logging training metrics.
USE_WANDB = False

ANSWER_EOS = "<|endoftext|>"

# Number of tokens used to represent each image.
IMG_TOKENS = 729
64 changes: 64 additions & 0 deletions moondream/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import datasets
import transformers
import pathlib

DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"

tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"./checkpoints/moondream-mai",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=DTYPE,
device_map={"": DEVICE},
)

diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})

flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.shuffle()\
.take(100)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
}
})

dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle()

pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)

for i, sample in enumerate(dataset):
sample['image'].save(f"samples/{i}.png", "PNG")

md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
early_stopping=True
)

if i < 3:
print('Question:', sample['qa']['question'])
print('Ground Truth:', sample['qa']['answer'])
print('Moondream:', md_answer)
else:
break
164 changes: 164 additions & 0 deletions moondream/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import math
import torch
import datasets
import transformers
import bitsandbytes
from tqdm import tqdm
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS

DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"

diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is an AI image."
}
})\
.train_test_split(test_size=TEST_SIZE)

flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\
.take(5000)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Describe this image.",
"answer": "This is a real image."
}
})\
.train_test_split(test_size=TEST_SIZE)

training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle()
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle()

tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream2",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=DTYPE,
device_map={"": DEVICE},
)

def collate(batch):
images = []
all_tokens = []
all_labels = []

for sample in batch:
images.append(sample["image"])

tokens = [tokenizer.bos_token_id]
labels = [-100] * (IMG_TOKENS + 1)

qa = sample["qa"]
q_t = tokenizer(
f"\n\nQuestion: {qa['question']}\n\nAnswer:",
add_special_tokens=False,
).input_ids
tokens.extend(q_t)
labels.extend([-100] * len(q_t))

a_t = tokenizer(
f" {qa['answer']}{ANSWER_EOS}",
add_special_tokens=False,
).input_ids
tokens.extend(a_t)
labels.extend(a_t)

all_tokens.append(tokens)
all_labels.append(labels)

longest_label_len = -1
for label in all_labels:
longest_label_len = max(longest_label_len, len(label))

all_attn_masks = []
for i in range(len(batch)):
label_len = len(all_labels[i])
pad_len = longest_label_len - label_len

all_labels[i].extend([-100] * pad_len)
all_tokens[i].extend([tokenizer.eos_token_id] * pad_len)
all_attn_masks.append([1] * label_len + [0] * pad_len)

return (
images,
torch.stack([torch.tensor(token, dtype=torch.long) for token in all_tokens]),
torch.stack([torch.tensor(label, dtype=torch.long) for label in all_labels]),
torch.stack([torch.tensor(mask, dtype=torch.bool) for mask in all_attn_masks]),
)

def compute_loss(batch):
images, tokens, labels, masks = batch

tokens = tokens.to(DEVICE)
labels = labels.to(DEVICE)
masks = masks.to(DEVICE)

with torch.no_grad():
img_embeds = moondream.vision_encoder(images)

token_embeds = moondream.text_model.get_input_embeddings()(tokens)

# start with embedding vector that represents bos, then insert image embeds, then the rest of the token embeds
# <BOS> + the image + all the tokens
inputs_embeds = torch.cat((token_embeds[:, 0:1, :], img_embeds, token_embeds[:, 1:, :]), dim=1)

outputs = moondream.text_model(
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=masks,
)

return outputs.loss

def lr_schedule(step, max_steps):
x = step / max_steps
if x < 0.1:
return 0.1 * LR + 0.9 * LR * x / 0.1
else:
return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2

dataloaders = {
"train": torch.utils.data.DataLoader(
training_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate,
)
}

moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable()

total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = bitsandbytes.optim.Adam8bit(
[{"params": moondream.text_model.parameters()}],
lr=LR*0.1,
betas=(0.9, 0.95),
eps=1e-6,
)

i = 0
for epoch in range(EPOCHS):
for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
i += 1

loss = compute_loss(batch)
loss.backward()

if i % GRAD_ACCUM_STEPS == 0:
optimizer.step()
optimizer.zero_grad()

lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
for param_group in optimizer.param_groups:
param_group["lr"] = lr

moondream.save_pretrained("checkpoints/moondream-mai")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 6b53cb0

Please sign in to comment.