Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast FBGEMM path KT.regroup_as #1910

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 108 additions & 14 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import abc
import operator

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -21,6 +22,12 @@
try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
except OSError:
pass

Expand Down Expand Up @@ -128,18 +135,10 @@ def _assert_offsets_or_lengths_is_provided(


@torch.fx.wrap
# keep for legacy use cases
def _regroup_keyed_tensors(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
# Shortcut for no re-grouping
if len(keyed_tensors) == len(groups):
match = True
for kt, group in zip(keyed_tensors, groups):
if kt.keys() != group:
match = False
break
if match:
return [kt.values() for kt in keyed_tensors]

embedding_dicts = [keyed_tensor.to_dict() for keyed_tensor in keyed_tensors]
lengths = [keyed_tensor.length_per_key() for keyed_tensor in keyed_tensors]
Expand All @@ -165,6 +164,97 @@ def _regroup_keyed_tensors(
return list(rearranged_values.split(split_lengths, dim=key_dim))


@torch.fx.wrap
def _all_keys_used_once(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> bool:
flat_keys: List[str] = []
flat_groups: List[str] = []
for keyed_tensor in keyed_tensors:
flat_keys.extend(keyed_tensor.keys())
for sub_group in groups:
flat_groups.extend(sub_group)
# jit.script does not support set, so we use a dict to represent the set
key_set: Dict[str, int] = {key: 1 for key in flat_keys}
group_set: Dict[str, int] = {key: 1 for key in flat_groups}
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
keys, lengths, groups
)
values = torch.concat(values, dim=1)
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values,
offsets.to(device=values.device),
permute.to(device=values.device),
inv_offsets.to(device=values.device),
inv_permute.to(device=values.device),
)
return list(torch.split(permuted_values, splits, dim=1))


@torch.fx.wrap
def _desugar_keyed_tensors(
kts: List["KeyedTensor"],
) -> Tuple[List[List[str]], List[List[int]], List[torch.Tensor]]:
"""
Desugar a list of KeyedTensors into basic data structure
"""
return (
[kt.keys() for kt in kts],
[kt.length_per_key() for kt in kts],
[kt.values() for kt in kts],
)


@torch.fx.wrap
def _remap_to_groups(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""
Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits.
The output is used to re-arrange values based on groups with a single cat operation.
"""

lengths: List[int] = []
flat_keys: List[str] = []
flat_groups: List[str] = []

for sub_keys_length in key_lengths:
lengths.extend(sub_keys_length)
for sub_keys in keys:
flat_keys.extend(sub_keys)

for sub_group in groups:
flat_groups.extend(sub_group)

key_splits = [len(sub_group) for sub_group in groups]

index_map = {key: idx for idx, key in enumerate(flat_keys)}
permute = [index_map[key] for key in flat_groups]
inv_lengths = [lengths[i] for i in permute]
splits = _sum_by_splits(inv_lengths, key_splits)

inv_permute = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i

offsets = torch.tensor(_cumsum(lengths), dtype=torch.int64)
inv_offsets = torch.tensor(_cumsum(inv_lengths), dtype=torch.int64)
permute = torch.tensor(permute, dtype=torch.int64)
inv_permute = torch.tensor(inv_permute, dtype=torch.int64)

return permute, inv_permute, offsets, inv_offsets, splits


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down Expand Up @@ -2474,18 +2564,22 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
return _regroup_keyed_tensors(keyed_tensors, groups)
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)

@staticmethod
def regroup_as_dict(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]], keys: List[str]
) -> Dict[str, torch.Tensor]:
ret: Dict[str, torch.Tensor] = {}
assert len(groups) == len(keys), "Groups and keys should have same length"
embeddings_list = _regroup_keyed_tensors(keyed_tensors, groups)
embeddings_dict: Dict[str, torch.Tensor] = {}
tensor_list = KeyedTensor.regroup(keyed_tensors, groups)
for i, key in enumerate(keys):
embeddings_dict[key] = embeddings_list[i]
return embeddings_dict
ret[key] = tensor_list[i]
return ret

@torch.jit.unused
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
Expand Down
Loading
Loading