Skip to content

Commit

Permalink
Support min/max carry over for eager mode from_float method (pytorch#…
Browse files Browse the repository at this point in the history
…2046)

Summary:
X-link: pytorch/pytorch#127309


After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never.

Dynamic QAT also does not require to re-run weight observer again by design.

This is a fix

Reviewed By: jerryzh168

Differential Revision: D57747749
  • Loading branch information
kwanghoon-meta authored and facebook-github-bot committed May 29, 2024
1 parent a71f049 commit cafccb8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
8 changes: 6 additions & 2 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def split_embedding_weights(
]

@classmethod
def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbeddingBag":
def from_float(
cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False
) -> "QuantBatchedEmbeddingBag":
assert hasattr(
module, "qconfig"
), "BaseEmbedding input float module must have qconfig defined"
Expand Down Expand Up @@ -490,7 +492,9 @@ def named_buffers(
yield append_prefix(prefix, f"{config.name}.weight_qbias"), weight_qbias

@classmethod
def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbedding":
def from_float(
cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False
) -> "QuantBatchedEmbedding":
assert hasattr(
module, "qconfig"
), "BaseEmbedding input float module must have qconfig defined"
Expand Down
14 changes: 11 additions & 3 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ def _get_name(self) -> str:

@classmethod
def from_float(
cls, module: OriginalEmbeddingBagCollection
cls,
module: OriginalEmbeddingBagCollection,
use_precomputed_fake_quant: bool = False,
) -> "EmbeddingBagCollection":
assert hasattr(
module, "qconfig"
Expand Down Expand Up @@ -617,7 +619,9 @@ def _get_name(self) -> str:
@classmethod
# pyre-ignore
def from_float(
cls, module: OriginalFeatureProcessedEmbeddingBagCollection
cls,
module: OriginalFeatureProcessedEmbeddingBagCollection,
use_precomputed_fake_quant: bool = False,
) -> "FeatureProcessedEmbeddingBagCollection":
fp_ebc = module
ebc = module._embedding_bag_collection
Expand Down Expand Up @@ -903,7 +907,11 @@ def forward(
return feature_embeddings

@classmethod
def from_float(cls, module: OriginalEmbeddingCollection) -> "EmbeddingCollection":
def from_float(
cls,
module: OriginalEmbeddingCollection,
use_precomputed_fake_quant: bool = False,
) -> "EmbeddingCollection":
assert hasattr(
module, "qconfig"
), "EmbeddingCollection input float module must have qconfig defined"
Expand Down

0 comments on commit cafccb8

Please sign in to comment.