Skip to content

Commit

Permalink
[quant][fx] Removing more unused code (pytorch#74603)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#74603

att

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: andrewor14

Differential Revision: D35071546

fbshipit-source-id: 273a7f0cb2a8f306864eb118916056fad3bb1399
(cherry picked from commit 9c31a50)
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Mar 25, 2022
1 parent cdcd1ac commit 0747bdb
Showing 1 changed file with 4 additions and 56 deletions.
60 changes: 4 additions & 56 deletions torch/ao/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,6 @@ def is_output_quantized(self, qconfig):
torch.nn.functional.softmax: int8_dtypes,
}

QAT_CONV_MODULE_CLASSES = \
(torch.nn.qat.Conv2d,
torch.nn.qat.Conv3d,
torch.nn.intrinsic.qat.ConvBn1d,
torch.nn.intrinsic.qat.ConvBn2d,
torch.nn.intrinsic.qat.ConvBn3d,
torch.nn.intrinsic.qat.ConvBnReLU1d,
torch.nn.intrinsic.qat.ConvBnReLU2d,
torch.nn.intrinsic.qat.ConvBnReLU3d,
torch.nn.intrinsic.qat.ConvReLU2d,
torch.nn.intrinsic.qat.ConvReLU3d)

@register_quant_pattern(operator.add)
@register_quant_pattern(operator.sub)
@register_quant_pattern(operator.mul)
Expand Down Expand Up @@ -295,18 +283,7 @@ def is_general_tensor_value_op(self) -> bool:
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
# TODO: rename Relu -> ReLU to be more consistent with other classes
class ConvReluQuantizeHandler(QuantizeHandler):
def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore[assignment]
self.conv_node = node
if node.op == "call_module":
self.conv = modules[str(self.conv_node.target)]
elif node.op == "call_function":
self.conv = node.target # type: ignore[assignment]
pass

@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
Expand All @@ -320,45 +297,20 @@ def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLUQuantizeHandler(QuantizeHandler):
def __init__(
self,
node: Node,
modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore[assignment]
self.linear_node = node
if node.op == 'call_module':
self.linear = modules[str(self.linear_node.target)]
pass

@register_quant_pattern(torch.nn.BatchNorm2d)
@register_quant_pattern(torch.nn.BatchNorm3d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
class BatchNormQuantizeHandler(QuantizeHandler):
def __init__(
self,
node: Node,
modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)
assert node.op == 'call_module'
self.bn_node = node
self.bn = modules[str(self.bn_node.target)]
pass

@register_quant_pattern(torch.nn.qat.Embedding)
@register_quant_pattern(torch.nn.qat.EmbeddingBag)
@register_quant_pattern(torch.nn.Embedding)
@register_quant_pattern(torch.nn.EmbeddingBag)
class EmbeddingQuantizeHandler(QuantizeHandler):
def __init__(
self,
node: Node,
modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)

def input_output_observed(self) -> bool:
return False

Expand All @@ -368,11 +320,7 @@ def input_output_observed(self) -> bool:
@register_quant_pattern(torch.nn.RNNCell)
@register_quant_pattern(torch.nn.LSTM)
class RNNDynamicQuantizeHandler(QuantizeHandler):
def __init__(
self,
node: Node,
modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)
pass

# we currently only support reference patterns for these ops so they have been removed
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
Expand Down

0 comments on commit 0747bdb

Please sign in to comment.