Skip to content

Commit

Permalink
make the custom_op_name consistent and backout weights... (#1949)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1949

# context
* backout the previous diff for weights in the graph
* make the custom_op name consistent

# before
```
(Pdb) ep.graph.print_tabular()
opcode         name                                      target                                                 args                                                                                                                                   kwargs
-------------  ----------------------------------------  -----------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------  --------
placeholder    p_ebc1_embedding_bags_t1_weight           p_ebc1_embedding_bags_t1_weight                        ()                                                                                                                                     {}
placeholder    p_ebc1_embedding_bags_t2_weight           p_ebc1_embedding_bags_t2_weight                        ()                                                                                                                                     {}
placeholder    p_ebc2_embedding_bags_t3_weight           p_ebc2_embedding_bags_t3_weight                        ()                                                                                                                                     {}
placeholder    p_ebc2_embedding_bags_t4_weight           p_ebc2_embedding_bags_t4_weight                        ()                                                                                                                                     {}
placeholder    features__values                          features__values                                       ()                                                                                                                                     {}
placeholder    features__weights                         features__weights                                      ()                                                                                                                                     {}
placeholder    features__lengths                         features__lengths                                      ()                                                                                                                                     {}
placeholder    features__offsets                         features__offsets                                      ()                                                                                                                                     {}
call_function  embedding_bag_collection_140614752342208  custom.EmbeddingBagCollection_140614752342208.default  ([features__values, features__weights, None, features__offsets, p_ebc1_embedding_bags_t1_weight, p_ebc1_embedding_bags_t2_weight], 2)  {}
call_function  getitem                                   <built-in function getitem>                            (embedding_bag_collection_140614752342208, 0)                                                                                          {}
call_function  embedding_bag_collection_140614752342209  custom.EmbeddingBagCollection_140614752342208.default  ([features__values, features__weights, None, features__offsets, p_ebc2_embedding_bags_t3_weight, p_ebc2_embedding_bags_t4_weight], 2)  {}
call_function  getitem_1                                 <built-in function getitem>                            (embedding_bag_collection_140614752342209, 0)                                                                                          {}
call_function  cat                                       aten.cat.default                                       ([getitem, getitem_1], 1)                                                                                                              {}
output         output                                    output                                                 ((cat,),)                                                                                                                              {}
```

# after
```
(Pdb) ep.graph.print_tabular()
opcode         name                             target                                   args                                                                 kwargs
-------------  -------------------------------  ---------------------------------------  -------------------------------------------------------------------  --------
placeholder    p_ebc1_embedding_bags_t1_weight  p_ebc1_embedding_bags_t1_weight          ()                                                                   {}
placeholder    p_ebc1_embedding_bags_t2_weight  p_ebc1_embedding_bags_t2_weight          ()                                                                   {}
placeholder    p_ebc2_embedding_bags_t3_weight  p_ebc2_embedding_bags_t3_weight          ()                                                                   {}
placeholder    p_ebc2_embedding_bags_t4_weight  p_ebc2_embedding_bags_t4_weight          ()                                                                   {}
placeholder    features__values                 features__values                         ()                                                                   {}
placeholder    features__weights                features__weights                        ()                                                                   {}
placeholder    features__lengths                features__lengths                        ()                                                                   {}
placeholder    features__offsets                features__offsets                        ()                                                                   {}
call_function  embedding_bag_collection_3       custom.EmbeddingBagCollection_3.default  ([features__values, features__weights, None, features__offsets], 2)  {}
call_function  getitem                          <built-in function getitem>              (embedding_bag_collection_3, 0)                                      {}
call_function  embedding_bag_collection_4       custom.EmbeddingBagCollection_3.default  ([features__values, features__weights, None, features__offsets], 2)  {}
call_function  getitem_1                        <built-in function getitem>              (embedding_bag_collection_4, 0)                                      {}
call_function  cat                              aten.cat.default                         ([getitem, getitem_1], 1)                                            {}
output         output                           output                                   ((cat,),)                                                            {}
```

Reviewed By: PaulZhang12, shruthign

Differential Revision: D56942421

fbshipit-source-id: cb9a5db4e4d94ed1c2f288e2a465110bbc3ddb1a
  • Loading branch information
TroyGarden authored and facebook-github-bot committed May 6, 2024
1 parent 54956aa commit 7c5571f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 43 deletions.
2 changes: 1 addition & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _non_strict_exporting_forward(
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
] + [bag.weight for bag in self.embedding_bags.values()]
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
dims = [sum(self._lengths_per_embedding)]
ebc_op = register_custom_op(self, dims)
outputs = ebc_op(arg_list, batch_size)
Expand Down
94 changes: 52 additions & 42 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ class OpRegistryState:
"""

op_registry_lock = threading.Lock()
# operator schema: op_name: schema

# operator schema: {class}.{id} => op_name
op_registry_schema: Dict[str, str] = {}
# operator counter: {class} => count
op_registry_counter: Dict[str, int] = defaultdict(int)


operator_registry_state = OpRegistryState()
Expand Down Expand Up @@ -267,6 +270,9 @@ def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
return permuted_indices


# register a customized operator that takes a list of tensors as input and returns
# a list of tensors as output. The operator is registered with the name of
# {module_class_name}_{instance_count}
def register_custom_op(
module: torch.nn.Module, dims: List[int]
) -> Callable[[List[Optional[torch.Tensor]], int], List[torch.Tensor]]:
Expand All @@ -280,47 +286,51 @@ def register_custom_op(

global operator_registry_state

op_name: str = f"{type(module).__name__}_{id(module)}"
m_name: str = type(module).__name__
op_id: str = f"{m_name}_{id(module)}"
with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return getattr(torch.ops.custom, op_name)

def custom_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> List[torch.Tensor]:
device = None
for v in values:
if v is not None:
device = v.device
break
if op_id in operator_registry_state.op_registry_schema:
op_name: str = operator_registry_state.op_registry_schema[op_id]
else:
raise AssertionError(
f"Custom op {op_name} expects at least one input tensor"
)

return [
torch.empty(
batch_size,
dim,
device=device,
operator_registry_state.op_registry_counter[m_name] += 1
op_name: str = (
f"{m_name}_{operator_registry_state.op_registry_counter[m_name]}"
)
for dim in dims
]

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return getattr(torch.ops.custom, op_name)
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, custom_op, "CPU")
lib.impl(op_name, custom_op, "CUDA")

# Register meta formula
lib.impl(op_name, custom_op, "Meta")

return getattr(torch.ops.custom, op_name)
operator_registry_state.op_registry_schema[op_id] = op_name

def custom_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> List[torch.Tensor]:
device = None
for v in values:
if v is not None:
device = v.device
break
else:
raise AssertionError(
f"Custom op {op_name} expects at least one input tensor"
)

return [
torch.empty(
batch_size,
dim,
device=device,
)
for dim in dims
]

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, custom_op, "CPU")
lib.impl(op_name, custom_op, "CUDA")

# Register meta formula
lib.impl(op_name, custom_op, "Meta")

return getattr(torch.ops.custom, op_name)

0 comments on commit 7c5571f

Please sign in to comment.