diff --git a/deploy/qwen2_vl/README.md b/deploy/qwen2_vl/README.md new file mode 100644 index 000000000..2eb092f4d --- /dev/null +++ b/deploy/qwen2_vl/README.md @@ -0,0 +1,54 @@ +# Qwen2-VL + +## 1. 模型介绍 + +[Qwen2-VL +](https: //qwenlm.github.io/blog/qwen2-vl/) 是大规模视觉语言模型。可以以图像、文本、检测框、视频作为输入,并以文本和检测框作为输出。本仓库提供paddle版本的`Qwen2-VL-2B-Instruct`和`Qwen2-VL-7B-Instruct`模型。 + +## 2 环境准备 +- **python >= 3.10** +- **paddlepaddle-gpu 要求版本develop** +``` +# 安装示例 +python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https: //www.paddlepaddle.org.cn/whl/linux/gpu/develop.html +``` +- **paddlenlp +``` +# 安装示例 +git submodule update --init --recursive +cd PaddleNLP +git reset --hard e91c2d3d634b12769c30aa419ddf931c20b7ca9f +pip install -e . +cd csrc +python setup_cuda.py install +``` + +> 注: +* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH +* (默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡 + +## 3 高性能推理 +# 在Qwen2-vl的推理优化中,我们在视觉模型部分继续使用paddlemix中的模型组网; + 但是在语言模型部分,我们调用Paddlenlp中高性能的qwen2语言模型,以得到高性能的Qwen2-vl推理版本。 + +### a. 文本&单张图像输入高性能推理 +```bash +python deploy/qwen2_vl/single_image_infer.py \ + --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ + --dtype bfloat16 \ + --benchmark 1 +``` + +- 在 NVIDIA A100-SXM4-80GB 上测试的性能如下: + + +- Qwen2-VL-2B-Instruct +| Paddle Inference| PyTorch | Paddle 动态图 | +| --------------- | ------------ | ------------ | +| 1.44 s | 2.35 s | 5.215 s | + + +- Qwen2-VL-7B-Instruct +| Paddle Inference| PyTorch | Paddle 动态图 | +| --------------- | ------------ | ------------ | +| 1.73 s | 4.4s | 6.339 s | diff --git a/deploy/qwen2_vl/single_image_infer.py b/deploy/qwen2_vl/single_image_infer.py new file mode 100644 index 000000000..76b679a99 --- /dev/null +++ b/deploy/qwen2_vl/single_image_infer.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from dataclasses import dataclass, field + +import numpy as np +import paddle +from paddlenlp.generation import GenerationConfig +from paddlenlp.trainer import PdArgumentParser +from paddlenlp.transformers import ( + AutoConfig, + AutoInferenceModelForCausalLM, + Qwen2Tokenizer, +) +from paddlenlp.trl import llm_utils + +from paddlemix.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2RotaryEmbedding, + Qwen2VLForConditionalGeneration, +) +from paddlemix.processors.qwen2_vl_processing import ( + Qwen2VLImageProcessor, + Qwen2VLProcessor, + process_vision_info, +) + +MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" +# MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" +vl_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16") + +# NOTE: (zhoukangkang、changwenbin) Because we only use the visual model here, +# in order to reduce video memory,we delete the language model. +del vl_model.model +paddle.device.cuda.empty_cache() + +image_processor = Qwen2VLImageProcessor() +tokenizer = Qwen2Tokenizer.from_pretrained(MODEL_NAME) +processor = Qwen2VLProcessor(image_processor, tokenizer) + +# min_pixels = 256*28*28 # 200704 +# max_pixels = 1280*28*28 # 1003520 +# processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels) + +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "paddlemix/demo_images/examples_image1.jpg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } +] + +# Preparation for inference +image_inputs, video_inputs = process_vision_info(messages) + +question = "Describe this image." +image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>" +text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n" + + +@dataclass +class PredictorArgument: + # NOTE: (zhoukangkang、changwenbin) + # These parameters are all copied from https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/predict/predictor.py + # For simplicity and ease of use, only the necessary parameters are retained here. + # If you want to know the exact meaning of these parameters, please refer to the link above. + + model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) + src_length = 1024 + min_length = 2 + max_length = 200 + top_k = 0 + top_p = 0.0 + temperature = 0.95 + repetition_penalty = 1.0 + dtype: str = field(default=None, metadata={"help": "Model dtype"}) + decode_strategy = "sampling" + mode = "dynamic" + inference_model = True + quant_type = "" + benchmark: bool = field( + default=False, + metadata={ + "help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. " + }, + ) + use_fake_parameter = False + block_attn = True + block_size = 64 + cachekv_int8_type = None + append_attn = True + total_max_length = 4096 + speculate_method = None + + +@dataclass +class ModelArgument: + model_type: str = field( + default=None, + metadata={"help": "the type of the model, which can be one of ['gpt-3', 'ernie-3.5-se', 'llama-img2txt']"}, + ) + + +def init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config: PredictorArgument): + assert len(inputs_embeds.shape) == 3 + batch_size = inputs_embeds.shape[0] + + model_inputs = {} + model_inputs["input_ids"] = paddle.zeros(shape=[batch_size, arg_config.total_max_length], dtype="int64") + model_inputs["inputs_embeds"] = inputs_embeds + + # I dislike write (arg_config.total_max_length + arg_config.block_size -1 ) // arg_config.block_size + assert arg_config.total_max_length % arg_config.block_size == 0 + + model_inputs["top_p"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.top_p, dtype="float32") + model_inputs["temperature"] = paddle.full( + shape=[batch_size, 1], fill_value=arg_config.temperature, dtype="float32" + ) + model_inputs["eos_token_id"] = paddle.to_tensor( + np.array(llm_utils.get_eos_token_id(tokenizer, generation_config)).reshape(-1, 1).astype("int64") + ) + model_inputs["penalty_score"] = paddle.full( + shape=[batch_size, 1], fill_value=arg_config.repetition_penalty, dtype="float32" + ) + model_inputs["frequency_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32") + model_inputs["presence_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32") + model_inputs["min_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.min_length, dtype="int64") + model_inputs["max_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.max_length, dtype="int64") + + position_ids, _ = vl_model.get_rope_index( + config.vision_config["spatial_merge_size"], + config.image_token_id, + config.video_token_id, + config.vision_start_token_id, + vision_model_inputs.get("input_ids"), + vision_model_inputs.get("image_grid_thw"), + vision_model_inputs.get("video_grid_thw", None), + vision_model_inputs.get("attention_mask"), + ) + position_start = position_ids[0][0][-1].item() + position_end = 4096 - position_ids.shape[-1] + position_start + position_value = ( + paddle.arange(position_start, position_end).reshape([1, 1, -1]).expand([position_ids.shape[0], 1, -1]) + ) + position_ids = paddle.concat([position_ids, position_value], axis=-1) + + head_dim = config.hidden_size // config.num_attention_heads + qwen2_Embedding = Qwen2RotaryEmbedding(head_dim, 4096, config.rope_theta) + cos = qwen2_Embedding.cos_cached + sin = qwen2_Embedding.sin_cached + + # NOTE: (zhoukangkang、changwenbin) Copied from PaddleMIX/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py, + # for calculating M-ROPE. + cos = cos[position_ids] + sin = sin[position_ids] + mrope_section = config.rope_scaling["mrope_section"] * 2 + cos = paddle.concat(x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1) + sin = paddle.concat(x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1) + + rope_emb = paddle.stack([cos, sin], axis=0) + rope_emb = rope_emb.reshape([rope_emb.shape[0], 1, rope_emb.shape[2], 1, rope_emb.shape[-1]]) + model_inputs["rope_emb"] = rope_emb + + model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64") + model_inputs["is_block_step"] = paddle.full(shape=[batch_size], fill_value=False, dtype="bool") + + cache_kvs_shape = fast_llm_model.get_cache_kvs_shape(fast_llm_model.config, batch_size) + cachekv_dtype = config.dtype if arg_config.cachekv_int8_type is None else "uint8" + model_inputs["cache_kvs"] = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in cache_kvs_shape] + + block_nums = arg_config.total_max_length // arg_config.block_size + model_inputs["block_tables"] = paddle.arange(block_nums, dtype="int32").tile([batch_size, 1]) + + seq_lens = inputs_embeds.shape[1] + model_inputs["seq_lens_this_time"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + model_inputs["seq_lens_encoder"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + model_inputs["seq_lens_decoder"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int32") + model_inputs["step_idx"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64") + model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + model_inputs["stop_flags"] = paddle.full(shape=[batch_size, 1], fill_value=False, dtype="bool") + model_inputs["stop_nums"] = paddle.full(shape=[1], fill_value=batch_size, dtype="int64") + model_inputs["pre_ids"] = paddle.full(shape=[batch_size, arg_config.max_length], fill_value=-1, dtype="int64") + model_inputs["next_tokens"] = paddle.full(shape=[batch_size, 1], fill_value=-1, dtype="int64") + + return model_inputs + + +parser = PdArgumentParser((PredictorArgument, ModelArgument)) +predictor_args, model_args = parser.parse_args_into_dataclasses() + +paddle.set_default_dtype(predictor_args.dtype) +config = AutoConfig.from_pretrained(MODEL_NAME) + +# NOTE: (changwenbin) This is for using the inference optimization of paddlenlp qwen2. +config.model_type = "qwen2" +generation_config = GenerationConfig.from_pretrained(MODEL_NAME) +fast_llm_model = AutoInferenceModelForCausalLM.from_pretrained( + MODEL_NAME, + config=config, + predictor_args=predictor_args, + model_args=model_args, + dtype=predictor_args.dtype, + tensor_parallel_degree=1, + tensor_parallel_rank=0, +) +fast_llm_model.eval() + +vl_model.model = fast_llm_model + + +def run_model(): + + vision_model_inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pd", + ) + inputs_embeds = vl_model.vision_forward(**vision_model_inputs) + llm_model_inputs = init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config=predictor_args) + generated_text = "" + while llm_model_inputs["not_need_stop"]: + generated_ids = fast_llm_model.generate(**llm_model_inputs) # already trimmed in paddle + llm_model_inputs["input_ids"] = generated_ids + llm_model_inputs["inputs_embeds"] = None + new_text_piece = processor.batch_decode( + generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + if new_text_piece == "<|im_end|>": + break + generated_text += new_text_piece + return generated_text + + +if predictor_args.benchmark: + print(f"Benchmarking {MODEL_NAME} ...") + warm_up = 3 + repeat_times = 10 + sumtime = 0.0 + times = repeat_times + warm_up + for i in range(times): + if i > 2: + paddle.device.synchronize() + starttime = datetime.datetime.now() + generated_text = run_model() + if i > 2: + paddle.device.synchronize() + endtime = datetime.datetime.now() + print("Final output_text:\n", generated_text) + + if i > 2: + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print(f"Single {MODEL_NAME} end to end time : ", duringtime, "ms") + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + + print(f"Single {MODEL_NAME} ave end to end time : ", sumtime / repeat_times, "ms") + +else: + generated_text = run_model() + print("Final output_text:\n", generated_text) diff --git a/paddlemix/MULLM_WebUI/README.md b/paddlemix/MULLM_WebUI/README.md new file mode 100644 index 000000000..4520e6d2c --- /dev/null +++ b/paddlemix/MULLM_WebUI/README.md @@ -0,0 +1,92 @@ +# PaddleMIX MULLM WebUI + +## 1. 简介 +PaddleMIX MULLM_WebUI 是一个基于PaddleMIX套件的交互式平台,主要支持多模态理解任务的模型微调与推理功能。MULLM_WebUI 提供了丰富的可视化操作界面,支持用户进行模型微调、推理等操作。 +![overview](./fig/overview.jpg) + +#### 支持模型 +| Model |Model Size |Inference | SFT | LoRA | +|-------|------------|-------|---|-----| +| qwen2_vl|2B/7B| ✅ | ✅ | ✅ || + +>* ✅: Supported +>* 🚧: In Progress +>* ❌: Not Supported + +## 2. 安装 +* 安装Paddle和PaddleMIX依赖 + +* 安装PaddleMIX MULLM WebUI依赖 +``` +pip install -r paddlemix/MULLM_WebUI/requirements.txt +``` + +## 3. 快速使用 + +### 3.1 启动 +``` +CUDA_VISIBLE_DEVICES=0 \ +GRADIO_SHARE=1 \ +GRADIO_SERVER_NAME=0.0.0.0 \ +GRADIO_ANALYTICS_ENABLED=0 \ +GRADIO_SERVER_PORT=8260 python paddlemix/MULLM_WebUI/run_web.py +``` +### 3.2 使用教程 +#### 3.2.1 新增数据集 + +* 下载 [Pokemon](https://huggingface.co/datasets/llamafactory/pokemon-gpt4o-captions/tree/main) 数据集。Pokemon-gpt4o-captions 是一个基于精灵宝可梦的中英双语视觉问答数据集,其问答结果由gpt4o生成。其中中文问答数据共计833条,数据集大小80.8M。 +* 放置中文数据集文件到 `./data/pokemon_gpt4o_zh/pokemon_gpt4o_zh.parquet` + +* 运行转换数据集脚本 +``` +python paddlemix/MULLM_WebUI/scripts/convert_dataset.py \ + --data_dir ./data \ + --dataset_dir pokemon_gpt4o_zh \ + --file_name ./data/pokemon_gpt4o_zh/pokemon_gpt4o_zh.parquet +``` +> 注:目前MULLM WebUI只支持单卡微调,为了达到更佳的训练效果,建议自己构建数据集或者按照[qwen2_vl ](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/paddlemix/examples/qwen2_vl)样例中提供的脚本进行微调。 +#### 3.2.2 模型微调 +1) 模型选择 + +![模型选择](./fig/train_1.jpg) + + +2) 超参数设置 +![超参数设置](./fig/train_2.jpg) + + +3) LoRA参数设置与模型训练 +![模型训练](./fig/train_3.jpg) + +#### 3.2.3 模型推理 + +1) 模型加载 +![模型加载](./fig/chat_1.jpg) + + +2) 多模态理解 +![多模态理解](./fig/chat_2.jpg) + +## 4. 使用展示 + + +1) 模型微调 +![模型微调样例](./fig/example_train.jpg) + + +2) 模型推理 +![模型推理样例](./fig/example_chat.jpg) + +## 参考文献 + +```BibTeX +@inproceedings{zheng2024llamafactory, + title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, + author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma}, + booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, + address={Bangkok, Thailand}, + publisher={Association for Computational Linguistics}, + year={2024}, + url={http://arxiv.org/abs/2403.13372} +} +``` diff --git a/paddlemix/MULLM_WebUI/__init__.py b/paddlemix/MULLM_WebUI/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/MULLM_WebUI/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/MULLM_WebUI/chatter.py b/paddlemix/MULLM_WebUI/chatter.py new file mode 100644 index 000000000..200e55f72 --- /dev/null +++ b/paddlemix/MULLM_WebUI/chatter.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from threading import Thread +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Tuple + +import gradio as gr +import paddle +from paddlenlp.generation import TextIteratorStreamer +from paddlenlp.peft import LoRAModel +from paddlenlp.transformers import Qwen2Tokenizer +from paddlenlp.utils.import_utils import import_module + +from ..processors.qwen2_vl_processing import ( + Qwen2VLImageProcessor, + Qwen2VLProcessor, + process_vision_info, +) +from .common import ChatState, change_checkbox, chat_ready, get_save_dir +from .extras.constants import FINAL_CHECKPOINT_NAME, MODEL_MAPPING +from .locales import ALERTS, LOCALES + +if TYPE_CHECKING: + from .manager import Manager + + +class WebChatModel: + def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: + self.manager = manager + self.demo_mode = demo_mode + self.engine = None + self.processor = None + self.tokenizer = None + self.terminators = ["<|im_end|>"] + # self.min_pixels = 256 * 28 * 28 # 200704 + # self.max_pixels = 1280 * 28 * 28 # 1003520 + + @property + def loaded(self) -> bool: + return self.engine is not None + + def load_model(self, data) -> Generator[str, None, None]: + engine_cls = self.get_model(data) + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + finetuning_type = get("top.finetuning_type") + infer_dtype = get("infer.infer_dtype") + checkpoint_path = get("top.checkpoint_path") + state_checkbox_group = get("infer.state_checkbox_group") + selected_ckpt = get("infer.ckpt_box") + if selected_ckpt == FINAL_CHECKPOINT_NAME: + ckpt_path = os.path.join(get_save_dir(model_name, finetuning_type), checkpoint_path) + elif selected_ckpt != "": + ckpt_path = os.path.join(get_save_dir(model_name, finetuning_type), checkpoint_path, selected_ckpt) + else: + ckpt_path = "" + + error = "" + yield ALERTS["info_loading"][lang], state_checkbox_group + + if self.loaded: + error = ALERTS["err_exists"][lang] + yield error, change_checkbox(state_checkbox_group, True, choice_type=LOCALES["model_tag"][lang]) + return + elif not model_name: + error = ALERTS["err_no_model"][lang] + elif not model_path: + error = ALERTS["err_no_path"][lang] + + self.engine = engine_cls.from_pretrained(model_path, dtype=infer_dtype) + self.processor = self.get_processor(model_path) + + # load lora + if ckpt_path != "": + self.engine = LoRAModel.from_pretrained(model=self.engine, lora_path=ckpt_path) + + if error: + gr.Warning(error) + yield error + return + + yield ALERTS["info_loaded"][lang], change_checkbox( + state_checkbox_group, True, choice_type=LOCALES["model_tag"][lang] + ) + + def unload_model(self, data) -> Generator[str, None, None]: + lang = data[self.manager.get_elem_by_id("top.lang")] + state_checkbox_group = data[self.manager.get_elem_by_id("infer.state_checkbox_group")] + + if not self.loaded: + yield ALERTS["info_unload_error"][lang], state_checkbox_group + return + + yield ALERTS["info_unloading"][lang], state_checkbox_group + self.engine = None + state_checkbox_group.remove(LOCALES["model_tag"][lang]) + paddle.device.cuda.empty_cache() + yield ALERTS["info_unloaded"][lang], state_checkbox_group + + def multi_round_chat( + self, + lang, + chatbot, + messages, + question_box, + question_type, + image, + video, + chat_checkbox, + max_new_tokens, + top_p, + temperature, + seed, + info_box, + ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]: + chat_state = { + "model": LOCALES["model_tag"][lang], + "image": LOCALES["image_tag"][lang], + "video": LOCALES["video_tag"][lang], + "question": LOCALES["question_tag"][lang], + } + + check_result = chat_ready(chat_checkbox, chat_state) + if check_result == ChatState.MISSING_QUESTION: + yield chatbot, messages, gr.update(value=question_box), gr.update( + value=ALERTS["info_query"][lang] + ), gr.update(interactive=True) + return + if check_result == ChatState.MISSING_MODEL: + yield chatbot, messages, gr.update(value=question_box), gr.update( + value=ALERTS["info_upload_model"][lang] + ), gr.update(interactive=True) + return + if check_result == ChatState.MISSING_FILE: + yield chatbot, messages, gr.update(value=question_box), gr.update( + value=ALERTS["info_upload_file"][lang] + ), gr.update(interactive=True) + return + msg = { + "role": "user", + "content": [], + } + last_img_inp = None + last_video_inp = None + + # find last image and video input + for m in messages[::-1]: + for content in m["content"]: + if "video" in content.keys(): + last_video_inp = content["video"] + + if "image" in content.keys(): + last_img_inp = content["image"] + if last_img_inp is not None and last_video_inp is not None: + break + + if image is not None and image == last_img_inp: + image = None + + if video is not None and video == last_video_inp: + video = None + + if question_type == "image" and image is not None: + msg["content"].append({"type": "image", "image": image}) + + if question_type == "video" and video is not None: + msg["content"].append({"type": "video", "video": video, "fps": 1, "max_pixels": 360 * 420}) + + chatbot += [[question_box, None]] + msg["content"].append({"type": "text", "text": f"{question_box}"}) + + messages.append(msg) + paddle.seed(seed=seed) + generate_cfg = dict( + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature, + ) + response = "" + res = self.generate(messages, generate_cfg) + for text in res: + response += text + yield chatbot + [[None, response]], messages + [ + {"role": "assistant", "content": [{"type": "text", "text": response}]} + ], gr.update(value=question_box), gr.update(value=ALERTS["info_generating"][lang]), gr.update( + interactive=False + ) + + yield chatbot + [[None, response]], messages + [ + {"role": "assistant", "content": [{"type": "text", "text": response}]} + ], gr.update(value=question_box), gr.update(value=ALERTS["info_generated"][lang]), gr.update(interactive=True) + + def get_model(self, data): + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name = get("top.model_name") + model_module = import_module(f"paddlemix.models.{MODEL_MAPPING[model_name]}") + + return model_module + + def get_processor(self, model_path): + image_processor = Qwen2VLImageProcessor() + tokenizer = Qwen2Tokenizer.from_pretrained(model_path) + # processor = Qwen2VLProcessor(image_processor, tokenizer,min_pixels=self.min_pixels, max_pixels=self.max_pixels) + processor = Qwen2VLProcessor(image_processor, tokenizer) + self.tokenizer = tokenizer + + return processor + + def generate(self, messages, generate_cfg): + image_inputs, video_inputs = process_vision_info(messages) + text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pd", + ) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pd", + ) + + streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_special_tokens=True) + generation_kwargs = { + "streamer": streamer, + } + + generation_kwargs.update(generate_cfg) + generation_kwargs.update(inputs) + + thread = Thread(target=self.engine.generate, kwargs=generation_kwargs) + """Class Method: *.start, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" + thread.start() + return streamer diff --git a/paddlemix/MULLM_WebUI/common.py b/paddlemix/MULLM_WebUI/common.py new file mode 100644 index 000000000..84cf9037d --- /dev/null +++ b/paddlemix/MULLM_WebUI/common.py @@ -0,0 +1,648 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import sys +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +import gradio as gr +import numpy as np +import paddle +from paddlenlp.transformers.utils import cached_file +from paddlenlp.utils.import_utils import is_datasets_available + +from ..utils.log import logger +from .extras.constants import ( + DATA_CONFIG, + DEFAULT_TEMPLATE, + FILEEXT2TYPE, + FINAL_CHECKPOINT_NAME, + PEFT_METHODS, + STAGES_USE_PAIR_DATA, + SUPPORTED_MODELS, + TRAINING_STAGES, +) +from .extras.data import align_dataset, has_tokenized_data, merge_dataset, split_dataset +from .extras.packages import use_modelscope, use_openmind +from .extras.preprocess import get_preprocess_and_print_func +from .locales import LOCALES + +if is_datasets_available(): + from datasets import DatasetDict, load_dataset, load_from_disk + +if TYPE_CHECKING: + from datasets import Dataset, DatasetModule, IterableDataset + from paddlenlp.trainer import Seq2SeqTrainingArguments + from paddlenlp.transformers import PretrainedTokenizer, ProcessorMixin + + from .extras.args import DataArguments, ModelArguments + from .extras.template import Template + +DEFAULT_CACHE_DIR = "cache" +DEFAULT_CONFIG_DIR = "config" +DEFAULT_DATA_DIR = "data" +DEFAULT_SAVE_DIR = "saves" +USER_CONFIG = "user_config.yaml" + + +class ChatState(Enum): + READY = 1 + MISSING_QUESTION = -1 + MISSING_MODEL = -2 + MISSING_IMAGE = -3 + MISSING_VIDEO = -4 + MISSING_FILE = -5 + + +@dataclass +class DatasetAttr: + r""" + Dataset attributes. + """ + + # basic configs + load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] + dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False + # extra configs + subset: Optional[str] = None + split: str = "train" + folder: Optional[str] = None + num_samples: Optional[int] = None + # common columns + system: Optional[str] = None + tools: Optional[str] = None + images: Optional[str] = None + videos: Optional[str] = None + # rlhf columns + chosen: Optional[str] = None + rejected: Optional[str] = None + kto_tag: Optional[str] = None + # alpaca columns + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + # sharegpt columns + messages: Optional[str] = "conversations" + # sharegpt tags + role_tag: Optional[str] = "from" + content_tag: Optional[str] = "value" + user_tag: Optional[str] = "human" + assistant_tag: Optional[str] = "gpt" + observation_tag: Optional[str] = "observation" + function_tag: Optional[str] = "function_call" + system_tag: Optional[str] = "system" + + def __repr__(self) -> str: + return self.dataset_name + + def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: + setattr(self, key, obj.get(key, default)) + + +def get_save_dir(*paths: str) -> os.PathLike: + r""" + Gets the path to saved model checkpoints. + """ + if os.path.sep in paths[-1]: + logger(30, "Found complex path, some features may be not available.") + return paths[-1] + + paths = (path.replace(" ", "").strip() for path in paths) + return os.path.join(DEFAULT_SAVE_DIR, *paths) + + +def get_model_path(model_name: str) -> str: + r""" + Gets the model path according to the model name. + """ + model_path: str = SUPPORTED_MODELS.get(model_name) + return model_path + + +def get_model_info(model_name: str) -> Tuple[str, str]: + r""" + Gets the necessary information of this model. + + Returns: + model_path (str) + template (str) + """ + return get_model_path(model_name), get_template(model_name) + + +def get_template(model_name: str) -> str: + r""" + Gets the template name if the model is a chat model. + """ + return DEFAULT_TEMPLATE.get(model_name, "default") + + +def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": + r""" + Lists all available checkpoints. + """ + checkpoints = [] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if os.path.isdir(os.path.join(save_dir, checkpoint)): + checkpoints.append(checkpoint) + if finetuning_type in PEFT_METHODS: + yield gr.Dropdown(value=None, choices=checkpoints, multiselect=False) + return + else: + yield gr.Dropdown(value=None, choices=checkpoints, multiselect=False) + return + + +def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: + r""" + Loads dataset_info.json. + """ + if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): + logger(20, f"dataset_dir is {dataset_dir}, using online dataset.") + return {} + + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: + return json.load(f) + except Exception as err: + logger(30, f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.") + return {} + + +def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": + r""" + Lists all available datasets in the dataset dir for the training stage. + """ + dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) + ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA + datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] + return gr.Dropdown(choices=datasets) + + +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: + r""" + Gets the attributes of the datasets. + """ + if dataset_names is None: + dataset_names = [] + + if dataset_dir == "ONLINE": + dataset_info = None + else: + if dataset_dir.startswith("REMOTE:"): + config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + else: + config_path = os.path.join(dataset_dir, DATA_CONFIG) + + try: + with open(config_path) as f: + dataset_info = json.load(f) + except Exception as err: + if len(dataset_names) != 0: + raise ValueError(f"Cannot open {config_path} due to {str(err)}.") + + dataset_info = None + + dataset_list: List["DatasetAttr"] = [] + for name in dataset_names: + if dataset_info is None: # dataset_dir is ONLINE + if use_modelscope(): + load_from = "ms_hub" + elif use_openmind(): + load_from = "om_hub" + else: + load_from = "hf_hub" + dataset_attr = DatasetAttr(load_from, dataset_name=name) + dataset_list.append(dataset_attr) + continue + + if name not in dataset_info: + raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.") + + has_hf_url = "hf_hub_url" in dataset_info[name] + has_ms_url = "ms_hub_url" in dataset_info[name] + has_om_url = "om_hub_url" in dataset_info[name] + + if has_hf_url or has_ms_url or has_om_url: + if has_ms_url and (use_modelscope() or not has_hf_url): + dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) + elif has_om_url and (use_openmind() or not has_hf_url): + dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"]) + else: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) + dataset_attr.set_attr("subset", dataset_info[name]) + dataset_attr.set_attr("split", dataset_info[name], default="train") + dataset_attr.set_attr("folder", dataset_info[name]) + dataset_attr.set_attr("num_samples", dataset_info[name]) + + if "columns" in dataset_info[name]: + column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"] + if dataset_attr.formatting == "alpaca": + column_names.extend(["prompt", "query", "response", "history"]) + else: + column_names.extend(["messages"]) + + for column_name in column_names: + dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) + + if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: + tag_names = ( + "role_tag", + "content_tag", + "user_tag", + "assistant_tag", + "observation_tag", + "function_tag", + "system_tag", + ) + for tag in tag_names: + dataset_attr.set_attr(tag, dataset_info[name]["tags"]) + + dataset_list.append(dataset_attr) + + return dataset_list + + +def change_checkbox(checkbox, x, lang=None, tag=None, choice_type=None): + if choice_type is None: + choice_type = LOCALES[tag][lang] + if (x == "" or x is None) and choice_type in checkbox: + checkbox.remove(choice_type) + elif (x != "" or (tag != "question_tag" and x is not None)) and choice_type not in checkbox: + checkbox.append(choice_type) + return checkbox + + +def chat_ready(checkbox, state): + + if state["model"] not in checkbox: + return ChatState.MISSING_MODEL + + if state["question"] not in checkbox: + return ChatState.MISSING_QUESTION + + if state["image"] not in checkbox and state["video"] not in checkbox: + return ChatState.MISSING_FILE + + return ChatState.READY + + +# train +def get_device_count(): + return paddle.device.cuda.device_count() + + +# dataset +def _load_single_dataset( + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + r""" + Loads a single dataset and aligns it to the standard format. + """ + logger(20, f"Loading dataset {dataset_attr}...") + data_path, data_name, data_dir, data_files = None, None, None, None + if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: + data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "file": + data_files = [] + local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + if os.path.isdir(local_path): # is directory + for file_name in os.listdir(local_path): + data_files.append(os.path.join(local_path, file_name)) + elif os.path.isfile(local_path): # is file + data_files.append(local_path) + else: + raise ValueError(f"File {local_path} not found.") + + data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None) + if data_path is None: + raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) + + if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files): + raise ValueError("File types should be identical.") + else: + raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") + + if dataset_attr.load_from == "ms_hub": + # require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + from modelscope import MsDataset # type: ignore + from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore + + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=data_args.streaming, + ) + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() + + elif dataset_attr.load_from == "om_hub": + # require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") + from openmind import OmDataset # type: ignore + from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore + + cache_dir = model_args.cache_dir or OM_DATASETS_CACHE + dataset = OmDataset.load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.om_hub_token, + streaming=data_args.streaming, + ) + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=data_args.streaming, + num_proc=data_args.preprocessing_num_workers, + ) + + if dataset_attr.num_samples is not None and not data_args.streaming: + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." + dataset = dataset.select(indexes) + logger(20, f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.") + + if data_args.max_samples is not None: # truncate dataset + max_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(max_samples)) + + return align_dataset(dataset, dataset_attr, data_args, training_args) + + +def _get_merged_dataset( + dataset_names: Optional[Sequence[str]], + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], +) -> Optional[Union["Dataset", "IterableDataset"]]: + r""" + Gets the merged datasets in the standard format. + """ + if dataset_names is None: + return None + + datasets = [] + for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + + datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + + return merge_dataset(datasets, data_args, seed=training_args.seed) + + +def _get_preprocessed_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["sft"], + template: "Template", + tokenizer: "PretrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, + is_eval: bool = False, +) -> Optional[Union["Dataset", "IterableDataset"]]: + r""" + Preprocesses the dataset, including format checking and tokenization. + """ + if dataset is None: + return None + + preprocess_func, print_function = get_preprocess_and_print_func( + data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + # kwargs = dict( + # num_proc=data_args.preprocessing_num_workers, + # load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + # desc="Running tokenizer on dataset", + # ) + kwargs = dict( + num_proc=1, + load_from_cache_file=False, + desc="Running tokenizer on dataset", + ) + + dataset = dataset.map( + preprocess_func, + batched=True, + batch_size=data_args.preprocessing_batch_size, + remove_columns=column_names, + **kwargs, + ) + + if training_args.should_log: + try: + print("eval example:" if is_eval else "training example:") + print_function(next(iter(dataset))) + except StopIteration: + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + + return dataset + + +def get_dataset( + template: "Template", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["sft"], + tokenizer: "PretrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, +) -> "DatasetModule": + r""" + Gets the train dataset and optionally gets the evaluation dataset. + """ + # Load tokenized dataset + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger(30, "Loading dataset from disk will ignore other data arguments.") + dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + logger(20, f"Loaded tokenized dataset from {data_args.tokenized_path}.") + + dataset_module: Dict[str, "Dataset"] = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + + return dataset_module + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + # Load and preprocess dataset + with training_args.main_process_first(desc="load dataset"): + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + + with training_args.main_process_first(desc="pre-process dataset"): + dataset = _get_preprocessed_dataset( + dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False + ) + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + if data_args.val_size > 1e-6: + dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) + else: + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["validation"] = eval_dataset + + dataset_dict = DatasetDict(dataset_dict) + + if data_args.tokenized_path is not None: + if training_args.should_save: + dataset_dict.save_to_disk(data_args.tokenized_path) + logger(20, f"Tokenized dataset saved at {data_args.tokenized_path}.") + logger(20, f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") + + sys.exit(0) + + dataset_module = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + return dataset_module + + +def list_config_paths(current_time: str) -> "gr.Dropdown": + r""" + Lists all the saved configuration files. + """ + config_files = [f"{current_time}.yaml"] + if os.path.isdir(DEFAULT_CONFIG_DIR): + for file_name in os.listdir(DEFAULT_CONFIG_DIR): + if file_name.endswith(".yaml") and file_name not in config_files: + config_files.append(file_name) + + return gr.Dropdown(choices=config_files) + + +def get_time() -> str: + r""" + Gets current date and time. + """ + return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") + + +def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": + r""" + Lists all the directories that can resume from. + """ + output_dirs = [f"train_{current_time}"] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for output in os.listdir(save_dir): + output_dirs.append(output) + return gr.Dropdown(choices=output_dirs) + + +def list_checkpoint_item(model_name, finetune_type, checkpoint_path): + items = [] + if checkpoint_path == "" or not isinstance(checkpoint_path, str): + return gr.update(choices=items) + cur_path = os.path.join(get_save_dir(model_name, finetune_type), checkpoint_path) + if not os.path.exists(cur_path): + return gr.update(choices=items) + for ckpt in os.listdir(cur_path): + if "checkpoint" in ckpt: + items.append(ckpt) + elif "lora_model_state.pdparams" in ckpt: + items.append(FINAL_CHECKPOINT_NAME) + items.sort() + return gr.update(choices=items) diff --git a/paddlemix/MULLM_WebUI/components/__init__.py b/paddlemix/MULLM_WebUI/components/__init__.py new file mode 100644 index 000000000..986fec615 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .chatbot import create_chat_box +from .infer import create_infer_tab +from .top import create_top +from .train import create_train_tab + +__all__ = [ + "create_chat_box", + "create_infer_tab", + "create_top", + "create_train_tab", +] diff --git a/paddlemix/MULLM_WebUI/components/chatbot.py b/paddlemix/MULLM_WebUI/components/chatbot.py new file mode 100644 index 000000000..13239b4e0 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/chatbot.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Tuple + +from ..extras.packages import is_gradio_available + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_chat_box( + engine: "Engine", visible: bool = False +) -> Tuple["Component", "Component", Dict[str, "Component"]]: + with gr.Column(visible=visible) as chat_box: + chatbot = gr.Chatbot(show_copy_button=True) + messages = gr.State([]) + with gr.Row(): + with gr.Column(scale=4): + with gr.Row(): + with gr.Column(): + role = gr.Dropdown(choices=["user", "observation"], value="user") + # role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value="") + + system = gr.Textbox(show_label=False) + tools = gr.Textbox(show_label=False, lines=3) + + with gr.Column() as mm_box: + with gr.Tab("Image"): + image = gr.Image(sources=["upload"], type="pil") + + with gr.Tab("Video"): + video = gr.Video(sources=["upload"]) + + query = gr.Textbox(show_label=False, lines=8) + submit_btn = gr.Button(variant="primary") + + with gr.Column(scale=1): + max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) + top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) + temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) + clear_btn = gr.Button() + + tools.input(inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) + + submit_btn.click(engine.chatter.append, [chatbot, messages, role, query], [chatbot, messages, query],).then( + engine.chatter.stream, + [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature], + [chatbot, messages], + ) + clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) + + return ( + chatbot, + messages, + dict( + chat_box=chat_box, + role=role, + system=system, + tools=tools, + mm_box=mm_box, + image=image, + video=video, + query=query, + submit_btn=submit_btn, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature, + clear_btn=clear_btn, + ), + ) + + +def enable_chat_btn(checkbox): + if "Question" in checkbox and "Model" in checkbox and ("Video" in checkbox or "Image" in checkbox): + return gr.update(interactive=True) + else: + return gr.update(interactive=False) + + +def enable_checkpoint_box(checkpoint_path): + if isinstance(checkpoint_path, str) and checkpoint_path != "": + return gr.update(visible=True) + return gr.update(visible=False) diff --git a/paddlemix/MULLM_WebUI/components/data.py b/paddlemix/MULLM_WebUI/components/data.py new file mode 100644 index 000000000..aa54d4902 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/data.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +from ..extras.constants import DATA_CONFIG +from ..extras.packages import is_gradio_available + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + +PAGE_SIZE = 2 + + +def prev_page(page_index: int) -> int: + return page_index - 1 if page_index > 0 else page_index + + +def next_page(page_index: int, total_num: int) -> int: + return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index + + +def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: + dataset_info = json.load(f) + except Exception: + return gr.Button(interactive=False) + + if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]: + return gr.Button(interactive=False) + + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)): + return gr.Button(interactive=True) + else: + return gr.Button(interactive=False) + + +def _load_data_file(file_path: str) -> List[Any]: + with open(file_path, encoding="utf-8") as f: + if file_path.endswith(".json"): + return json.load(f) + elif file_path.endswith(".jsonl"): + return [json.loads(line) for line in f] + else: + return list(f) + + +def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: + with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: + dataset_info = json.load(f) + + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path): + data = _load_data_file(data_path) + else: + data = [] + for file_name in os.listdir(data_path): + data.extend(_load_data_file(os.path.join(data_path, file_name))) + + return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) + + +def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: + data_preview_btn = gr.Button(interactive=False, scale=1) + with gr.Column(visible=False, elem_classes="modal-box") as preview_box: + with gr.Row(): + preview_count = gr.Number(value=0, interactive=False, precision=0) + page_index = gr.Number(value=0, interactive=False, precision=0) + + with gr.Row(): + prev_btn = gr.Button() + next_btn = gr.Button() + close_btn = gr.Button() + + with gr.Row(): + preview_samples = gr.JSON() + + dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then( + lambda: 0, outputs=[page_index], queue=False + ) + data_preview_btn.click( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + prev_btn.click(prev_page, [page_index], [page_index], queue=False).then( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False) + return dict( + data_preview_btn=data_preview_btn, + preview_count=preview_count, + page_index=page_index, + prev_btn=prev_btn, + next_btn=next_btn, + close_btn=close_btn, + preview_samples=preview_samples, + ) diff --git a/paddlemix/MULLM_WebUI/components/infer.py b/paddlemix/MULLM_WebUI/components/infer.py new file mode 100644 index 000000000..f4aa8afd0 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/infer.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from ..common import change_checkbox, list_checkpoint_item +from ..extras.packages import is_gradio_available +from .chatbot import enable_checkpoint_box + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + checkpoint_path: "gr.Dropdown" = engine.manager.get_elem_by_id("top.checkpoint_path") + + elem_dict = dict() + + with gr.Row(): + infer_dtype = gr.Dropdown(choices=["float16", "bfloat16", "float32"], value="float16") + ckpt_box = gr.Dropdown(value="", visible=False) + + with gr.Row(): + load_btn = gr.Button() + unload_btn = gr.Button() + + with gr.Row(): + with gr.Column(): + with gr.Tab("Image"): + image = gr.Image(type="pil", sources=["upload", "webcam", "clipboard"]) + + with gr.Tab("Video"): + video = gr.Video(sources=["upload"]) + + state_checkbox_group = gr.CheckboxGroup(value=[], interactive=False) + info_box = gr.Textbox(show_label=True, interactive=False) + + with gr.Column(scale=1): + question_box = gr.Textbox(value="", interactive=True) + question_type = gr.Dropdown(choices=["image", "video"], value="image") + seed_box = gr.Textbox(value=42, interactive=True) + max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) + top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) + temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) + chat_btn = gr.Button() + clear_btn = gr.Button() + + chatbot = gr.Chatbot(show_copy_button=True) + messages = gr.State([]) + + input_elems.update({image}) + input_elems.update({video}) + input_elems.update({chatbot}) + input_elems.update({question_box, info_box}) + input_elems.update({messages}) + input_elems.update({infer_dtype, ckpt_box}) + input_elems.update({state_checkbox_group}) + input_elems.update({question_type, max_new_tokens, top_p, temperature, seed_box}) + elem_dict.update( + dict( + infer_dtype=infer_dtype, + ckpt_box=ckpt_box, + load_btn=load_btn, + unload_btn=unload_btn, + info_box=info_box, + question_box=question_box, + seed_box=seed_box, + question_type=question_type, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature, + chat_btn=chat_btn, + clear_btn=clear_btn, + state_checkbox_group=state_checkbox_group, + image=image, + video=video, + chatbot=chatbot, + messages=messages, + ) + ) + + # Button + load_btn.click(engine.chatter.load_model, input_elems, [info_box, state_checkbox_group]) + unload_btn.click(engine.chatter.unload_model, input_elems, [info_box, state_checkbox_group]) + + clear_btn.click(lambda: ([], [], "", None, None), outputs=[chatbot, messages, question_box, image, video]) + + chat_btn.click( + engine.chatter.multi_round_chat, + inputs=[ + engine.manager._id_to_elem["top.lang"], + chatbot, + messages, + question_box, + question_type, + image, + video, + state_checkbox_group, + max_new_tokens, + top_p, + temperature, + seed_box, + info_box, + ], + outputs=[chatbot, messages, question_box, info_box, chat_btn], + ) + question_box.change( + change_checkbox, + inputs=[state_checkbox_group, question_box, engine.manager._id_to_elem["top.lang"], gr.State("question_tag")], + outputs=state_checkbox_group, + every=3, + ) + + image.change( + change_checkbox, + inputs=[state_checkbox_group, image, engine.manager._id_to_elem["top.lang"], gr.State("image_tag")], + outputs=state_checkbox_group, + ) + video.change( + change_checkbox, + inputs=[state_checkbox_group, video, engine.manager._id_to_elem["top.lang"], gr.State("video_tag")], + outputs=state_checkbox_group, + ) + checkpoint_path.change( + list_checkpoint_item, + [ + engine.manager._id_to_elem["top.model_name"], + engine.manager._id_to_elem["top.finetuning_type"], + engine.manager._id_to_elem["top.checkpoint_path"], + ], + [ckpt_box], + queue=False, + ).then(enable_checkpoint_box, inputs=[checkpoint_path], outputs=[ckpt_box], show_progress=False) + + return elem_dict diff --git a/paddlemix/MULLM_WebUI/components/top.py b/paddlemix/MULLM_WebUI/components/top.py new file mode 100644 index 000000000..bd73d2801 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/top.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +import gradio as gr + +from ..common import get_model_info, list_checkpoints +from ..extras.constants import METHODS +from ..extras.template import TEMPLATES + +if TYPE_CHECKING: + from gradio.components import Component + + +def create_top() -> Dict[str, "Component"]: + # available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + available_models = ["Qwen2-VL-2B-Instruct", "Qwen2-VL-7B-Instruct"] + with gr.Row(): + lang = gr.Dropdown(choices=["en", "zh"], scale=1, value="en") + model_name = gr.Dropdown(choices=available_models, scale=3, value="Qwen2-VL-2B-Instruct") + model_path = gr.Textbox(scale=3) + + with gr.Row(): + finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) + checkpoint_path = gr.Dropdown(scale=6, value="") + + with gr.Row(): + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="qwen2_vl", scale=2) + model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) + checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) + + finetuning_type.change(inputs=[finetuning_type], outputs=[finetuning_type]).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) + + return dict( + lang=lang, + model_name=model_name, + model_path=model_path, + template=template, + finetuning_type=finetuning_type, + checkpoint_path=checkpoint_path, + ) diff --git a/paddlemix/MULLM_WebUI/components/train.py b/paddlemix/MULLM_WebUI/components/train.py new file mode 100644 index 000000000..10ec44818 --- /dev/null +++ b/paddlemix/MULLM_WebUI/components/train.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +import gradio as gr +from paddlenlp.trainer.trainer_utils import SchedulerType + +from ..common import ( + get_device_count, + list_checkpoints, + list_config_paths, + list_datasets, + list_output_dirs, +) +from ..components.data import create_preview_box +from ..extras.constants import DEFAULT_DATA_DIR, TRAINING_STAGES + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + training_stage = gr.Dropdown( + choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 + ) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) + dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) + preview_elems = create_preview_box(dataset_dir, dataset) + + input_elems.update({training_stage, dataset_dir, dataset}) + elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) + + with gr.Row(): + learning_rate = gr.Textbox(value="5e-5") + num_train_epochs = gr.Textbox(value="3.0") + max_grad_norm = gr.Textbox(value="1.0") + max_samples = gr.Textbox(value="100000") + compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16") + + input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type}) + elem_dict.update( + dict( + learning_rate=learning_rate, + num_train_epochs=num_train_epochs, + max_grad_norm=max_grad_norm, + max_samples=max_samples, + compute_type=compute_type, + ) + ) + + with gr.Row(): + cutoff_len = gr.Slider(minimum=4, maximum=131072, value=2048, step=1) + batch_size = gr.Slider(minimum=1, maximum=1024, value=1, step=1) + gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1) + val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001) + lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="constant") + + input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type}) + elem_dict.update( + dict( + cutoff_len=cutoff_len, + batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + val_size=val_size, + lr_scheduler_type=lr_scheduler_type, + ) + ) + + with gr.Accordion(open=False) as extra_tab: + with gr.Row(): + logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5) + save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10) + eval_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10) + warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1) + extra_args = gr.Textbox(value='{"optim": "adamw"}') + + input_elems.update( + { + logging_steps, + save_steps, + eval_steps, + warmup_steps, + extra_args, + } + ) + elem_dict.update( + dict( + extra_tab=extra_tab, + logging_steps=logging_steps, + eval_steps=eval_steps, + save_steps=save_steps, + warmup_steps=warmup_steps, + extra_args=extra_args, + ) + ) + + with gr.Accordion(open=False) as lora_tab: + with gr.Row(): + lora_rank = gr.Slider(minimum=1, maximum=1024, value=32, step=1) + lora_alpha = gr.Slider(minimum=1, maximum=2048, value=32, step=1) + lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01) + loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=1, step=0.01) + + with gr.Row(): + use_rslora = gr.Checkbox() + use_pissa = gr.Checkbox() + + input_elems.update( + { + lora_rank, + lora_alpha, + lora_dropout, + loraplus_lr_ratio, + use_rslora, + use_pissa, + } + ) + elem_dict.update( + dict( + lora_tab=lora_tab, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + loraplus_lr_ratio=loraplus_lr_ratio, + use_rslora=use_rslora, + use_pissa=use_pissa, + ) + ) + + with gr.Row(): + arg_save_btn = gr.Button() + arg_load_btn = gr.Button() + start_btn = gr.Button(variant="primary") + stop_btn = gr.Button(variant="stop") + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + current_time = gr.Textbox(visible=False, interactive=False) + output_dir = gr.Dropdown(allow_custom_value=True) + config_path = gr.Dropdown(allow_custom_value=True) + + with gr.Row(): + device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False) + + with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False) + progress_bar = gr.Slider(visible=False, interactive=False) + + with gr.Row(): + output_box = gr.Textbox(interactive=False) + + with gr.Column(scale=1): + loss_viewer = gr.Plot() + + input_elems.update({output_dir, config_path, output_box}) + elem_dict.update( + dict( + arg_save_btn=arg_save_btn, + arg_load_btn=arg_load_btn, + start_btn=start_btn, + stop_btn=stop_btn, + current_time=current_time, + output_dir=output_dir, + config_path=config_path, + device_count=device_count, + resume_btn=resume_btn, + progress_bar=progress_bar, + output_box=output_box, + loss_viewer=loss_viewer, + ) + ) + output_elems = [output_box, progress_bar, loss_viewer] + + start_btn.click(engine.runner.run_train_v2, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort) + resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) + + lang = engine.manager.get_elem_by_id("top.lang") + model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") + finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") + + arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) + arg_load_btn.click( + engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None + ) + + dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) + model_name.change( + list_checkpoints, + [ + model_name, + finetuning_type, + ], + [output_dir], + queue=False, + ) + finetuning_type.change(list_checkpoints, [model_name, finetuning_type], [output_dir], queue=False) + output_dir.change( + list_output_dirs, + [model_name, finetuning_type, current_time], + [output_dir], + concurrency_limit=None, + queue=False, + ) + output_dir.input( + engine.runner.check_output_dir, + [lang, model_name, finetuning_type, output_dir], + [output_box], + concurrency_limit=None, + ) + config_path.change(list_config_paths, [current_time], [config_path], queue=False) + + return elem_dict diff --git a/paddlemix/MULLM_WebUI/css.py b/paddlemix/MULLM_WebUI/css.py new file mode 100644 index 000000000..13794cf47 --- /dev/null +++ b/paddlemix/MULLM_WebUI/css.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CSS = r""" +.duplicate-button { + margin: auto !important; + color: white !important; + background: black !important; + border-radius: 100vh !important; +} + +.modal-box { + position: fixed !important; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); /* center horizontally */ + max-width: 1000px; + max-height: 750px; + overflow-y: auto; + background-color: var(--input-background-fill); + flex-wrap: nowrap !important; + border: 2px solid black !important; + z-index: 1000; + padding: 10px; +} + +.dark .modal-box { + border: 2px solid white !important; +} +""" diff --git a/paddlemix/MULLM_WebUI/engine.py b/paddlemix/MULLM_WebUI/engine.py new file mode 100644 index 000000000..bc1dd3128 --- /dev/null +++ b/paddlemix/MULLM_WebUI/engine.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict + +from .chatter import WebChatModel +from .common import get_time +from .locales import LOCALES +from .manager import Manager +from .runner import Runner + +if TYPE_CHECKING: + from gradio.components import Component + + +class Engine: + def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: + self.demo_mode = demo_mode + self.pure_chat = pure_chat + self.manager = Manager() + self.runner = Runner(self.manager, demo_mode) + self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) + + def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: + r""" + Gets the dict to update the components. + """ + output_dict: Dict["Component", "Component"] = {} + for elem_id, elem_attr in input_dict.items(): + elem = self.manager.get_elem_by_id(elem_id) + output_dict[elem] = elem.__class__(**elem_attr) + + return output_dict + + def resume(self): + init_dict = { + "top.lang": {"value": "zh"}, + "top.model_name": {"value": "Qwen2-VL-2B-Instruct"}, + "top.model_path": {"value": "Qwen/Qwen2-VL-2B-Instruct"}, + } + + if not self.pure_chat: + current_time = get_time() + init_dict["train.current_time"] = {"value": current_time} + init_dict["train.output_dir"] = {"value": f"train_{current_time}"} + init_dict["train.config_path"] = {"value": f"{current_time}.yaml"} + + yield self._update_component(init_dict) + + if self.runner.running and not self.pure_chat: + yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()} + if self.runner.do_train: + yield self._update_component({"train.resume_btn": {"value": True}}) + else: + yield self._update_component({"eval.resume_btn": {"value": True}}) + + def change_lang(self, lang: str): + return { + elem: elem.__class__(**LOCALES[elem_name][lang]) + for elem_name, elem in self.manager.get_elem_iter() + if elem_name in LOCALES + } diff --git a/paddlemix/MULLM_WebUI/extras/__init__.py b/paddlemix/MULLM_WebUI/extras/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/MULLM_WebUI/extras/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/MULLM_WebUI/extras/args.py b/paddlemix/MULLM_WebUI/extras/args.py new file mode 100644 index 000000000..6a546d6e9 --- /dev/null +++ b/paddlemix/MULLM_WebUI/extras/args.py @@ -0,0 +1,872 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List, Literal, Optional, Tuple + +import paddle +from paddlenlp.trainer import ( + PdArgumentParser, + Seq2SeqTrainingArguments, + get_last_checkpoint, +) +from typing_extensions import Self +from yaml import safe_dump, safe_load + +from ...utils.log import logger +from .constants import CHECKPOINT_NAMES +from .training import get_current_device + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + + template: Optional[str] = field( + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."}, + ) + dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, + ) + eval_dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, + ) + dataset_dir: str = field( + default="data", + metadata={"help": "Path to the folder containing the datasets."}, + ) + image_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."}, + ) + cutoff_len: int = field( + default=2048, + metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, + ) + train_on_prompt: bool = field( + default=False, + metadata={"help": "Whether or not to disable the mask on the prompt."}, + ) + mask_history: bool = field( + default=False, + metadata={"help": "Whether or not to mask the history and train on the last turn only."}, + ) + streaming: bool = field( + default=False, + metadata={"help": "Enable dataset streaming."}, + ) + buffer_size: int = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, + ) + mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, + ) + interleave_probs: Optional[str] = field( + default=None, + metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + preprocessing_batch_size: int = field( + default=1000, + metadata={"help": "The number of examples in one group in pre-processing."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the pre-processing."}, + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."}, + ) + val_size: float = field( + default=0.0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, + ) + packing: Optional[bool] = field( + default=None, + metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, + ) + neat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing without cross-attention."}, + ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Tool format to use for constructing function calling examples."}, + ) + tokenized_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to save or load the tokenized datasets. " + "If tokenized_path not exists, it will save the tokenized datasets. " + "If tokenized_path exists, it will load the tokenized datasets." + ) + }, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.dataset = split_arg(self.dataset) + self.eval_dataset = split_arg(self.eval_dataset) + + if self.image_dir is None: + self.image_dir = self.dataset_dir + + if self.dataset is None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `dataset` is None.") + + if self.eval_dataset is not None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + if self.interleave_probs is not None: + if self.mix_strategy == "concat": + raise ValueError("`interleave_probs` is only valid for interleaved mixing.") + + self.interleave_probs = list(map(float, split_arg(self.interleave_probs))) + if self.dataset is not None and len(self.dataset) != len(self.interleave_probs): + raise ValueError("The length of dataset and interleave probs should be identical.") + + if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs): + raise ValueError("The length of eval dataset and interleave probs should be identical.") + + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: + raise ValueError("Streaming mode should have an integer val size.") + + if self.streaming and self.max_samples is not None: + raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.mask_history and self.train_on_prompt: + raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") + + +@dataclass +class ProcessorArguments: + r""" + Arguments pertaining to the image processor. + """ + + image_resolution: int = field( + default=512 * 512, + metadata={"help": "Keeps the number of pixels of image below this resolution."}, + ) + video_resolution: int = field( + default=128 * 128, + metadata={"help": "Keeps the number of pixels of video below this resolution."}, + ) + video_fps: float = field( + default=2.0, + metadata={"help": "The frames to sample per second for video inputs."}, + ) + video_maxlen: int = field( + default=64, + metadata={"help": "The maximum number of sampled frames for video inputs."}, + ) + + +@dataclass +class ExportArguments: + r""" + Arguments pertaining to the model export. + """ + + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."}, + ) + export_size: int = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."}, + ) + export_device: Literal["cpu", "auto"] = field( + default="cpu", + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."}, + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, + ) + export_quantization_nsamples: int = field( + default=128, + metadata={"help": "The number of samples used for quantization."}, + ) + export_quantization_maxlen: int = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."}, + ) + export_legacy_format: bool = field( + default=False, + metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, + ) + export_hub_model_id: Optional[str] = field( + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + ) + + +@dataclass +class VllmArguments: + r""" + Arguments pertaining to the vLLM worker. + """ + + vllm_maxlen: int = field( + default=4096, + metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, + ) + vllm_gpu_util: float = field( + default=0.9, + metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, + ) + vllm_enforce_eager: bool = field( + default=False, + metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, + ) + vllm_max_lora_rank: int = field( + default=32, + metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, + ) + vllm_config: Optional[str] = field( + default=None, + metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, + ) + + +@dataclass +class ModelArguments(ProcessorArguments, ExportArguments, VllmArguments): + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, + ) + adapter_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the adapter weight or identifier from huggingface.co/models. " + "Use commas to separate multiple adapters." + ) + }, + ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + ) + resize_vocab: bool = field( + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + ) + split_special_tokens: bool = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + ) + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, + ) + shift_attn: bool = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + ) + mixture_of_depths: Optional[Literal["convert", "load"]] = field( + default=None, + metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, + ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + ) + use_unsloth_gc: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's gradient checkpointing."}, + ) + enable_liger_kernel: bool = field( + default=False, + metadata={"help": "Whether or not to enable liger kernel for faster training."}, + ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) + disable_gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, + ) + upcast_layernorm: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, + ) + upcast_lmhead_output: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) + infer_backend: Literal["huggingface", "vllm"] = field( + default="huggingface", + metadata={"help": "Backend engine used at inference."}, + ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, + ) + om_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Modelers Hub."}, + ) + print_param_status: bool = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + ) + compute_dtype: Optional[paddle.dtype] = field( + default=None, + init=False, + metadata={"help": "Paddle data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, + ) + device_map: Optional[str] = field( + default=None, + init=False, + metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, + ) + model_max_length: Optional[int] = field( + default=None, + init=False, + metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, + ) + block_diag_attn: bool = field( + default=False, + init=False, + metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."}, + ) + + def __post_init__(self): + if self.model_name_or_path is None: + raise ValueError("Please provide `model_name_or_path`.") + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.adapter_name_or_path is not None: # support merging multiple lora weights + self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + + # if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"): + # self.vllm_config = _convert_str_dict(json.loads(self.vllm_config)) + + @classmethod + def copyfrom(cls, source: "Self", **kwargs) -> "Self": + init_args, lazy_args = {}, {} + for attr in fields(source): + if attr.init: + init_args[attr.name] = getattr(source, attr.name) + else: + lazy_args[attr.name] = getattr(source, attr.name) + + init_args.update(kwargs) + result = cls(**init_args) + for name, value in lazy_args.items(): + setattr(result, name, value) + + return result + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + + do_sample: bool = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, + ) + temperature: float = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, + ) + top_p: float = field( + default=0.7, + metadata={ + "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + }, + ) + top_k: int = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, + ) + num_beams: int = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, + ) + max_length: int = field( + default=1024, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, + ) + max_new_tokens: int = field( + default=1024, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, + ) + length_penalty: float = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, + ) + default_system: Optional[str] = field( + default=None, + metadata={"help": "Default system message to use in chat completion."}, + ) + + +# finetune +@dataclass +class FreezeArguments: + r""" + Arguments pertaining to the freeze (partial-parameter) training. + """ + + freeze_trainable_layers: int = field( + default=2, + metadata={ + "help": ( + "The number of trainable layers for freeze (partial-parameter) fine-tuning. " + "Positive numbers mean the last n layers are set as trainable, " + "negative numbers mean the first n layers are set as trainable." + ) + }, + ) + freeze_trainable_modules: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the available modules." + ) + }, + ) + freeze_extra_modules: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from hidden layers to be set as trainable " + "for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules." + ) + }, + ) + + +@dataclass +class LoraArguments: + r""" + Arguments pertaining to the LoRA training. + """ + + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from LoRA layers to be set as trainable " + "and saved in the final checkpoint. " + "Use commas to separate multiple modules." + ) + }, + ) + lora_alpha: Optional[int] = field( + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, + ) + lora_dropout: float = field( + default=0.0, + metadata={"help": "Dropout rate for the LoRA fine-tuning."}, + ) + lora_rank: int = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, + ) + lora_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of target modules to apply LoRA. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=1.0, + metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, + ) + loraplus_lr_embedding: float = field( + default=1e-6, + metadata={"help": "LoRA plus learning rate for lora embedding layers."}, + ) + use_rslora: bool = field( + default=False, + metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, + ) + use_dora: bool = field( + default=False, + metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, + ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=16, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) + create_new_adapter: bool = field( + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, + ) + + +@dataclass +class GaloreArguments: + r""" + Arguments pertaining to the GaLore algorithm. + """ + + use_galore: bool = field( + default=False, + metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."}, + ) + galore_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + galore_rank: int = field( + default=16, + metadata={"help": "The rank of GaLore gradients."}, + ) + galore_update_interval: int = field( + default=200, + metadata={"help": "Number of steps to update the GaLore projection."}, + ) + galore_scale: float = field( + default=0.25, + metadata={"help": "GaLore scaling coefficient."}, + ) + galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field( + default="std", + metadata={"help": "Type of GaLore projection."}, + ) + galore_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments): + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + + pure_bf16: bool = field( + default=False, + metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, + ) + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."}, + ) + finetuning_type: Literal["lora", "freeze", "full"] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."}, + ) + use_llama_pro: bool = field( + default=False, + metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, + ) + use_adam_mini: bool = field( + default=False, + metadata={"help": "Whether or not to use the Adam-mini optimizer."}, + ) + freeze_vision_tower: bool = field( + default=True, + metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + ) + train_mm_proj_only: bool = field( + default=False, + metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + ) + compute_accuracy: bool = field( + default=False, + metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."}, + ) + plot_loss: bool = field( + default=False, + metadata={"help": "Whether or not to save the training loss curves."}, + ) + include_effective_tokens_per_second: bool = field( + default=False, + metadata={"help": "Whether or not to compute effective tokens per second."}, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + # self.lora_target: List[str] = split_arg(self.lora_target) + self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + + assert self.finetuning_type in ["lora", "full"], "Invalid fine-tuning method." + # assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.use_llama_pro and self.finetuning_type == "full": + raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") + + if self.train_mm_proj_only and self.finetuning_type != "full": + raise ValueError("`train_mm_proj_only` is only valid for full training.") + + if self.finetuning_type != "lora": + if self.loraplus_lr_ratio is not None: + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.use_rslora: + raise ValueError("`use_rslora` is only valid for LoRA training.") + + if self.pissa_init: + raise ValueError("`pissa_init` is only valid for LoRA training.") + + +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] + + +def _parse_args(parser: "PdArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + + if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the PdArgumentParser: {unknown_args}") + + return (*parsed_args,) + + +def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = PdArgumentParser(_TRAIN_ARGS) + return _parse_args(parser, args) + + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + + # Check arguments + if finetuning_args.stage != "pt" and data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage != "sft": + if training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + + if data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + + if data_args.train_on_prompt or data_args.mask_history: + raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.") + + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") + + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if training_args.do_train and data_args.dataset is None: + raise ValueError("Please specify dataset for training.") + + if (training_args.do_eval or training_args.do_predict) and ( + data_args.eval_dataset is None and data_args.val_size < 1e-6 + ): + raise ValueError("Please specify dataset for evaluation.") + + if training_args.predict_with_generate: + if data_args.eval_dataset is None: + raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") + + if finetuning_args.compute_accuracy: + raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and any( + os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES + ): + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + + if last_checkpoint is not None: + training_args.resume_from_checkpoint = last_checkpoint + logger(20, f"Resuming training from {training_args.resume_from_checkpoint}.") + logger(20, "Change `output_dir` or use `overwrite_output_dir` to avoid.") + + # Post-process model arguments + if training_args.bf16 or finetuning_args.pure_bf16: + model_args.compute_dtype = "bfloat16" + elif training_args.fp16: + model_args.compute_dtype = "float16" + else: + model_args.compute_dtype = "float32" + model_args.device_map = {"": get_current_device()} + model_args.model_max_length = data_args.cutoff_len + model_args.block_diag_attn = data_args.neat_packing + data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" + + # Log on each process the small summary + logger( + 20, + "Process rank: {}, device: {}, compute dtype: {}".format( + training_args.local_rank, + training_args.device, + str(model_args.compute_dtype), + ), + ) + + # transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def load_args(config_path: str) -> Optional[Dict[str, Any]]: + r""" + Loads saved arguments. + """ + try: + with open(config_path, encoding="utf-8") as f: + return safe_load(f) + except Exception: + return None + + +def save_args(config_path: str, config_dict: Dict[str, Any]): + r""" + Saves arguments. + """ + with open(config_path, "w", encoding="utf-8") as f: + safe_dump(config_dict, f) diff --git a/paddlemix/MULLM_WebUI/extras/callbacks.py b/paddlemix/MULLM_WebUI/extras/callbacks.py new file mode 100644 index 000000000..18484afbf --- /dev/null +++ b/paddlemix/MULLM_WebUI/extras/callbacks.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional + +from paddlenlp.trainer.trainer_callback import TrainerCallback +from paddlenlp.trainer.trainer_utils import has_length +from typing_extensions import override + +from ...utils.log import logger +from .constants import TRAINER_LOG +from .training import get_peak_memory + +if TYPE_CHECKING: + from paddlenlp.trainer.trainer_callback import ( + TrainerControl, + TrainerState, + TrainingArguments, + ) + + +class LogCallback(TrainerCallback): + r""" + A callback for logging training and evaluation status. + """ + + def __init__(self) -> None: + # Progress + self.start_time = 0 + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + self.thread_pool: Optional["ThreadPoolExecutor"] = None + # Status + self.aborted = False + self.do_train = False + # Web UI + self.webui_mode = True + # if self.webui_mode: + # signal.signal(signal.SIGABRT, self._set_abort) + # self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) + # logging.add_handler(self.logger_handler) + # transformers.logging.add_handler(self.logger_handler) + + def _set_abort(self, signum, frame) -> None: + self.aborted = True + + def _reset(self, max_steps: int = 0) -> None: + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = max_steps + self.elapsed_time = "" + self.remaining_time = "" + + def _timing(self, cur_steps: int) -> None: + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_time = (self.max_steps - cur_steps) * avg_time_per_step + self.cur_steps = cur_steps + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: + with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def _create_thread_pool(self, output_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _close_thread_pool(self) -> None: + if self.thread_pool is not None: + self.thread_pool.shutdown(wait=True) + self.thread_pool = None + + @override + def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if ( + args.should_save + and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) + and args.overwrite_output_dir + ): + logger(30, "Previous trainer log in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, TRAINER_LOG)) + + @override + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if args.should_save: + self.do_train = True + self._reset(max_steps=state.max_steps) + self._create_thread_pool(output_dir=args.output_dir) + + @override + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + self._close_thread_pool() + + @override + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + @override + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + @override + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not self.do_train: + self._close_thread_pool() + + @override + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not self.do_train: + self._close_thread_pool() + + @override + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not args.should_save: + return + + self._timing(cur_steps=state.global_step) + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=kwargs["logs"].get("loss"), + eval_loss=kwargs["logs"].get("eval_loss"), + predict_loss=kwargs["logs"].get("predict_loss"), + lr=kwargs["logs"].get("learning_rate"), + epoch=kwargs["logs"].get("epoch"), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + # if state.num_input_tokens_seen: + # logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) + # logs["total_tokens"] = state.num_input_tokens_seen + + if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: + vram_allocated, vram_reserved = get_peak_memory() + logs["vram_allocated"] = round(vram_allocated / (1024**3), 2) + logs["vram_reserved"] = round(vram_reserved / (1024**3), 2) + + logs = {k: v for k, v in logs.items() if v is not None} + if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")): + log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}" + for extra_key in ("reward", "accuracy", "throughput"): + if logs.get(extra_key): + log_str += f", '{extra_key}': {logs[extra_key]:.2f}" + + logger(30, "{" + log_str + "}") + + if self.thread_pool is not None: + self.thread_pool.submit(self._write_log, args.output_dir, logs) + + @override + def on_prediction_step( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ): + if self.do_train: + return + + if self.aborted: + sys.exit(0) + + if not args.should_save: + return + + eval_dataloader = kwargs.pop("eval_dataloader", None) + if has_length(eval_dataloader): + if self.max_steps == 0: + self._reset(max_steps=len(eval_dataloader)) + self._create_thread_pool(output_dir=args.output_dir) + + self._timing(cur_steps=self.cur_steps + 1) + if self.cur_steps % 5 == 0 and self.thread_pool is not None: + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + self.thread_pool.submit(self._write_log, args.output_dir, logs) diff --git a/paddlemix/MULLM_WebUI/extras/constants.py b/paddlemix/MULLM_WebUI/extras/constants.py new file mode 100644 index 000000000..05d636edb --- /dev/null +++ b/paddlemix/MULLM_WebUI/extras/constants.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +SUPPORTED_MODELS = { + "Qwen2-VL-2B-Instruct": "Qwen/Qwen2-VL-2B-Instruct", + "Qwen2-VL-7B-Instruct": "Qwen/Qwen2-VL-7B-Instruct", +} + + +MODEL_MAPPING = { + "Qwen2-VL-2B-Instruct": "Qwen2VLForConditionalGeneration", + "Qwen2-VL-7B-Instruct": "Qwen2VLForConditionalGeneration", +} + +DEFAULT_TEMPLATE = { + "Qwen2-VL-2B-Instruct": "qwen2_vl", + "Qwen2-VL-7B-Instruct": "qwen2_vl", +} + +METHODS = ["full", "lora"] + +# train +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", +} + +STAGES_USE_PAIR_DATA = {} +PADDLEMIX_CONFIG = "config.yaml" + +DATA_CONFIG = "dataset_info.json" + +PEFT_METHODS = {"lora"} + +DEFAULT_DATA_DIR = "data" + +TRAINER_MAPPING = {} + +FILEEXT2TYPE = { + "arrow": "arrow", + "csv": "csv", + "json": "json", + "jsonl": "json", + "parquet": "parquet", + "txt": "text", +} + +IGNORE_INDEX = -100 + +IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "") +VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "