Skip to content

Commit

Permalink
Fast FBGEMM path KT.regroup_as (#1910)
Browse files Browse the repository at this point in the history
Summary:

Use custom FBGEMM kernel when possible for inference/training.  ~0-75% runtime speedup.

Benchmark Results [Forward]
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 80       | device: cuda     | Runtime (P90):   0.4 ms | Memory (P90):  24.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 80       | device: cuda     | Runtime (P90):   0.4 ms | Memory (P90):  36.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 160      | device: cuda     | Runtime (P90):   0.8 ms | Memory (P90):  48.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 160      | device: cuda     | Runtime (P90):   0.6 ms | Memory (P90):  72.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 320      | device: cuda     | Runtime (P90):   1.9 ms | Memory (P90):  96.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 320      | device: cuda     | Runtime (P90):   0.7 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 640      | device: cuda     | Runtime (P90):   4.6 ms | Memory (P90): 192.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 640      | device: cuda     | Runtime (P90):   1.3 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  13.2 ms | Memory (P90): 384.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 1280     | device: cuda     | Runtime (P90):   2.2 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   0.3 ms | Memory (P90):  48.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   0.4 ms | Memory (P90):  72.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   0.8 ms | Memory (P90):  96.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   0.6 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   1.8 ms | Memory (P90): 192.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   0.9 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 640      | device: cuda     | Runtime (P90):   4.1 ms | Memory (P90): 384.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 640      | device: cuda     | Runtime (P90):   1.6 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  12.8 ms | Memory (P90): 768.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):   3.1 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   0.4 ms | Memory (P90):  96.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   0.5 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   0.7 ms | Memory (P90): 192.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   0.8 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   1.6 ms | Memory (P90): 384.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   1.4 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 640      | device: cuda     | Runtime (P90):   4.8 ms | Memory (P90): 768.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 640      | device: cuda     | Runtime (P90):   2.8 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  12.5 ms | Memory (P90): 1536.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):   5.6 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   0.4 ms | Memory (P90): 192.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   0.8 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   0.9 ms | Memory (P90): 384.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   1.4 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   1.7 ms | Memory (P90): 768.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   2.8 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 640      | device: cuda     | Runtime (P90):   4.1 ms | Memory (P90): 1536.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 640      | device: cuda     | Runtime (P90):   5.6 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  12.2 ms | Memory (P90): 3072.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  11.1 ms | Memory (P90): 4608.0

Benchmark Results [Fowrard + Backward]
  [prod] KeyedTensor.regroup          | B: 512      | F: 80       | device: cuda     | Runtime (P90):   2.2 ms | Memory (P90):  72.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 160      | device: cuda     | Runtime (P90):   4.7 ms | Memory (P90): 144.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 160      | device: cuda     | Runtime (P90):   3.4 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 320      | device: cuda     | Runtime (P90):   9.0 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 320      | device: cuda     | Runtime (P90):   6.5 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 640      | device: cuda     | Runtime (P90):  19.9 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 640      | device: cuda     | Runtime (P90):  11.4 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  46.7 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  23.1 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   2.6 ms | Memory (P90): 144.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   2.5 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   4.5 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   3.9 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   8.8 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   6.7 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 640      | device: cuda     | Runtime (P90):  18.7 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 640      | device: cuda     | Runtime (P90):  12.2 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  42.8 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  23.1 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   2.5 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   2.4 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   4.5 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   4.2 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   8.9 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   7.7 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 640      | device: cuda     | Runtime (P90):  19.2 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 640      | device: cuda     | Runtime (P90):  12.9 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  45.1 ms | Memory (P90): 4608.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  26.4 ms | Memory (P90): 4608.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   2.4 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   2.7 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   4.4 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   4.4 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   8.4 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   8.1 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 640      | device: cuda     | Runtime (P90):  28.0 ms | Memory (P90): 4608.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 640      | device: cuda     | Runtime (P90):  15.6 ms | Memory (P90): 4608.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  43.2 ms | Memory (P90): 9216.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  31.2 ms | Memory (P90): 9216.0

Reviewed By: PaulZhang12

