-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add code for finetuning moondream
- Loading branch information
1 parent
4bf2c89
commit 6b53cb0
Showing
11 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
TEST_SIZE = 0.2 | ||
|
||
# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or | ||
# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit. | ||
EPOCHS = 1 | ||
|
||
# Number of samples to process in each batch. Set this to the highest value that doesn't cause an | ||
# out-of-memory error. Decrease it if you're running out of memory. | ||
BATCH_SIZE = 8 | ||
|
||
# Number of batches to process before updating the model. You can use this to simulate a higher batch | ||
# size than your GPU can handle. Set this to 1 to disable gradient accumulation. | ||
GRAD_ACCUM_STEPS = 2 | ||
|
||
# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule | ||
# of thumb, increase it by 1.4 times each time you double the effective batch size. | ||
# | ||
# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ | ||
# | ||
# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the | ||
# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a | ||
# cosine schedule. | ||
LR = 1e-5 | ||
|
||
# Whether to use Weights and Biases for logging training metrics. | ||
USE_WANDB = False | ||
|
||
ANSWER_EOS = "<|endoftext|>" | ||
|
||
# Number of tokens used to represent each image. | ||
IMG_TOKENS = 729 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import datasets | ||
import transformers | ||
import pathlib | ||
|
||
DEVICE = "cuda" | ||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 | ||
MD_REVISION = "2024-07-23" | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") | ||
moondream = transformers.AutoModelForCausalLM.from_pretrained( | ||
"./checkpoints/moondream-mai", | ||
trust_remote_code=True, | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=DTYPE, | ||
device_map={"": DEVICE}, | ||
) | ||
|
||
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\ | ||
.shuffle()\ | ||
.take(100)\ | ||
.select_columns(["image"])\ | ||
.map(lambda row: { | ||
**row, | ||
"qa": { | ||
"question": "Describe this image.", | ||
"answer": "This is an AI image." | ||
} | ||
}) | ||
|
||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ | ||
.shuffle()\ | ||
.take(100)\ | ||
.select_columns(["image"])\ | ||
.map(lambda row: { | ||
**row, | ||
"qa": { | ||
"question": "Describe this image.", | ||
"answer": "This is a real image." | ||
} | ||
}) | ||
|
||
dataset = datasets.concatenate_datasets([diffusion_db_dataset, flickr_dataset]).shuffle() | ||
|
||
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True) | ||
|
||
for i, sample in enumerate(dataset): | ||
sample['image'].save(f"samples/{i}.png", "PNG") | ||
|
||
md_answer = moondream.answer_question( | ||
moondream.encode_image(sample['image']), | ||
sample['qa']['question'], | ||
tokenizer=tokenizer, | ||
num_beams=4, | ||
no_repeat_ngram_size=5, | ||
early_stopping=True | ||
) | ||
|
||
if i < 3: | ||
print('Question:', sample['qa']['question']) | ||
print('Ground Truth:', sample['qa']['answer']) | ||
print('Moondream:', md_answer) | ||
else: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import math | ||
import torch | ||
import datasets | ||
import transformers | ||
import bitsandbytes | ||
from tqdm import tqdm | ||
from .hyperparams import TEST_SIZE, ANSWER_EOS, IMG_TOKENS, LR, BATCH_SIZE, EPOCHS, GRAD_ACCUM_STEPS | ||
|
||
DEVICE = "cuda" | ||
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16 | ||
MD_REVISION = "2024-07-23" | ||
|
||
diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", trust_remote_code=True, split="train")\ | ||
.select_columns(["image"])\ | ||
.map(lambda row: { | ||
**row, | ||
"qa": { | ||
"question": "Describe this image.", | ||
"answer": "This is an AI image." | ||
} | ||
})\ | ||
.train_test_split(test_size=TEST_SIZE) | ||
|
||
flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test")\ | ||
.take(5000)\ | ||
.select_columns(["image"])\ | ||
.map(lambda row: { | ||
**row, | ||
"qa": { | ||
"question": "Describe this image.", | ||
"answer": "This is a real image." | ||
} | ||
})\ | ||
.train_test_split(test_size=TEST_SIZE) | ||
|
||
training_dataset = datasets.concatenate_datasets([diffusion_db_dataset["train"], flickr_dataset["train"]]).shuffle() | ||
test_dataset = datasets.concatenate_datasets([diffusion_db_dataset["test"], flickr_dataset["test"]]).shuffle() | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") | ||
moondream = transformers.AutoModelForCausalLM.from_pretrained( | ||
"vikhyatk/moondream2", | ||
trust_remote_code=True, | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=DTYPE, | ||
device_map={"": DEVICE}, | ||
) | ||
|
||
def collate(batch): | ||
images = [] | ||
all_tokens = [] | ||
all_labels = [] | ||
|
||
for sample in batch: | ||
images.append(sample["image"]) | ||
|
||
tokens = [tokenizer.bos_token_id] | ||
labels = [-100] * (IMG_TOKENS + 1) | ||
|
||
qa = sample["qa"] | ||
q_t = tokenizer( | ||
f"\n\nQuestion: {qa['question']}\n\nAnswer:", | ||
add_special_tokens=False, | ||
).input_ids | ||
tokens.extend(q_t) | ||
labels.extend([-100] * len(q_t)) | ||
|
||
a_t = tokenizer( | ||
f" {qa['answer']}{ANSWER_EOS}", | ||
add_special_tokens=False, | ||
).input_ids | ||
tokens.extend(a_t) | ||
labels.extend(a_t) | ||
|
||
all_tokens.append(tokens) | ||
all_labels.append(labels) | ||
|
||
longest_label_len = -1 | ||
for label in all_labels: | ||
longest_label_len = max(longest_label_len, len(label)) | ||
|
||
all_attn_masks = [] | ||
for i in range(len(batch)): | ||
label_len = len(all_labels[i]) | ||
pad_len = longest_label_len - label_len | ||
|
||
all_labels[i].extend([-100] * pad_len) | ||
all_tokens[i].extend([tokenizer.eos_token_id] * pad_len) | ||
all_attn_masks.append([1] * label_len + [0] * pad_len) | ||
|
||
return ( | ||
images, | ||
torch.stack([torch.tensor(token, dtype=torch.long) for token in all_tokens]), | ||
torch.stack([torch.tensor(label, dtype=torch.long) for label in all_labels]), | ||
torch.stack([torch.tensor(mask, dtype=torch.bool) for mask in all_attn_masks]), | ||
) | ||
|
||
def compute_loss(batch): | ||
images, tokens, labels, masks = batch | ||
|
||
tokens = tokens.to(DEVICE) | ||
labels = labels.to(DEVICE) | ||
masks = masks.to(DEVICE) | ||
|
||
with torch.no_grad(): | ||
img_embeds = moondream.vision_encoder(images) | ||
|
||
token_embeds = moondream.text_model.get_input_embeddings()(tokens) | ||
|
||
# start with embedding vector that represents bos, then insert image embeds, then the rest of the token embeds | ||
# <BOS> + the image + all the tokens | ||
inputs_embeds = torch.cat((token_embeds[:, 0:1, :], img_embeds, token_embeds[:, 1:, :]), dim=1) | ||
|
||
outputs = moondream.text_model( | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
attention_mask=masks, | ||
) | ||
|
||
return outputs.loss | ||
|
||
def lr_schedule(step, max_steps): | ||
x = step / max_steps | ||
if x < 0.1: | ||
return 0.1 * LR + 0.9 * LR * x / 0.1 | ||
else: | ||
return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2 | ||
|
||
dataloaders = { | ||
"train": torch.utils.data.DataLoader( | ||
training_dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=True, | ||
collate_fn=collate, | ||
) | ||
} | ||
|
||
moondream.text_model.train() | ||
moondream.text_model.transformer.gradient_checkpointing_enable() | ||
|
||
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS | ||
optimizer = bitsandbytes.optim.Adam8bit( | ||
[{"params": moondream.text_model.parameters()}], | ||
lr=LR*0.1, | ||
betas=(0.9, 0.95), | ||
eps=1e-6, | ||
) | ||
|
||
i = 0 | ||
for epoch in range(EPOCHS): | ||
for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"): | ||
i += 1 | ||
|
||
loss = compute_loss(batch) | ||
loss.backward() | ||
|
||
if i % GRAD_ACCUM_STEPS == 0: | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps) | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = lr | ||
|
||
moondream.save_pretrained("checkpoints/moondream-mai") |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.