From cafccb88c4cd38ed54462c482fc0550944463839 Mon Sep 17 00:00:00 2001 From: Kwanghoon An Date: Tue, 28 May 2024 22:49:05 -0700 Subject: [PATCH] Support min/max carry over for eager mode from_float method (#2046) Summary: X-link: https://github.com/pytorch/pytorch/pull/127309 After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never. Dynamic QAT also does not require to re-run weight observer again by design. This is a fix Reviewed By: jerryzh168 Differential Revision: D57747749 --- torchrec/distributed/quant_embedding_kernel.py | 8 ++++++-- torchrec/quant/embedding_modules.py | 14 +++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 17e6d7b7f..283587618 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -338,7 +338,9 @@ def split_embedding_weights( ] @classmethod - def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbeddingBag": + def from_float( + cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False + ) -> "QuantBatchedEmbeddingBag": assert hasattr( module, "qconfig" ), "BaseEmbedding input float module must have qconfig defined" @@ -490,7 +492,9 @@ def named_buffers( yield append_prefix(prefix, f"{config.name}.weight_qbias"), weight_qbias @classmethod - def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbedding": + def from_float( + cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False + ) -> "QuantBatchedEmbedding": assert hasattr( module, "qconfig" ), "BaseEmbedding input float module must have qconfig defined" diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index a978cfe64..80132bf70 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -514,7 +514,9 @@ def _get_name(self) -> str: @classmethod def from_float( - cls, module: OriginalEmbeddingBagCollection + cls, + module: OriginalEmbeddingBagCollection, + use_precomputed_fake_quant: bool = False, ) -> "EmbeddingBagCollection": assert hasattr( module, "qconfig" @@ -617,7 +619,9 @@ def _get_name(self) -> str: @classmethod # pyre-ignore def from_float( - cls, module: OriginalFeatureProcessedEmbeddingBagCollection + cls, + module: OriginalFeatureProcessedEmbeddingBagCollection, + use_precomputed_fake_quant: bool = False, ) -> "FeatureProcessedEmbeddingBagCollection": fp_ebc = module ebc = module._embedding_bag_collection @@ -903,7 +907,11 @@ def forward( return feature_embeddings @classmethod - def from_float(cls, module: OriginalEmbeddingCollection) -> "EmbeddingCollection": + def from_float( + cls, + module: OriginalEmbeddingCollection, + use_precomputed_fake_quant: bool = False, + ) -> "EmbeddingCollection": assert hasattr( module, "qconfig" ), "EmbeddingCollection input float module must have qconfig defined"