Skip to content

Commit

Permalink
Add attention_backend to let user choose
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Jan 13, 2025
1 parent 654bb03 commit 757b224
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class BuilderArgs:
prefill_possible: bool = False
dynamic_shapes: bool = False
max_seq_length: Optional[int] = None
attention_backend: str = "math"

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -202,6 +203,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
attention_backend=args.attention_backend,
)

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def _add_model_config_args(parser, verb: str) -> None:
choices=["fast", "cpu", "cuda", "mps"],
help="Hardware device to use. Options: fast, cpu, cuda, mps",
)
model_config_parser.add_argument(
"--attention-backend",
type=str,
default="math",
choices=["math", "flash_attention", "efficient_attention", "cudnn_attention"],
help="SDPBackend to use. Options: MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION",
)


# Add CLI Args representing output paths of exported model files
Expand Down
16 changes: 15 additions & 1 deletion torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,22 @@ def decode_n_tokens(
callback=lambda _: _,
eos_token_id: int = 2,
eot_id: Optional[int] = None,
attention_backend: str = "math",
**sampling_kwargs,
):
new_tokens, new_probs = [], []
encountered_eos = False
sdp_backend_dict = {
'math': torch.nn.attention.SDPBackend.MATH,
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
}
for _i in range(
num_new_tokens - 1
): # -1 to save space to run an EoS if dont generate it naturally
# Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
with torch.nn.attention.sdpa_kernel([sdp_backend_dict[attention_backend]]):

out_token = cur_token.clone()
next_token, next_prob = self.decode_one_token(
Expand Down Expand Up @@ -683,6 +690,7 @@ def generate(
sequential_prefill=True,
callback=lambda x: x,
max_seq_length: int,
attention_backend: str = "math",
seed: Optional[int] = None,
**sampling_kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -799,6 +807,7 @@ def generate(
if self.is_llama3_model
else None
),
attention_backend=attention_backend,
**sampling_kwargs,
):
generated_tokens.append(generated_token.view(-1))
Expand Down Expand Up @@ -1170,6 +1179,10 @@ def callback(x, *, done_generating=False):
prof = torch.profiler.profile()
t0 = time.perf_counter()
num_tokens_generated = 0
if self.builder_args.device == "cpu" and (self.builder_args.attention_backend == "efficient_attention"
or self.builder_args.attention_backend == "cudnn_attention"):
print(f"Warning: {self.builder_args.attention_backend} is not supported on CPU. Using math instead.")
self.builder_args.attention_backend = "math"
with prof:
generator_func = self.generate(
self.model,
Expand All @@ -1186,6 +1199,7 @@ def callback(x, *, done_generating=False):
start_pos=start_pos,
skip_cache_setup=not is_first_sample,
max_seq_length=max_seq_length,
attention_backend=self.builder_args.attention_backend,
)
for token_tensor, metrics in generator_func:
if token_tensor is not None:
Expand Down

0 comments on commit 757b224

Please sign in to comment.