From b4ca6d63b5305cc724eadc4c0321b22cd114164e Mon Sep 17 00:00:00 2001 From: Kenneth Date: Sat, 15 Feb 2025 20:06:57 +0000 Subject: [PATCH] modify data gen script + add ds upload script --- moondream/generate_dataset.py | 234 ++++++++++++++++++++++------------ moondream/hyperparams.py | 2 +- moondream/upload_dataset.py | 22 ++++ 3 files changed, 179 insertions(+), 79 deletions(-) create mode 100644 moondream/upload_dataset.py diff --git a/moondream/generate_dataset.py b/moondream/generate_dataset.py index 76c1c47..373f720 100644 --- a/moondream/generate_dataset.py +++ b/moondream/generate_dataset.py @@ -1,107 +1,185 @@ +import dotenv +dotenv.load_dotenv() + import os +import sys import torch import datasets import diffusers +import dotenv +import transformers +import argparse from .hyperparams import MOONDREAM_REVISION +print(f"HF_HOME set to {os.getenv('HF_HOME')}") + +# DATASET_SIZE = 10000 +# ROWS_PER_DS = 1250 +BATCH_SIZE = 4 +PARQUET_BATCH_SIZE = 200 +SKIP_PARQUET_BATCH = 203 + +parser = argparse.ArgumentParser() +parser.add_argument("-m", "--message") + +args = parser.parse_args() auth_token = os.getenv("HF_ACCESS_TOKEN") +if not auth_token: + print("huggingface access token not provided! please use the HF_ACCESS_TOKEN env var.") + sys.exit(1) +else: + print("huggingface access token loaded!") tokenizer = transformers.AutoTokenizer.from_pretrained("vikhyatk/moondream2") moondream = transformers.AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream2", revision=MOONDREAM_REVISION, trust_remote_code=True, - attn_implementation="flash_attention_2", - torch_dtype=torch.float16, -).to("cuda") + torch_dtype=torch.bfloat16, + device_map={"": "cuda"}, +) + +pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.bfloat16, + token=auth_token, + device_map="balanced", +) def collate(batch): images = [] - questions = [] + keywords = [] for sample in batch: images.append(sample["image"]) - questions.append("Describe this image.") + keywords.append([""]) - return images, questions + return images, keywords -flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\ - .select_columns(["image"])\ - .take(1) +# flickr_dataset = datasets.load_dataset("nlphuji/flickr30k", split="test", streaming=True)\ +# .select_columns(["image"])\ wiki_art_dataset = datasets.load_dataset("huggan/wikiart", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\ - .select_columns(["image"])\ - .take(1) - -ds = datasets.concatenate_datasets([ - flickr_dataset, - wiki_art_dataset, - anime_dataset, - coco_dataset, - movie_poster_dataset, - cars_dataset, - website_dataset, - movie_scene_dataset, -]) + .select_columns(["image"]) + +# anime_dataset_ft = datasets.Features({"image": datasets.Image(decode=True)}) +# anime_dataset = datasets.load_dataset("animelover/danbooru2022", "1-full", trust_remote_code=True, split="train", streaming=True, features=anime_dataset_ft)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence. Include the word anime in the sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [["anime"]] * ROWS_PER_DS) + +# coco_dataset = datasets.load_dataset("detection-datasets/coco", split="train", streaming=True)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [[""]] * ROWS_PER_DS) + +# movie_poster_dataset = datasets.load_dataset("skvarre/movie_posters-100k", split="train", streaming=True)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [[""]] * ROWS_PER_DS) + +# cars_dataset = datasets.load_dataset("tanganke/stanford_cars", split="train", streaming=True)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [[""]] * ROWS_PER_DS) + +# website_dataset = datasets.load_dataset("silatus/1k_Website_Screenshots_and_Metadata", split="train", streaming=True)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [[""]] * ROWS_PER_DS) + +# movie_scene_dataset = datasets.load_dataset("unography/movie-scenes-resized-captioned", split="train", streaming=True)\ +# .select_columns(["image"])\ +# .take(ROWS_PER_DS)\ +# .add_column("question", ["Describe this image in one sentence."] * ROWS_PER_DS)\ +# .add_column("keywords", [[""]] * ROWS_PER_DS) + +# ds = datasets.concatenate_datasets([ +# flickr_dataset, +# wiki_art_dataset, +# anime_dataset, +# coco_dataset, +# movie_poster_dataset, +# cars_dataset, +# website_dataset, +# movie_scene_dataset, +# ]).cast_column("image", datasets.Image(decode=True)).skip(SKIP_PARQUET_BATCH * PARQUET_BATCH_SIZE) + +ds = wiki_art_dataset.cast_column("image", datasets.Image(decode=True)) data_loader = torch.utils.data.DataLoader( ds, - batch_size=8, + batch_size=BATCH_SIZE, collate_fn=collate ) -captions = [] -for batch in data_loader: - images, questions = batch - answers = moondream.batch_answer( - images=images, - prompts=questions, - tokenizer=tokenizer - ) - - for ans in answers: - print(ans) - print() - - captions.extend(answers) - -ds = ds.add_column("caption", captions) - -del moondream - -pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( - "stabilityai/stable-diffusion-3.5-large", - torch_dtype=torch.bfloat16, - token=auth_token, -).to("cuda") - -image = pipe( - "A capybara holding a sign that reads Hello World", - num_inference_steps=28, - guidance_scale=3.5, -).images[0] -image.save("capybara.png") +temp_ds = { + "image": [], + "keywords": [], + "caption": [], + "generated_image": [] +} +temp_ds_size = 0 + +ds_features = datasets.Features({ + "image": datasets.Image(), + "keywords": datasets.Sequence(datasets.Value(dtype="string")), + "caption": datasets.Value(dtype="string"), + "generated_image": datasets.Image(), +}) + +generator = torch.Generator(device="cpu").manual_seed(12321313) + +batch_count = SKIP_PARQUET_BATCH + +for batch_index, batch in enumerate(data_loader): + images, keywords = batch + + prompts = [] + for i, img in enumerate(images): + caption = moondream.caption(img, length="normal")["caption"] + + add_keywords = len(keywords[i]) > 0 and keywords[i][0] != "" + for k in keywords[i]: + if k and k in caption: + add_keywords = False + break + + prompt = caption + if add_keywords: + prompt = f"{', '.join(keywords[i])}, {caption}" + + prompts.append(prompt) + + gen_imgs = pipe( + prompts, + num_inference_steps=28, + guidance_scale=3.5, + generator=generator, + max_sequence_length=512, + ).images + + temp_ds["image"].extend(images) + temp_ds["caption"].extend(prompts) + temp_ds["keywords"].extend(keywords) + temp_ds["generated_image"].extend(gen_imgs) + + temp_ds_size += BATCH_SIZE + + if temp_ds_size == PARQUET_BATCH_SIZE: + batch_ds = datasets.Dataset.from_dict(temp_ds, features=ds_features) + batch_ds.to_parquet( + f"data/batch_{batch_count}.parquet", + ) + temp_ds_size = 0 + temp_ds["image"].clear() + temp_ds["caption"].clear() + temp_ds["keywords"].clear() + temp_ds["generated_image"].clear() + + batch_count += 1 diff --git a/moondream/hyperparams.py b/moondream/hyperparams.py index 2a7faa6..3fdf430 100644 --- a/moondream/hyperparams.py +++ b/moondream/hyperparams.py @@ -1,4 +1,4 @@ -MOONDREAM_REVISION = "2024-08-26" +MOONDREAM_REVISION = "2025-01-09" TEST_SIZE = 0.2 diff --git a/moondream/upload_dataset.py b/moondream/upload_dataset.py new file mode 100644 index 0000000..4345968 --- /dev/null +++ b/moondream/upload_dataset.py @@ -0,0 +1,22 @@ +import sys +import os +import dotenv +dotenv.load_dotenv() + +access_token = os.getenv("HF_ACCESS_TOKEN") +if not access_token: + print("Please provide huggingface access token via HF_ACCESS_TOKEN.") + sys.exit(1) + +from huggingface_hub import HfApi +from huggingface_hub.constants import REPO_TYPE_DATASET + +api = HfApi(token=access_token) + +api.upload_large_folder( + repo_id="athenlab/reva", + folder_path="./data", + repo_type=REPO_TYPE_DATASET, + private=False, + print_report=True, +)