Differential Revision: D56392296
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Apr 24, 2024
1 parent 0f3954f commit 12385ab
Show file tree
Hide file tree
Showing 5 changed files with 615 additions and 55 deletions.
122 changes: 108 additions & 14 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import abc
import operator

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -21,6 +22,12 @@
try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
except OSError:
pass

Expand Down Expand Up @@ -128,18 +135,10 @@ def _assert_offsets_or_lengths_is_provided(


@torch.fx.wrap
# keep for legacy use cases
def _regroup_keyed_tensors(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
# Shortcut for no re-grouping
if len(keyed_tensors) == len(groups):
match = True
for kt, group in zip(keyed_tensors, groups):
if kt.keys() != group:
match = False
break
if match:
return [kt.values() for kt in keyed_tensors]

embedding_dicts = [keyed_tensor.to_dict() for keyed_tensor in keyed_tensors]
lengths = [keyed_tensor.length_per_key() for keyed_tensor in keyed_tensors]
Expand All @@ -165,6 +164,97 @@ def _regroup_keyed_tensors(
return list(rearranged_values.split(split_lengths, dim=key_dim))


@torch.fx.wrap
def _all_keys_used_once(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> bool:
flat_keys: List[str] = []
flat_groups: List[str] = []
for keyed_tensor in keyed_tensors:
flat_keys.extend(keyed_tensor.keys())
for sub_group in groups:
flat_groups.extend(sub_group)
# jit.script does not support set, so we use a dict to represent the set
key_set: Dict[str, int] = {key: 1 for key in flat_keys}
group_set: Dict[str, int] = {key: 1 for key in flat_groups}
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
keys, lengths, groups
)
values = torch.concat(values, dim=1)
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values,
offsets.to(device=values.device),
permute.to(device=values.device),
inv_offsets.to(device=values.device),
inv_permute.to(device=values.device),
)
return list(torch.split(permuted_values, splits, dim=1))


@torch.fx.wrap
def _desugar_keyed_tensors(
kts: List["KeyedTensor"],
) -> Tuple[List[List[str]], List[List[int]], List[torch.Tensor]]:
"""
Desugar a list of KeyedTensors into basic data structure
"""
return (
[kt.keys() for kt in kts],
[kt.length_per_key() for kt in kts],
[kt.values() for kt in kts],
)


@torch.fx.wrap
def _remap_to_groups(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""
Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits.
The output is used to re-arrange values based on groups with a single cat operation.
"""

lengths: List[int] = []
flat_keys: List[str] = []
flat_groups: List[str] = []

for sub_keys_length in key_lengths:
lengths.extend(sub_keys_length)
for sub_keys in keys:
flat_keys.extend(sub_keys)

for sub_group in groups:
flat_groups.extend(sub_group)

key_splits = [len(sub_group) for sub_group in groups]

index_map = {key: idx for idx, key in enumerate(flat_keys)}
permute = [index_map[key] for key in flat_groups]
inv_lengths = [lengths[i] for i in permute]
splits = _sum_by_splits(inv_lengths, key_splits)

inv_permute = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i

offsets = torch.tensor(_cumsum(lengths), dtype=torch.int64)
inv_offsets = torch.tensor(_cumsum(inv_lengths), dtype=torch.int64)
permute = torch.tensor(permute, dtype=torch.int64)
inv_permute = torch.tensor(inv_permute, dtype=torch.int64)

return permute, inv_permute, offsets, inv_offsets, splits


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down Expand Up @@ -2474,18 +2564,22 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
return _regroup_keyed_tensors(keyed_tensors, groups)
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)

@staticmethod
def regroup_as_dict(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]], keys: List[str]
) -> Dict[str, torch.Tensor]:
ret: Dict[str, torch.Tensor] = {}
assert len(groups) == len(keys), "Groups and keys should have same length"
embeddings_list = _regroup_keyed_tensors(keyed_tensors, groups)
embeddings_dict: Dict[str, torch.Tensor] = {}
tensor_list = KeyedTensor.regroup(keyed_tensors, groups)
for i, key in enumerate(keys):
embeddings_dict[key] = embeddings_list[i]
return embeddings_dict
ret[key] = tensor_list[i]
return ret

@torch.jit.unused
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
Expand Down
Loading

0 comments on commit 12385ab

Please sign in to comment.