Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Has anyone implemented batch infer on Qwen-VL-Chat? #487

Open
hhhheep opened this issue Nov 3, 2024 · 0 comments
Open

Has anyone implemented batch infer on Qwen-VL-Chat? #487

hhhheep opened this issue Nov 3, 2024 · 0 comments

Comments

@hhhheep
Copy link

hhhheep commented Nov 3, 2024

模型是 Qwen-VL-Chat
我尝试了使用generate按batch 生成文字 参考了
#240

因为tokenize 无法调整 pad size 方向 手动调整 到left

结果生成的质量并没有 chat()生成的好

是哪里调整有问题嘛?

model is Qwen-VL-Chat
I tried using generate to generate text by batch referring to #240
tokenize can't adjust the pad size direction manually to left

the quality of the generated text isn't as good as chat().

Is there something wrong with the adjustment?Is there something wrong with the adjustment?

my souce code

···
def batch_chat(
self,
tokenizer: PreTrainedTokenizer,
queries: List[str],
histories: Optional[List[HistoryType]] = None,
system: str = "You are a helpful assistant.",
append_history: bool = True,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
batch_size=1,
**kwargs,
) -> Tuple[List[str], List[HistoryType]]:
"""
批量生成聊天响应。

    参数:
        tokenizer (PreTrainedTokenizer): 预训练的分词器。
        queries (List[str]): 批量的查询,每个查询是一个字符串。
        histories (Optional[List[HistoryType]]): 批量的历史记录,每个查询对应一个历史记录。
        system (str): 系统提示。
        append_history (bool): 是否将生成的响应添加到历史记录中。
        stop_words_ids (Optional[List[List[int]]]): 生成过程中使用的停止词ID列表。
        generation_config (Optional[GenerationConfig]): 生成配置参数。
        **kwargs: 其他生成参数,如 batch_size。

    返回:
        Tuple[List[str], List[HistoryType]]: 生成的响应列表和更新后的历史记录列表。
    """
    from collections import OrderedDict

    def custom_collate_fn(batch):
        batch_queries, batch_histories = zip(*batch)
        return list(batch_queries), list(batch_histories)

    # 设置生成配置
    generation_config = generation_config if generation_config is not None else self.generation_config

    # 验证生成配置格式
    assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT

    # 初始化历史记录
    if histories is None:
        historie = [[("1-", "2-")] for _ in range(len(queries))]
    else:
        if len(histories) != len(queries):
            raise ValueError("The length of histories must match the length of queries.")
    
    # 初始化停止词
    final_stop_words_ids = stop_words_ids.copy() if stop_words_ids else []
    additional_stop_words = get_stop_words_ids(generation_config.chat_format, tokenizer)
    final_stop_words_ids += additional_stop_words

    # 获取最大窗口大小
    max_window_size = kwargs.get('max_window_size', generation_config.max_window_size)

    # 批量分块处理,避免内存溢出
    loader = DataLoader(list(zip(queries, historie)), batch_size=batch_size, collate_fn=custom_collate_fn)

    all_responses = []
    all_new_histories = []

    for batch_queries, batch_histories in loader:
        # 构建上下文
        # print("Batch size (queries):", len(batch_queries))
        # print("Batch size (histories):", len(batch_histories))
        context_tokens_batch = []
        raw_texts = []

        for query, history in zip(batch_queries, batch_histories):
            if not isinstance(query, str):
                raise ValueError(f"Expected query to be str, got {type(query)}")
            
            # print("query:", query)
            # print("history:", history)
            raw_text, context_tokens = make_context(
                tokenizer,
                query=query,
                history=history if histories else None,
                system=system,
                max_window_size=max_window_size,
                chat_format=generation_config.chat_format,
            )
            raw_texts.append(raw_text)
            context_tokens_batch.append(torch.tensor(context_tokens, dtype=torch.long))
            # print("context_tokens shape", torch.tensor(context_tokens, dtype=torch.long).shape)

        # 手动进行左侧填充
        max_length = max([tokens.size(0) for tokens in context_tokens_batch])
        padded_context_tokens_batch = []
        
        for tokens in context_tokens_batch:
            padding_length = max_length - tokens.size(0)
            if padding_length > 0:
                # 在左侧进行填充
                padded_tokens = torch.cat([torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long), tokens], dim=0)
            else:
                padded_tokens = tokens
            padded_context_tokens_batch.append(padded_tokens)
        
        # 将填充后的列表转换为张量
        input_ids = torch.stack(padded_context_tokens_batch).to(self.device)
        attention_mask = (input_ids != tokenizer.pad_token_id).long().to(self.device)
        # print("input_ids shape", input_ids.shape)

        # 批量生成响应
        with torch.no_grad():
            try:
                outputs = self.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    stop_words_ids=final_stop_words_ids,
                    return_dict_in_generate=False,
                    generation_config=generation_config,
                    **kwargs,
                )
            except Exception as e:
                raise RuntimeError(f"Error during generation: {e}")
            # print("outputs shape", outputs.shape)

        # 解码并更新历史记录
        for i, output in enumerate(outputs):
            try:
                response = decode_tokens(
                    output,
                    tokenizer,
                    raw_text_len=len(raw_texts[i]),
                    context_length=len(context_tokens_batch[i]),
                    chat_format=generation_config.chat_format,
                    verbose=False,
                    errors='replace'
                )
            except Exception as e:
                response = ""

            all_responses.append(response)

            # 更新历史记录
        
            if append_history:
                batch_histories[i].append((batch_queries[i], response))
            all_new_histories.append(batch_histories[i])
        print(batch_queries)
        all_new_histories = [item for sublist in all_new_histories for item in sublist]
        unique_history = set(i for i in all_new_histories)

    return all_responses, unique_history

···

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant