From dc6d4cb8004a35cacf6c1d20fff9c1d63f6cc112 Mon Sep 17 00:00:00 2001 From: nie3e Date: Tue, 30 Jan 2024 21:09:29 +0100 Subject: [PATCH 1/2] Add max tokens --- .gitignore | 3 ++- gradio_demo.py | 16 +++++++++++----- moondream/text_model.py | 5 +++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 0e5ac793..3b4dac95 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .venv -__pycache__ \ No newline at end of file +__pycache__ +.idea/ \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py index c4fb47d4..5e5ba4e0 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -17,12 +17,15 @@ text_model = TextModel(model_path).to(device=device, dtype=dtype) -def moondream(img, prompt): +def moondream(img, prompt, max_tokens): image_embeds = vision_encoder(img) streamer = TextIteratorStreamer(text_model.tokenizer, skip_special_tokens=True) thread = Thread( target=text_model.answer_question, - kwargs={"image_embeds": image_embeds, "question": prompt, "streamer": streamer}, + kwargs={ + "image_embeds": image_embeds, "question": prompt, + "streamer": streamer, "max_new_tokens": max_tokens + }, ) thread.start() @@ -41,12 +44,15 @@ def moondream(img, prompt): """ ) with gr.Row(): - prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...", scale=4) + with gr.Column(scale=4): + prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...") + max_tokens = gr.Slider(label="Max tokens", minimum=128, + maximum=2048, value=128) submit = gr.Button("Submit") with gr.Row(): img = gr.Image(type="pil", label="Upload an Image") output = gr.TextArea(label="Response", info="Please wait for a few seconds..") - submit.click(moondream, [img, prompt], output) - prompt.submit(moondream, [img, prompt], output) + submit.click(moondream, [img, prompt, max_tokens], output) + prompt.submit(moondream, [img, prompt, max_tokens], output) demo.queue().launch(debug=True) diff --git a/moondream/text_model.py b/moondream/text_model.py index 527458b0..0120e5b0 100644 --- a/moondream/text_model.py +++ b/moondream/text_model.py @@ -77,13 +77,14 @@ def generate( return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) - def answer_question(self, image_embeds, question, **kwargs): + def answer_question(self, image_embeds, question, max_new_tokens=128, + **kwargs): prompt = f"\n\nQuestion: {question}\n\nAnswer:" answer = self.generate( image_embeds, prompt, eos_text="", - max_new_tokens=128, + max_new_tokens=max_new_tokens, **kwargs, )[0] return re.sub("<$", "", re.sub("END$", "", answer)).strip() From 374b60abefd8578c7c4c6bb414059ddbc13612e6 Mon Sep 17 00:00:00 2001 From: nie3e Date: Tue, 30 Jan 2024 23:02:13 +0100 Subject: [PATCH 2/2] Add max tokens --- .gitignore | 3 ++- gradio_demo.py | 16 +++++++++++----- moondream/text_model.py | 5 +++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 0e5ac793..3b4dac95 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .venv -__pycache__ \ No newline at end of file +__pycache__ +.idea/ \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py index c4fb47d4..5e5ba4e0 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -17,12 +17,15 @@ text_model = TextModel(model_path).to(device=device, dtype=dtype) -def moondream(img, prompt): +def moondream(img, prompt, max_tokens): image_embeds = vision_encoder(img) streamer = TextIteratorStreamer(text_model.tokenizer, skip_special_tokens=True) thread = Thread( target=text_model.answer_question, - kwargs={"image_embeds": image_embeds, "question": prompt, "streamer": streamer}, + kwargs={ + "image_embeds": image_embeds, "question": prompt, + "streamer": streamer, "max_new_tokens": max_tokens + }, ) thread.start() @@ -41,12 +44,15 @@ def moondream(img, prompt): """ ) with gr.Row(): - prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...", scale=4) + with gr.Column(scale=4): + prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...") + max_tokens = gr.Slider(label="Max tokens", minimum=128, + maximum=2048, value=128) submit = gr.Button("Submit") with gr.Row(): img = gr.Image(type="pil", label="Upload an Image") output = gr.TextArea(label="Response", info="Please wait for a few seconds..") - submit.click(moondream, [img, prompt], output) - prompt.submit(moondream, [img, prompt], output) + submit.click(moondream, [img, prompt, max_tokens], output) + prompt.submit(moondream, [img, prompt, max_tokens], output) demo.queue().launch(debug=True) diff --git a/moondream/text_model.py b/moondream/text_model.py index c1b26bdb..404ccbf6 100644 --- a/moondream/text_model.py +++ b/moondream/text_model.py @@ -78,14 +78,15 @@ def generate( return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) def answer_question( - self, image_embeds, question, chat_history="", result_queue=None, **kwargs + self, image_embeds, question, chat_history="", result_queue=None, + max_new_tokens=128, **kwargs ): prompt = f"\n\n{chat_history}Question: {question}\n\nAnswer:" answer = self.generate( image_embeds, prompt, eos_text="", - max_new_tokens=128, + max_new_tokens=max_new_tokens, **kwargs, )[0] cleaned_answer = re.sub("<$", "", re.sub("END$", "", answer)).strip()