Skip to content

Commit

Permalink
add script for generating dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethnym committed Jan 12, 2025
1 parent 14c6f26 commit d2b6e70
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 13 deletions.
107 changes: 107 additions & 0 deletions moondream/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import torch
import datasets
import diffusers
from .hyperparams import MOONDREAM_REVISION

auth_token = os.getenv("HF_ACCESS_TOKEN")

tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = transformers.AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream2",
revision=MOONDREAM_REVISION,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
).to("cuda")

def collate(batch):
images = []
questions = []

for sample in batch:
images.append(sample["image"])
questions.append("Describe this image.")

return images, questions

flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\
.take(1)

wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
.select_columns(["image"])\
.take(1)

ds = datasets.concatenate_datasets([
flickr_dataset,
wiki_art_dataset,
anime_dataset,
coco_dataset,
movie_poster_dataset,
cars_dataset,
website_dataset,
movie_scene_dataset,
])

data_loader = torch.utils.data.DataLoader(
ds,
batch_size=8,
collate_fn=collate
)

captions = []
for batch in data_loader:
images, questions = batch
answers = moondream.batch_answer(
images=images,
prompts=questions,
tokenizer=tokenizer
)

for ans in answers:
print(ans)
print()

captions.extend(answers)

ds = ds.add_column("caption", captions)

del moondream

pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.bfloat16,
token=auth_token,
).to("cuda")

image = pipe(
"A capybara holding a sign that reads Hello World",
num_inference_steps=28,
guidance_scale=3.5,
).images[0]
image.save("capybara.png")
6 changes: 4 additions & 2 deletions moondream/hyperparams.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
MOONDREAM_REVISION = "2024-08-26"

TEST_SIZE = 0.2

# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or
# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit.
EPOCHS = 1
EPOCHS = 2

# Number of samples to process in each batch. Set this to the highest value that doesn't cause an
# out-of-memory error. Decrease it if you're running out of memory.
BATCH_SIZE = 8

# Number of batches to process before updating the model. You can use this to simulate a higher batch
# size than your GPU can handle. Set this to 1 to disable gradient accumulation.
GRAD_ACCUM_STEPS = 2
GRAD_ACCUM_STEPS = 1

# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule
# of thumb, increase it by 1.4 times each time you double the effective batch size.
Expand Down
2 changes: 1 addition & 1 deletion moondream/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
img = Image.open("samples/Untitled.jpg")
md_answer = moondream.answer_question(
moondream.encode_image(img),
"Describe this image.",
"Is this image AI generated?",
tokenizer=tokenizer,
num_beams=4,
no_repeat_ngram_size=5,
Expand Down
62 changes: 52 additions & 10 deletions moondream/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-07-23"
TOTAL_DATA_SIZE = 8000
TOTAL_DATA_SIZE = 20000

diffusion_db_dataset = datasets.load_dataset("poloclub/diffusiondb", "2m_random_5k", split="train", trust_remote_code=True, streaming=True)\
.select_columns(["image"])\
Expand All @@ -24,7 +24,7 @@
"answer": "Yes."
}
})
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=2000, test_size=TEST_SIZE)
diffusion_db_dataset = utils.datasets.split_streaming_dataset(diffusion_db_dataset, total_size=5000, test_size=TEST_SIZE)

midjourney_dataset = datasets.load_dataset("brivangl/midjourney-v6-llava", split="train", streaming=True)\
.select_columns(["image"])\
Expand All @@ -35,7 +35,7 @@
"answer": "Yes."
}
})
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=2000, test_size=TEST_SIZE)
midjourney_dataset = utils.datasets.split_streaming_dataset(midjourney_dataset, total_size=5000, test_size=TEST_SIZE)

flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\
.select_columns(["image"])\
Expand All @@ -46,7 +46,7 @@
"answer": "No."
}
})
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=800, test_size=TEST_SIZE)
flickr_dataset = utils.datasets.split_streaming_dataset(flickr_dataset, total_size=1250, test_size=TEST_SIZE)

wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\
.select_columns(["image"])\
Expand All @@ -57,7 +57,7 @@
"answer": "No."
}
})
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=800, test_size=TEST_SIZE)
wiki_art_dataset = utils.datasets.split_streaming_dataset(wiki_art_dataset, total_size=1250, test_size=TEST_SIZE)

anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\
.select_columns(["image"])\
Expand All @@ -68,7 +68,7 @@
"answer": "No."
}
})
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=800, test_size=TEST_SIZE)
anime_dataset = utils.datasets.split_streaming_dataset(anime_dataset, total_size=1250, test_size=TEST_SIZE)

coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\
.select_columns(["image"])\
Expand All @@ -79,18 +79,51 @@
"answer": "No."
}
})
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=800, test_size=TEST_SIZE)
coco_dataset = utils.datasets.split_streaming_dataset(coco_dataset, total_size=1250, test_size=TEST_SIZE)

movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\
.select_columns(["age"])\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=800, test_size=TEST_SIZE)
movie_poster_dataset = utils.datasets.split_streaming_dataset(movie_poster_dataset, total_size=1250, test_size=TEST_SIZE)

cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No."
}
})
cars_dataset = utils.datasets.split_streaming_dataset(cars_dataset, total_size=1250, test_size=TEST_SIZE)

website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No.",
}
})
website_dataset = utils.datasets.split_streaming_dataset(website_dataset, total_size=1250, test_size=TEST_SIZE)

movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\
.select_columns(["image"])\
.map(lambda row: {
**row,
"qa": {
"question": "Is this image AI generated?",
"answer": "No.",
}
})
movie_scene_dataset = utils.datasets.split_streaming_dataset(movie_scene_dataset, total_size=1250, test_size=TEST_SIZE)

training_dataset = datasets.interleave_datasets([
diffusion_db_dataset["train"],
Expand All @@ -100,6 +133,9 @@
anime_dataset["train"],
coco_dataset["train"],
movie_poster_dataset["train"],
cars_dataset["train"],
website_dataset["train"],
movie_scene_dataset["train"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))
test_dataset = datasets.interleave_datasets([
diffusion_db_dataset["test"],
Expand All @@ -109,6 +145,9 @@
anime_dataset["test"],
coco_dataset["test"],
movie_poster_dataset["test"],
cars_dataset["test"],
website_dataset["test"],
movie_scene_dataset["test"],
], stopping_strategy="all_exhausted").cast_column("image", datasets.Image(decode=True))

print("Training and test dataset prepared.")
Expand Down Expand Up @@ -242,8 +281,11 @@ def lr_schedule(step, max_steps):
moondream.eval()
pathlib.Path("./samples").mkdir(parents=True, exist_ok=True)

total = 0
correct_predictions = 0
for sample in tqdm(test_dataset, desc="Validation"):
total += 1

md_answer = moondream.answer_question(
moondream.encode_image(sample['image']),
sample['qa']['question'],
Expand All @@ -257,6 +299,6 @@ def lr_schedule(step, max_steps):
if md_answer == ground_truth:
correct_predictions += 1

accuracy = correct_predictions * 100 / (TOTAL_DATA_SIZE * TEST_SIZE)
accuracy = correct_predictions * 100 / total

print(f"Model accuracy: f{accuracy}%")

0 comments on commit d2b6e70

Please sign in to comment.