Skip to content

Commit

Permalink
add more explain of limitations
Browse files Browse the repository at this point in the history
  • Loading branch information
leslie-fang-intel committed Jun 9, 2023
1 parent 4b86060 commit f9ee24b
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(prototype) Quantization in PyTorch 2.0 Export Tutorial (Work in Progress)
(Work in Progress) Quantization in PyTorch 2.0 Export Tutorial
==============================================================

**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`__, `Jiong Gong <https://github.com/jgong5>`__
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`__, `Jiong Gong <https://github.com/jgong5>`__, `Jerry Zhang <https://github.com/jerryzh168>`__

Today we have `FX Graph Mode
Quantization <https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization>`__
Expand All @@ -20,12 +20,12 @@ Prerequisites:
- `Understanding of the quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__
- `Understanding of FX Graph Mode post training static quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__
- `Understanding of BackendConfig in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html?highlight=backend>`__
- `Understanding of QConfigMapping in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html#set-up-qconfigmapping-that-satisfies-the-backend-constraints>`__
- `Understanding of QConfig and QConfigMapping in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html#set-up-qconfigmapping-that-satisfies-the-backend-constraints>`__

Previously in ``FX Graph Mode Quantization`` we were using ``QConfigMapping`` for users to specify how the model to be quantized
and ``BackendConfig`` to specify the supported ways of quantization in their backend.
This API covers most use cases relatively well, but the main problem is that this API is not fully extensible
with two main limitations:
without involvement of the quantization team:

- Limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of
`issue-96288 <https://github.com/pytorch/pytorch/issues/96288>`__ to support ``conv add`` fusion with oneDNN library.
Expand All @@ -34,6 +34,15 @@ with two main limitations:
- Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend
developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization
team and backend developer.
- Currently we use ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's
intention of how they want their model to be quantized. ``BackendConfig`` describes what kind of quantization a backend support.
Currently ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping``
that is incompatible with a specific BackendConfig. This is not a great UX. Ideally we can structure this better
by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend
specific, so there will be less confusion about incompatibilities.
- Currently in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization.
This increases the things that user may need to care about, e.g. not only the ``dtype`` but also how the observation should
happen. These could potentially be hidden from user so that the user interface is simpler.

To address these scalability issues,
`Quantizer <https://github.com/pytorch/pytorch/blob/3e988316b5976df560c51c998303f56a234a6a1f/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L160>`__
Expand Down Expand Up @@ -127,24 +136,34 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b

# Step 4: Lower Reference Quantized Model into the backend

Inside the Quantizer, we will use the ``QuantizationAnnotation API``
to convey user's intent for what quantization spec to use and how to
observe certain tensor values in the prepare step. Now, we will have a step-by-step
tutorial for how to use the ``QuantizationAnnotation API`` with different types of
Quantizer uses annotation API to convey quantization intent for different operators/patterns.
Annotation API uses ``QuantizationSpec`` (
`definition is here <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38>`__
) to convey intent of how a tensor will be quantized,
e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc.
Furthermore, annotation API also allows quantizer to specify how a
tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver``
, or some customized observer.

``QuantizationSpec`` is used to annotate nodes' output tensor or input tensors. Annotating
input tensors is equivalent of annotating edge of the graph, while annotating output tensor is
equivalent of annotating node. Thus annotation API requires quantizer to annotate nodes (output tensor)
or edges (input tensors) of the graph.

Now, we will have a step-by-step tutorial for how to use the annotation API with different types of
``QuantizationSpec``.

1. Annotate common operator patterns
--------------------------------------------------------

In order to use the quantized pattern/operators, e.g. ``quantized add``,
backend developers will have intent to quantize (as expressed by
`QuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38>`__
) inputs, output of the pattern. Following is an example flow (take ``add`` operator as example)
backend developers will have intent to quantize (as expressed by ``QuantizationSpec``)
inputs, output of the pattern. Following is an example flow (take ``add`` operator as example)
of how this intent is conveyed in the quantization workflow with annotation API.

- Step 1: Identify the original floating point pattern in the FX graph. There are
several ways to identify this pattern: User may use a pattern matcher (e.g. SubgraphMatcher)
to match the operator pattern; User may go through the nodes from start to the end and compare
several ways to identify this pattern: Quantizer may use a pattern matcher (e.g. SubgraphMatcher)
to match the operator pattern; Quantizer may go through the nodes from start to the end and compare
the node's target type to match the operator pattern. In this example, we can use the
`get_source_partitions <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/fx/passes/utils/source_matcher_utils.py#L51>`__
to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node.
Expand Down Expand Up @@ -177,8 +196,9 @@ of how this intent is conveyed in the quantization workflow with annotation API.
- Step 3: Annotate the inputs and output of the pattern with
`QuantizationAnnotation <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L144>`__.
``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is of class ``Dict``
to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for
output node; ``_annotated`` field indicates if this node has already been annotated by quantizer.
to map each input ``Node`` to a ``QuantizationSpec``. It means to annotate each input edge with this ``QuantizationSpec``;
``output_qspec`` field expresses the ``QuantizationSpec`` used to
annotate the output node; ``_annotated`` field indicates if this node has already been annotated by quantizer.
In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects
created in above step 2 for two inputs and one output of ``add`` node.

Expand Down

0 comments on commit f9ee24b

Please sign in to comment.