Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace begin_forward/forward/end_forward with plan/run #466

Merged
merged 4 commits into from
Aug 25, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Aug 24, 2024

This PR changes the use of begin_forward/forward/end_forward API with the new plan/run API.

  • forward is consistent with pytorch but confusing because flashinfer focus on inference and do not have a corresponding backward phase, this PR changes it to run, which is more precise and consistent with the naming convention of cutlass's python API.
  • begin_forward is renamed to plan, which is consistent with the naming convention of nvmath API.
  • end_forward is deprecated and has no effect after this PR.

There is some slight difference between the old forward and the new run API:

  • All problem specifications will be provided in plan (previously begin_forward) API, and cached until next plan call, and we only need to provide query and KV-Cache tensors in run API.

This is not a breaking change, and we keep backward compatibility of the old begin_forward/forward/end_forward APIs, they will be gradually deprecated in future releases.

@yzh119 yzh119 merged commit d940d2e into main Aug 25, 2024
yzh119 added a commit that referenced this pull request Aug 26, 2024
In the previous PR #466 we replace the old-style
`begin_forward`/`end_forward`/`forward` APIs with the new `plan`/`run`
APIs, but didn't update the unit tests accordingly (this is intentional
because we want a commit that keeps unit tests that uses the old-style
API to check backward compatibility).

This PR updates the unit tests with new APIs.

Some other changes:
- Remove old-style APIs from docstring.
- Fix some errors in docstring with new APIs.
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <[email protected]>
@yzh119 yzh119 deleted the refactor-forward branch August 27, 2024 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant