Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sighingnow
Copy link
Contributor

@sighingnow sighingnow commented Jan 8, 2025

This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.

This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.

FIX #12452

Copy link

github-actions bot commented Jan 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 8, 2025
@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch 2 times, most recently from 82b5a4c to 4c4a33e Compare January 9, 2025 06:17
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph.
Do you plan to fix this in the future?

Copy link

mergify bot commented Jan 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sighingnow
Copy link
Contributor Author

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

st] = decode_metadata.block_tables[i, st:ed]
decode_metadata.block_tables_intra = block_tables_intra

seq_lens_succ = (chunk_num_curr -

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.

seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)

Copy link

mergify bot commented Jan 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 15, 2025
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

Capturing CUDA graph shapes:   0%|                                                                                                                                                                                                               | 0/35 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 435, in <module>
[rank0]:     ht = LLMNeedleHaystackTester(
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 94, in __init__
[rank0]:     self.model_to_test = LLM(model=model_name)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/utils.py", line 1044, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/entrypoints/llm.py", line 228, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 517, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 276, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 429, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/executor/gpu_executor.py", line 83, in initialize_cache
[rank0]:     self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 274, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 292, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1533, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1885, in capture
[rank0]:     self.model(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 496, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/compilation/decorators.py", line 170, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 359, in forward
[rank0]:     hidden_states, residual = layer(
[rank0]:                               ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 267, in forward
[rank0]:     hidden_states = self.self_attn(
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 189, in forward
[rank0]:     attn_output = self.attn(q,
[rank0]:                   ^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 185, in forward
[rank0]:     return torch.ops.vllm.unified_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 280, in unified_attention
[rank0]:     return self.impl.forward(query, key, value, kv_cache, attn_metadata,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 373, in forward
[rank0]:     assert decode_meta.scaling_factor is not None
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError

@mergify mergify bot removed the needs-rebase label Jan 16, 2025
@sighingnow
Copy link
Contributor Author

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

The dual chunk attention doesn't support cuda graph and I have added an assertion in arg_utils.py.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.

It is indeed a bug introduced during preparing this PR, fixed. Thanks!

@sighingnow
Copy link
Contributor Author

sighingnow commented Jan 19, 2025

Rebase against main.

Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request?

Thanks!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?

csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
csrc/attention/vertical_slash_index.cu Outdated Show resolved Hide resolved
Copy link

mergify bot commented Jan 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 20, 2025
examples/offline_inference_qwen_1m.py Outdated Show resolved Hide resolved
examples/offline_inference_qwen_1m.py Outdated Show resolved Hide resolved
examples/offline_inference_qwen_1m.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Show resolved Hide resolved
vllm/attention/backends/dual_chunk_flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/dual_chunk_flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/dual_chunk_flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/dual_chunk_flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/xformers.py Outdated Show resolved Hide resolved
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq)
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq)
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think these einsum's are still slow on cuda than (a * b).sum(-1), not on the hot path though so not critical

pytorch/pytorch#101249

ran bench_einsum.py from that issue on an H100 and got:

python einsum_bench.py 
[-------------------------------------  -------------------------------------]
                                  |  mul/sum  |  torch.einsum  |  numpy.einsum
1 threads: -------------------------------------------------------------------
      Nc,Nc->N cpu (1048576, 2)   |    5000   |      3100      |      4000    
      Nc,Nc->N cuda (1048576, 2)  |      20   |       747      |      3300    

Times are in microseconds (us).

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
Comment on lines +102 to +115
logits_soft_cap, attn_type, **{
"dual_chunk_attention_config": dual_chunk_attention_config,
"prefix": prefix,
} if dual_chunk_attention_config is not None else {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this messy, I think we should maybe do something like:

def __init__(..., **extra_attn_kwargs):
   self.impl = impl_cls(..., **extra_attn_kwargs)

the challenge here is prefix would not be captured by extra_attn_kwargs but is only (currently) used by DualChunkFlashAttentionImpl. I do think it would be less messy though to do this any make prefix a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon

Comment on lines 148 to 158
if self.dual_chunk_attention_config:
assert query_succ_and_inter is not None
dca_kwargs = {
"query_succ": query_succ_and_inter[0],
"query_inter": query_succ_and_inter[1],
"query_succ_critical": query_succ_and_inter[2],
"query_inter_critical": query_succ_and_inter[3],
} if query_succ_and_inter else {}
else:
dca_kwargs = {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q tensor that get reinterpreted as components via a combination of slicing and .view calls in the attn impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would take a try to see if it can be simplified.

CMakeLists.txt Outdated Show resolved Hide resolved
@mergify mergify bot removed the needs-rebase label Jan 23, 2025
@sighingnow
Copy link
Contributor Author

Hi @LucasWilkinson most of the comments has been addressed, could you please take another look? Thanks!

The lint error comes from the prompt text, do you have any suggestion about how could I skip/resolve it?

@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch 3 times, most recently from c230fb7 to e1b1d0f Compare January 25, 2025 15:54
@ywang96
Copy link
Member

ywang96 commented Jan 27, 2025

@sighingnow Sorry for the delayed response! I've merged main into your branch so the pre-commit error should be cleared. I'll enable ready status for this PR so at least we can get the CI going before @tlrmchlsmth or @LucasWilkinson want to give their final greenlight!

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 27, 2025
This was referenced Jan 27, 2025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should revisit whether it makes sense to have this 1m context example here. This file would be by far the largest in the vLLM git repo at 4MB, and I think it might be better for the example to pull the data from an external source.

Side note, I thought having the example prompts start by discussing the differences between various narcotics was an interesting choice!

Comment on lines +176 to +177
vertical_indices_count: torch.
Tensor, # [N_HEADS] : different head use different number of indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: some weird whitespace here - I wouldn't want us to spend too much effort fighting clang-format but maybe the comment can be moved to its own line

Comment on lines +198 to +206
for i in range(batch_size):
st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len //
self.block_size)
ed = min(
st + (max_seq_len_succ - 1) // self.block_size + 1,
(cache_seq_lens[i] - 1) // self.block_size + 1,
)
block_tables_succ[i, :ed -
st] = decode_metadata.block_tables[i, st:ed]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved readability:

Suggested change
for i in range(batch_size):
st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len //
self.block_size)
ed = min(
st + (max_seq_len_succ - 1) // self.block_size + 1,
(cache_seq_lens[i] - 1) // self.block_size + 1,
)
block_tables_succ[i, :ed -
st] = decode_metadata.block_tables[i, st:ed]
for i in range(batch_size):
start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len //
self.block_size)
end = min(
start + (max_seq_len_succ - 1) // self.block_size + 1,
(cache_seq_lens[i] - 1) // self.block_size + 1,
)
block_tables_succ[i, :end -
start] = decode_metadata.block_tables[i, start:end]

Comment on lines +1037 to +1061
if sparse_attn_enabled:
flash_result = self._do_flash_attn(
q_states_intra,
k_states_intra,
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
stage="intra",
vertical_indices=vertical_buffer,
slash_indices=slash_buffer,
vertical_indices_count=vertical_size_buffer,
slash_indices_count=slash_sizes_buffer,
mergehead_softmax_scale=softmax_scale,
sparse_attn_enabled=sparse_attn_enabled)
else:
flash_result = self._do_flash_attn(
q_states_intra,
k_states_intra,
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
stage="intra",
vertical_indices=intra_vertical_indices,
slash_indices=intra_slash_indices,
sparse_attn_enabled=sparse_attn_enabled)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: is the if/else necessary since sparse_attn_enabled is passed into the function in both cases?

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add SPDX headers to new files (see #12628 for reference)

Copy link

mergify bot commented Feb 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 3, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more review comments, mostly minor stuff. Looks pretty good, although I do suggest getting rid of the qwen_1m example

Comment on lines +10 to +22
__device__ void save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t& block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function would be clearer and more explicit in its behavior if it returned the current block count instead of modifying its input argument:

__device__ void save_blocks(int* block_offset, int64_t range_start,
                            int64_t range_end, int64_t block_size,
                            int64_t input_block_count, int64_t kv_seqlen) {
  if (range_start >= kv_seqlen) {
    return input_block_count;
  }
  if (range_end > kv_seqlen) {
    range_end = kv_seqlen;
  }
  int64_t current_block_count = input_block_count;
  for (int idx = range_start; idx < range_end; idx += block_size) {
    block_offset[block_count++] = idx;
  }
  return block_count;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments describing what the functions in this file are doing? Comments describing what blocks of code within convert_vertical_slash_indexes_kernel would be helpful as well

Comment on lines +835 to +836
int32_max = 2147483647 # avoid sort
int32_min = -2147483648
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: best to avoid hardcoding constants

Suggested change
int32_max = 2147483647 # avoid sort
int32_min = -2147483648
int32_max = torch.iinfo(torch.int32).max
int32_min = torch.iinfo(torch.int32).min

(Also could you elaborate on the "avoid sort" comment?)

Comment on lines +1199 to +1201
q=query_states.bfloat16(),
k=key_states.bfloat16(),
v=value_states.bfloat16(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert these to bfloat16? I don't think we should be doing this e.g. if the model's dtype is float16

softmax_scale,
chunk_size,
local_size,
scaling_factor.item(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scaling_factor.item(),
scaling_factor[i].item(),

Otherwise this will error since scaling_factor contains more than 1 element here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build needs-rebase ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Support Qwen/Qwen2.5-14B-Instruct-1M
7 participants