Skip to content

Commit

Permalink
Codgen fix (apache#3)
Browse files Browse the repository at this point in the history
* Checkpoint, nothing works

* DNNL based codegen almost works

* Work in dnnl style

* Work in dnnl style

* Arg passing works

* Work in dnnl style

* Codegen somewhat works

* Requantization not working

* Codegen works

* Remove headsail_old

* Remove zero points from headsail codegen

* Remove unused things from codegen
  • Loading branch information
vilukissa68 authored Oct 25, 2024
1 parent d24af7d commit e8850ed
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 191 deletions.
205 changes: 70 additions & 135 deletions python/tvm/relay/op/contrib/headsail.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
from ..strategy.generic import is_depthwise_conv2d
logger = logging.getLogger("HEADSAIL")

conv2d_counter = True

def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
by Headsail.
Expand Down Expand Up @@ -75,62 +73,22 @@ def _func_wrapper(expr):
return _func_wrapper


#_register_external_op_helper("qnn.add")
#_register_external_op_helper("qnn.conv2d")
#_register_external_op_helper("qnn.relu")

# Special case to handle tflite models converted to relay with fused activation
def qnn_tflite_conv2d_bias_relu():
def qnn_tflite_conv2d_bias():
data = wildcard()
weight = wildcard()
bias = wildcard()
pattern = is_op("qnn.conv2d")(
data, weight, is_constant(), is_constant(), is_constant(), is_constant()
)
pattern = is_op("nn.bias_add")(pattern, bias)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
pattern = is_op("clip")(pattern)
#pattern = is_op("nn.bias_add")(pattern, bias)
pattern = is_op("add")(pattern, bias)
return pattern

def make_qnn_conv2d_pattern():
"""Make qnn.conv2d based pattern supported by DNNL
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)
return pat

@register_pattern_table("headsail")
def pattern_table():
tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", qnn_tflite_conv2d_bias_relu())
#tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", make_qnn_conv2d_pattern())
#tflite_conv2d_bias= ("headsail.tflite_conv2d_bias", qnn_tflite_conv2d_bias())
return [tflite_conv2d_bias_relu]
#return [tflite_conv2d_bias_relu, tflite_conv2d_b//ias]
tflite_conv2d_bias= ("headsail.tflite_conv2d_bias", qnn_tflite_conv2d_bias())
return [tflite_conv2d_bias]

class LegalizeQnnOpForHeadsail(DFPatternCallback):
"""Legalize QNN based patterns to match DNNL
Expand All @@ -142,22 +100,17 @@ class LegalizeQnnOpForHeadsail(DFPatternCallback):
%2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize
%3 = act(%2) // activation == clip
transform to DNNL compatible:
transform to Headsail compatible:
%1 = OP<int>(SRC, WGH)
%2 = cast(%1, dtype="float")
%2 = (%1 + bias) * o_scl
%3 = act(%2) * act_scl
%4 = %3 + SRC2 * sum_scl
%5 = %4 + dst_zp
%6 = cast(%5, dtype="float")
%2 = (%1 + bias)
%3 = cast(%2, dtype="float")
%4 = act(%4) * act_scl
%5 = %4 + SRC2 * sum_scl
%6 = cast(%5, dtype="int8")
where:
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl
dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl -
sum_rhs_zp * sum_rhs_scl / sum_out_scl
"""

def __init__(self):
Expand All @@ -184,17 +137,19 @@ def __init__(self):
self.sum_out_scl = is_constant()
self.sum_out_zp = is_constant()


self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))(
self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl
)
pat = is_op("add")(self.root, self.bias) | self.root # optional bias
pat = is_op("nn.bias_add")(self.root, self.bias) | self.root # optional bias
pat = is_op("qnn.requantize")(
pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp
)
pat = is_op("clip")(pat)
cast = is_op("cast")(pat)
pat = is_op("qnn.add")(
cast,
self.clip = is_op("clip")(pat)
pat = pat | self.clip

add = is_op("qnn.add")(
pat,
self.sum_src,
self.sum_lhs_scl,
self.sum_lhs_zp,
Expand All @@ -203,118 +158,98 @@ def __init__(self):
self.sum_out_scl,
self.sum_out_zp,
)
pat = is_op("clip")(pat)
self.pattern = pat | cast
add = is_op("clip")(add)
self.pattern = pat | add


def callback(self, pre, post, node_map):
root = node_map[self.root][0]
src = node_map[self.src][0]
wgh = node_map[self.wgh][0]
bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0]
src_scl = node_map[self.src_scl][0]
src_zp = node_map[self.src_zp][0]
rq_in_scl = node_map[self.rq_in_scl][0]
rq_in_zp = node_map[self.rq_in_zp][0]
rq_out_scl = node_map[self.rq_out_scl][0]
rq_out_zp = node_map[self.rq_out_zp][0]
final_dtype = "int8"

final_dtype = node_map[self.pattern][0].checked_type.dtype
print("src_scl", src_scl)
print("src_zp", src_zp)
print("rq_in_scl", rq_in_scl)
print("rq_in_zp", rq_in_zp)
print("rq_out_scl", rq_out_scl)
print("rq_out_zp", rq_out_zp)

if root.op == relay.op.get("qnn.conv2d"):
dst_layout = root.attrs.out_layout
dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout
wgh_layout = root.attrs.kernel_layout
else:
# qnn.dense has no layout attributes. Assume that is plain
dst_layout = "NC"
wgh_layout = "OI"
def cast_fp(op):
return relay.op.cast(op, dtype="float32")
def cast_int8(op):
return relay.op.cast(op, dtype="int8")

# TODO(@apeskov): dst_layout may ne blocked
bias_rank = len(dst_layout) - dst_layout.index("C")

sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None
# Default values if qnn.sum is not present
sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None
sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32")
sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32")

def cast_fp(op):
return relay.op.cast(op, dtype="float32")

# recalculate some factors
o_scl = rq_in_scl / rq_out_scl
# Compute scaling factors for requantization
zero_zp = relay.const(0, dtype="int32")
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
dst_zp = (
cast_fp(sum_out_zp)
- cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl
- cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl
)
bias = self.squeeze_bias(bias, dst_layout)
bias = (
cast_fp(bias)
- cast_fp(self.fake_op(src_zp, wgh, wgh_layout))
- cast_fp(rq_in_zp)
+ cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl
)
bias = self.broadcast_to_rank(bias, bias_rank)

zero_zp = relay.const(0, dtype="int32")
one_scl = relay.const(1.0, dtype="float32")
# Remove zero-point
rq_in_zp = zero_zp
rq_out_zp = zero_zp

# construct new graph with proper post op ordering
gr = tvm.relay.Call(
# Construct the new computation graph
output = tvm.relay.Call(
root.op,
[src, wgh, zero_zp, zero_zp, one_scl, one_scl],
[src, wgh, zero_zp, zero_zp, relay.const(1.0, dtype="float32"), relay.const(1.0, dtype="float32")],
root.attrs,
root.type_args,
root.span,
)
gr = relay.op.cast(gr, dtype="float32")
gr = gr + bias
gr = gr * o_scl
gr = relay.op.clip(gr, 0, 255) * act_scl
gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr
gr = gr + dst_zp
gr = relay.op.cast(gr, dtype=final_dtype)
return gr

@staticmethod
def fake_op(zp, wgh, layout):
"""Fake operator implementation for zp broadcast input"""
# Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct
# Dense: reduce kernel {OC, IC} -> {OC}
wgh_int = relay.op.cast(wgh, dtype="int32")
reduced_kernel = relay.op.sum(
wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True
output = output + bias

# Insert requantize node back
output = relay.qnn.op.requantize(
output,
input_scale=rq_in_scl,
input_zero_point=rq_in_zp,
output_scale=rq_out_scl,
output_zero_point=rq_out_zp,
out_dtype="int32"
)
return zp * reduced_kernel

@staticmethod
def squeeze_bias(bias, layout):
shape = transform.InferTypeLocal(bias).concrete_shape
c_position = layout.index("C") - len(layout) + len(shape)
squeeze_idxs = [i for i in range(len(shape)) if i != c_position]
return relay.op.squeeze(bias, squeeze_idxs)

@staticmethod
def broadcast_to_rank(op, rank):
"""Scalar or 1D tensor are supported"""
shape = transform.InferTypeLocal(op).concrete_shape
if len(shape) == 0:
return op
if len(shape) == 1:
return relay.op.expand_dims(op, 1, rank - 1)
raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.")

# Apply clipping with optional ReLU
if self.clip in node_map:
output = relay.op.clip(output, 0, 127)
else:
output = relay.op.clip(output, -128, 127)

# Apply qnn.add if sum was matched
if sum_src:
output = (cast_fp(output) * act_scl) + (cast_fp(sum_src) * sum_scl)
output = relay.op.clip(output, 0, 127)

# Cast to int8
output = relay.op.cast(output, dtype=final_dtype)


print("Legalization pass done")
return output

def legalize_qnn_for_headsail(mod):
"""Transform qnn primitives to DNNL compatible form. Eliminate source zero point and apply
strict sequence of post ops."""
print("Legalizing qnn for headsail")
#mod["main"] = rewrite(LegalizeQnnOpForHeadsail(), mod["main"])
mod["main"] = rewrite(LegalizeQnnOpForHeadsail(), mod["main"])

seq = tvm.transform.Sequential(
[
Expand Down
Loading

0 comments on commit e8850ed

Please sign in to comment.