diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 6b5546a8b4646..0d46000bab70d 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -252,6 +252,8 @@ def select_op( lambda align: all([dim % align == 0 for dim in [IC, OC]]), use_3xtf32, profile_all_alignments, + # Use fp32 accumulation for wgrad to align with cuDNN + accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype, ) if not find_first_valid: diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index bb591985cab5d..ac60f6c4ebc56 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -164,6 +164,8 @@ def get_default( lambda align: align == 1, # Only request align1 kernels use_3xtf32, profile_all_alignments=True, # To include all align1 kernels + # TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm + accumlator_dtype=out_dtype, ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] @@ -220,6 +222,8 @@ def select_op( lambda align: all([dim % align == 0 for dim in [M, N, K]]), use_3xtf32, profile_all_alignments=profile_all_alignments, + # TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm + accumlator_dtype=out_dtype, ) if not find_first_valid: @@ -266,6 +270,7 @@ def profile( profile_all_alignments=profile_all_alignments, find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, + ) name, opdef = create_gemm_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index d048ff5e1478c..c1ebc05a873dc 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -51,7 +51,7 @@ def generate_tensor_op_common( data_type = [ math_inst.element_a, math_inst.element_b, - math_inst.element_accumulator, + math_inst.element_c, math_inst.element_accumulator, ] @@ -63,7 +63,7 @@ def generate_tensor_op_common( def generate_sm75_tensor_op_1688( - out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False + out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False, accumlator_dtype="float32", ): """Generate GEMM or Conv2D kernels for Turing.""" assert out_dtype in ["float32", "float16", "int32"] @@ -77,6 +77,7 @@ def generate_sm75_tensor_op_1688( DataType.f16, DataType.f16, dtype_map[out_dtype], + dtype_map[accumlator_dtype], OpcodeClass.TensorOp, MathOperation.multiply_add, ) @@ -100,6 +101,7 @@ def generate_sm75_tensor_op_1688( dtype_map[arg0_dtype], dtype_map[arg1_dtype], DataType.s32, + DataType.s32, OpcodeClass.TensorOp, MathOperation.multiply_add_saturate, ), @@ -141,6 +143,7 @@ def generate_sm80_tensor_op_16816( check_align, use_3xtf32=True, profile_all_alignments=False, + accumlator_dtype="float32", ): """Generate GEMM or Conv2D kernels for Ampere.""" min_cc = 80 @@ -176,6 +179,7 @@ def get_default_tile_descriptions(block_k_factor): DataType.f16, DataType.f16, dtype_map[out_dtype], + dtype_map[accumlator_dtype], OpcodeClass.TensorOp, MathOperation.multiply_add, ) @@ -189,6 +193,7 @@ def get_default_tile_descriptions(block_k_factor): DataType.f32, DataType.f32, DataType.f32, + DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add, ), @@ -221,6 +226,7 @@ def get_default_tile_descriptions(block_k_factor): dtype_map[arg0_dtype], dtype_map[arg1_dtype], DataType.s32, + DataType.s32, OpcodeClass.TensorOp, MathOperation.multiply_add_saturate, ), @@ -248,6 +254,7 @@ def get_tile_descriptions(math_inst): check_align, False, profile_all_alignments, + accumlator_dtype=accumlator_dtype, ) else: # TF32 (float32 + float32 case) is only supported on sm80 @@ -292,6 +299,7 @@ def get_tile_descriptions(math_inst): "cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True), "cutlass.conv2d": (EpilogueFunctor.LinearCombination, False), "cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False), + "cutlass.conv2d_backward_weight": (EpilogueFunctor.LinearCombination, False), } diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index b21e5e0f1410f..8632ab15641d1 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -266,6 +266,7 @@ def __init__( instruction_shape, element_a, element_b, + element_c, element_accumulator, opcode_class, math_operation=MathOperation.multiply_add, @@ -273,6 +274,7 @@ def __init__( self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b + self.element_c = element_c self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 49c59206b4e6f..5c906f7e69bed 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -94,6 +94,10 @@ def make_conv2d_transpose_pattern(): return is_op("nn.conv2d_transpose")(wildcard(), wildcard()) +def make_conv2d_backward_weight_pattern(): + return is_op("nn.conv2d_backward_weight")(wildcard(), wildcard()) + + def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"): """Add pattern for residual blocks.""" residual_input = wildcard() @@ -173,6 +177,10 @@ def check_conv2d_transpose(call): return check_conv2d_common("nn.conv2d_transpose", "IHWO", call) +def check_conv2d_backward_weight(call): + return check_conv2d_common("nn.conv2d_backward_weight", "NHWC", call) + + def check_conv2d_residual(call, binary_op): """Check if the given conv2d workload can be offloaded to CUTLASS.""" conv2d = get_root_call(call, "nn.conv2d") @@ -245,6 +253,11 @@ def partition_for_cutlass(mod, params=None): # For now, no fusion for grad kernels conv2d_grad_patterns = [ ("cutlass.conv2d_transpose", make_conv2d_transpose_pattern(), check_conv2d_transpose), + ( + "cutlass.conv2d_backward_weight", + make_conv2d_backward_weight_pattern(), + check_conv2d_backward_weight, + ), ] residual_block_patterns = [] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 1fa909e748a07..b96c7924bc82e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1143,6 +1143,40 @@ def legalize_conv2d_backward_weight(attrs, inputs, types): return backward_weight +@reg.register_convert_op_layout("nn.conv2d_backward_weight") +def convert_conv2d_backward_weight(attrs, inputs, _, desired_layouts): + """Convert Layout pass registration for conv2d_backward_weight op. + Note that `desired_layouts` must be a pair [`data_layout`, `kernel_layouts`], + where `kernel_layouts` affects the output of this op (since the output of this op + is the weight gradient). The layout of the output gradient (the second input to this op) + is assumed to be the same as `data_layout`. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + assert len(desired_layouts) == 2, "A desired layout is expected for both of data and gradient." + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + new_attrs["grad_layout"] = desired_data_layout + new_attrs["data_layout"] = desired_data_layout + new_attrs["kernel_layout"] = desired_kernel_layout + new_attrs.pop("out_layout") + return relay.nn.conv2d_backward_weight(inputs[0], inputs[1], **new_attrs) + + ##################### # Shape functions # ##################### diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 15fcaaa021344..5a5d59a6e2182 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -130,6 +130,9 @@ def conv2d_backward_weight_cudnn( ): """Compute conv2d wgrad using CuDNN library""" assert layout in ["NCHW", "NHWC"] + # cuDNN does not seem to support other combination. + assert output_dtype == "float16", "Only supports fp16 output for cuDNN wgrad." + conv_dtype = "float32" return cudnn.conv_backward_filter( dy, x, @@ -139,6 +142,6 @@ def conv2d_backward_weight_cudnn( dilation, conv_mode=1, tensor_format=0 if layout == "NCHW" else 1, - conv_dtype=output_dtype, + conv_dtype=conv_dtype, groups=groups, ) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index fee94c45c91bf..fdd268d1d9d17 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -615,6 +615,11 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GetRootCall(callee->body.as(), 0, {"nn.conv2d_transpose"}); return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_), true, false)); + } else if (pattern_name == "cutlass.conv2d_backward_weight") { + const auto* conv2d_call = + GetRootCall(callee->body.as(), 0, {"nn.conv2d_backward_weight"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_), false, true)); } LOG(FATAL) << "Unknown composite function: " << pattern_name; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index b3afc8dd14968..0975058cefeae 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -214,6 +214,30 @@ def get_conv2d_transpose_nchw( ) +def get_conv2d_backward_weight( + d_shape, + w_shape, + o_shape, + padding, + strides, + out_dtype="float32", + data_dtype="float32", + weight_dtype="float32", +): + grad = relay.var("grad", shape=o_shape, dtype=weight_dtype) + data = relay.var("data", shape=d_shape, dtype=data_dtype) + out_channel = o_shape[1] + return relay.nn.conv2d_backward_weight( + grad=grad, + data=data, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + def convert_conv2d_layout(mod, desired_layouts): with tvm.transform.PassContext(opt_level=3): seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) @@ -602,6 +626,42 @@ def verify_conv2d( ) +def verify_conv2d_backward_weight( + expr_nchw, # can be dynamic batch + expr_ref, # always static batch + grad_shape, + data_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + use_cudnn_ref=False, + use_fast_math=False, + grad_dtype="float16", + data_dtype="float16", + ref_target="cuda", + use_vm=False, +): + np_grad = get_random_ndarray(grad_shape, grad_dtype) + np_data = get_random_ndarray(data_shape, data_dtype) + params = {} + input_names = ["grad", "data"] + return verify_conv2d_common( + expr_nchw, + expr_ref, + input_names, + [np_grad, np_data], + params, + sm, + atol, + rtol, + use_cudnn_ref, + False, + use_fast_math, + ref_target, + use_vm, + ) + + def test_conv2d(): padding = (1, 1) for IC in [3, 16]: @@ -768,5 +828,92 @@ def test_conv2d_transpose(): ) + +@pytest.mark.skip("weird") +def test_conv2d_backward_weight(): + OC = 8 + IC = 16 + d_shape = (16, IC, 32, 32) + w_shape = (OC, IC, 3, 3) + dtype = "float32" + + for strides in [(1, 1), (2, 2)]: + o_shape = (16, OC, 32 // strides[0], 32 // strides[1]) + padding = (1, 1) + + mod_nchw = get_conv2d_backward_weight( + d_shape, + w_shape, + o_shape, + padding, + strides, + out_dtype=dtype, + data_dtype=dtype, + weight_dtype=dtype, + ) + + verify_conv2d_backward_weight( + mod_nchw, + mod_nchw, + o_shape, + d_shape, + sm=80, + atol=1e-3, + rtol=1e-3, + use_cudnn_ref=False, + grad_dtype=dtype, + data_dtype=dtype, + ) + + +@pytest.mark.skip("weird") +def test_conv2d_bwd(): + IC = 16 + OC = 8 + dshape = (16, IC, 32, 32) + wshape = (OC, IC, 3, 3) + padding = (0, 0) + strides = (1, 1) + + conv = get_conv2d_nchw( + dshape, + wshape, + padding, + strides=strides, + out_dtype="float32", + data_dtype="float32", + weight_dtype="float32", + ) + fwd_mod = InferType()(tvm.IRModule.from_expr(conv)) + + use_fp16 = False # Note: large difference in tvm and cutlass Wgrad results if use fp16 + verify_dgrad = False # False to verify wgrad + tol = 1e-5 if verify_dgrad else 1e-4 # Wgrad slightly less accurate + + if use_fp16: + fwd_mod = ToMixedPrecision("float16")(fwd_mod) + + fwd_bwd_func = FirstOrderGradient()(fwd_mod)["main"] + + bwd_func = relay.Function( + fwd_bwd_func.params, + relay.TupleGetItem(relay.TupleGetItem(fwd_bwd_func.body, 1), 0 if verify_dgrad else 1), + ) + + verify_conv2d( + bwd_func, + bwd_func, + dshape, + wshape, + sm=80, + atol=1e-2 if use_fp16 else tol, + rtol=1e-2 if use_fp16 else tol, + use_cudnn_ref=False, + data_dtype="float32", + weight_dtype="float32", + use_vm=True, + ) + + if __name__ == "__main__": pytest.main([__file__])