From 44026ba7e7f5376a80cf0f2b333a0f25c0eeda6c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 28 Sep 2022 22:54:11 -0700 Subject: [PATCH] v0.1 --- .circleci/config.yml | 80 + .clang-format | 88 + .flake8 | 63 + .github/workflows/docs.yml | 67 + .github/workflows/lint.yml | 41 + .gitignore | 143 ++ .gitmodules | 10 + 3rdparty/composable_kernel | 1 + 3rdparty/cub | 1 + 3rdparty/cutlass | 1 + CITATION.cff | 54 + CODE_OF_CONDUCT.md | 80 + CONTRIBUTING.md | 37 + LICENSE | 201 ++ README.md | 119 + docker/Dockerfile.cuda | 58 + docker/Dockerfile.rocm | 147 ++ docker/README.md | 30 + docker/build.sh | 24 + docker/install/install_ait.sh | 5 + docker/install/install_basic_dep.sh | 4 + docker/install/install_detection_deps.sh | 9 + docker/install/install_doc_dep.sh | 6 + docker/install/install_test_dep.sh | 11 + docker/install/rocm_dev-requirements.txt | 3 + docker/rocm_fix/fix_10736.py | 9 + docs/Makefile | 22 + docs/README.md | 20 + docs/make.bat | 35 + docs/source/arch/index.rst | 12 + docs/source/arch/philosophy.rst | 16 + docs/source/conf.py | 67 + docs/source/debughints.rst | 14 + docs/source/genindex.rst | 2 + docs/source/index.rst | 44 + docs/source/install/index.rst | 64 + docs/source/reference/backend.rst | 60 + docs/source/reference/compiler.rst | 37 + docs/source/reference/cuda.rst | 12 + docs/source/reference/env.rst | 37 + docs/source/reference/frontend.rst | 14 + docs/source/reference/index.rst | 16 + docs/source/reference/ops.rst | 8 + docs/source/reference/rocm.rst | 11 + docs/source/reference/testing.rst | 27 + docs/source/reference/transform.rst | 209 ++ docs/source/reference/utils.rst | 12 + docs/source/runtime/cxx_design.rst | 29 + docs/source/runtime/index.rst | 9 + docs/source/runtime/py_design.rst | 135 + docs/source/tutorial/how_to_add_op.rst | 302 +++ docs/source/tutorial/how_to_infer_pt.rst | 188 ++ docs/source/tutorial/how_to_visualize.rst | 85 + docs/source/tutorial/index.rst | 9 + docs/static/ait_model.html | 866 +++++++ examples/01_resnet-50/README.md | 84 + examples/01_resnet-50/benchmark_ait.py | 132 + examples/01_resnet-50/benchmark_mi250.sh | 4 + examples/01_resnet-50/benchmark_pt.py | 51 + examples/01_resnet-50/infer_with_torch.py | 135 + examples/01_resnet-50/modeling/__init__.py | 14 + examples/01_resnet-50/modeling/resnet.py | 456 ++++ examples/01_resnet-50/weight_utils.py | 173 ++ examples/02_detectron2/README.md | 169 ++ examples/02_detectron2/compile_model.py | 149 ++ examples/02_detectron2/configs/__init__.py | 17 + examples/02_detectron2/configs/config.py | 26 + examples/02_detectron2/configs/defaults.py | 668 +++++ .../configs/faster_rcnn_R_101_FPN.yaml | 47 + .../configs/faster_rcnn_R_50_FPN.yaml | 45 + .../configs/mask_rcnn_R_101_FPN.yaml | 48 + .../configs/mask_rcnn_R_50_FPN.yaml | 46 + examples/02_detectron2/demo.py | 105 + .../modeling/backbone/__init__.py | 25 + .../02_detectron2/modeling/backbone/fpn.py | 228 ++ .../02_detectron2/modeling/backbone/resnet.py | 459 ++++ .../02_detectron2/modeling/backbone/utils.py | 30 + .../modeling/meta_arch/__init__.py | 18 + .../02_detectron2/modeling/meta_arch/rcnn.py | 56 + .../modeling/proposal_generator/__init__.py | 18 + .../modeling/proposal_generator/rpn.py | 177 ++ .../modeling/roi_heads/__init__.py | 20 + .../modeling/roi_heads/box_head.py | 67 + .../modeling/roi_heads/fast_rcnn.py | 209 ++ .../modeling/roi_heads/mask_head.py | 65 + .../modeling/roi_heads/roi_heads.py | 91 + examples/02_detectron2/predictor/__init__.py | 18 + .../02_detectron2/predictor/builtin_meta.py | 180 ++ examples/02_detectron2/predictor/predictor.py | 359 +++ .../02_detectron2/prepare_and_run_rcnn.sh | 59 + .../02_detectron2/tools/convert_pt2ait.py | 157 ++ examples/03_bert/README.md | 303 +++ examples/03_bert/benchmark_ait.py | 298 +++ examples/03_bert/benchmark_mi250.sh | 4 + examples/03_bert/benchmark_pt.py | 148 ++ examples/03_bert/demo.py | 108 + examples/03_bert/modeling/__init__.py | 14 + examples/03_bert/modeling/bert.py | 391 +++ examples/03_bert/modeling/torch_model.py | 51 + examples/04_vit/README.md | 126 + examples/04_vit/benchmark_ait.py | 186 ++ examples/04_vit/benchmark_mi250.sh | 4 + examples/04_vit/benchmark_pt.py | 100 + .../04_vit/modeling/vision_transformer.py | 323 +++ examples/04_vit/verification.py | 164 ++ examples/04_vit/weight_utils.py | 115 + examples/05_stable_diffusion/README.md | 136 ++ examples/05_stable_diffusion/benchmark.py | 304 +++ examples/05_stable_diffusion/benchmark_pt.py | 46 + examples/05_stable_diffusion/compile.py | 353 +++ examples/05_stable_diffusion/demo.py | 46 + .../05_stable_diffusion/modeling/attention.py | 104 + examples/05_stable_diffusion/modeling/clip.py | 590 +++++ .../modeling/embeddings.py | 101 + .../05_stable_diffusion/modeling/resnet.py | 238 ++ .../modeling/unet_2d_condition.py | 251 ++ .../modeling/unet_blocks.py | 761 ++++++ examples/05_stable_diffusion/modeling/vae.py | 152 ++ .../pipeline_stable_diffusion_ait.py | 371 +++ .../06_how_to_add_an_op/how_to_add_an_op.py | 249 ++ .../how_to_run_pt_model.py | 131 + licenses/LICENSE.composable_kernel.txt | 28 + licenses/LICENSE.cub.txt | 24 + licenses/LICENSE.cutlass.txt | 27 + licenses/LICENSE.dmlc.txt | 201 ++ licenses/LICENSE.flash_attention.txt | 201 ++ licenses/LICENSE.hipcub.txt | 25 + licenses/LICENSE.markdown_table.txt | 21 + licenses/LICENSE.oneflow.txt | 202 ++ licenses/LICENSE.pydot.txt | 21 + licenses/LICENSE.pytorch.txt | 77 + licenses/LICENSE.tensorrt.txt | 337 +++ licenses/license.header.txt | 13 + python/aitemplate/__init__.py | 42 + python/aitemplate/_libinfo.py | 17 + python/aitemplate/backend/__init__.py | 37 + python/aitemplate/backend/backend_spec.py | 280 +++ python/aitemplate/backend/builder.py | 295 +++ python/aitemplate/backend/codegen.py | 744 ++++++ .../backend/common/concatenate_common.py | 839 +++++++ .../backend/common/elementwise_common.py | 881 +++++++ .../aitemplate/backend/common/gemm_common.py | 72 + .../aitemplate/backend/common/split_common.py | 569 +++++ .../backend/common/tensor/argmax_common.py | 456 ++++ .../common/tensor/batch_gather_common.py | 221 ++ .../common/tensor/permute021_common.py | 304 +++ .../common/tensor/permute102_common.py | 310 +++ .../common/tensor/permute210_common.py | 289 +++ .../backend/common/tensor/slice_common.py | 902 +++++++ .../tensor/slice_reshape_scatter_common.py | 149 ++ .../backend/common/tensor/topk_common.py | 769 ++++++ .../backend/common/tensor_accessor.cuh | 110 + .../backend/common/tensor_accessor_codegen.py | 163 ++ .../backend/common/upsampling2d_common.py | 425 ++++ .../common/vision_ops/efficient_nms_common.py | 250 ++ .../common/vision_ops/efficient_nms_kernel.py | 1160 +++++++++ .../multi_level_roi_align_common.py | 464 ++++ .../backend/common/vision_ops/nms_common.py | 235 ++ .../backend/common/vision_ops/nms_kernel.py | 565 +++++ .../common/vision_ops/roi_align_common.py | 392 +++ python/aitemplate/backend/cuda/__init__.py | 37 + .../backend/cuda/attention/__init__.py | 20 + .../backend/cuda/attention/flash_attention.py | 319 +++ .../backend/cuda/attention/src/fmha.h | 211 ++ .../backend/cuda/attention/src/fmha/gemm.h | 482 ++++ .../cuda/attention/src/fmha/gmem_tile.h | 608 +++++ .../cuda/attention/src/fmha/kernel_traits.h | 143 ++ .../backend/cuda/attention/src/fmha/mask.h | 117 + .../cuda/attention/src/fmha/smem_tile.h | 1843 ++++++++++++++ .../backend/cuda/attention/src/fmha/softmax.h | 708 ++++++ .../backend/cuda/attention/src/fmha/utils.h | 1332 ++++++++++ .../src/fmha_block_fprop_fp16_kernel.sm80.cu | 155 ++ .../src/fmha_block_fprop_kernel_1xN.h | 661 +++++ .../cuda/attention/src/fmha_blockmask.h | 69 + .../src/fmha_fprop_fp16_kernel.sm80.cu | 262 ++ .../attention/src/fmha_fprop_kernel_1xN.h | 795 ++++++ .../backend/cuda/attention/src/fmha_kernel.h | 204 ++ .../backend/cuda/attention/src/fmha_utils.h | 111 + .../cuda/attention/src/licenses/LICENSE | 201 ++ .../backend/cuda/attention/src/philox.cuh | 171 ++ .../backend/cuda/common/__init__.py | 19 + .../backend/cuda/common/dummy_op.py | 36 + .../backend/cuda/conv2d/__init__.py | 33 + .../aitemplate/backend/cuda/conv2d/common.py | 244 ++ .../conv2d/common_conv2d_bias_activation.py | 373 +++ .../common_conv2d_bias_add_activation.py | 348 +++ .../cuda/conv2d/common_conv2d_few_channels.py | 111 + .../aitemplate/backend/cuda/conv2d/conv2d.py | 420 ++++ .../backend/cuda/conv2d/conv2d_bias.py | 86 + .../backend/cuda/conv2d/conv2d_bias_add.py | 149 ++ .../cuda/conv2d/conv2d_bias_add_hardswish.py | 149 ++ .../cuda/conv2d/conv2d_bias_add_relu.py | 149 ++ .../cuda/conv2d/conv2d_bias_few_channels.py | 211 ++ .../cuda/conv2d/conv2d_bias_hardswish.py | 81 + .../conv2d_bias_hardswish_few_channels.py | 123 + .../backend/cuda/conv2d/conv2d_bias_relu.py | 81 + .../conv2d/conv2d_bias_relu_few_channels.py | 115 + .../cuda/conv2d/conv2d_bias_sigmoid.py | 82 + .../backend/cuda/conv2d/transposed_conv2d.py | 256 ++ .../cuda/conv2d/transposed_conv2d_bias.py | 264 ++ python/aitemplate/backend/cuda/cuda_common.py | 48 + .../backend/cuda/elementwise/__init__.py | 20 + .../backend/cuda/elementwise/custom_math.cuh | 299 +++ .../cuda/elementwise/fused_elementwise.py | 65 + .../backend/cuda/embedding/__init__.py | 16 + .../backend/cuda/embedding/bert_embeddings.py | 450 ++++ .../cuda/gemm_epilogue_vistor/__init__.py | 18 + .../bmm_common_softmax.py | 256 ++ .../gemm_epilogue_vistor/bmm_rcr_softmax.py | 161 ++ .../gemm_epilogue_vistor/common_softmax.py | 538 ++++ .../gemm_rcr_bias_softmax.py | 118 + .../gemm_epilogue_vistor/gemm_rcr_softmax.py | 216 ++ .../include/gemm_with_softmax.h | 302 +++ .../backend/cuda/gemm_special/__init__.py | 21 + .../backend/cuda/gemm_special/bmm_rcr_n1.py | 616 +++++ .../cuda/gemm_special/bmm_rrr_k1_tanh.py | 258 ++ .../cuda/gemm_special/gemm_rrr_small_nk.py | 374 +++ .../backend/cuda/gemm_universal/__init__.py | 61 + .../backend/cuda/gemm_universal/bmm_ccr.py | 142 ++ .../cuda/gemm_universal/bmm_ccr_add.py | 120 + .../backend/cuda/gemm_universal/bmm_common.py | 391 +++ .../backend/cuda/gemm_universal/bmm_crr.py | 144 ++ .../cuda/gemm_universal/bmm_crr_add.py | 104 + .../cuda/gemm_universal/bmm_permute_common.py | 166 ++ .../backend/cuda/gemm_universal/bmm_rcr.py | 211 ++ .../cuda/gemm_universal/bmm_rcr_permute.py | 211 ++ .../backend/cuda/gemm_universal/bmm_rrr.py | 145 ++ .../cuda/gemm_universal/bmm_rrr_add.py | 121 + .../cuda/gemm_universal/bmm_rrr_permute.py | 219 ++ .../gemm_universal/bmm_softmax_bmm_permute.py | 31 + .../backend/cuda/gemm_universal/common.py | 944 +++++++ .../cuda/gemm_universal/common_bias.py | 134 + .../gemm_universal/common_bias_activation.py | 93 + .../gemm_universal/common_bias_broadcast.py | 585 +++++ .../cuda/gemm_universal/common_permute.py | 351 +++ .../backend/cuda/gemm_universal/gemm_rcr.py | 229 ++ .../cuda/gemm_universal/gemm_rcr_bias.py | 158 ++ .../cuda/gemm_universal/gemm_rcr_bias_add.py | 98 + .../gemm_universal/gemm_rcr_bias_add_add.py | 98 + .../gemm_rcr_bias_add_add_relu.py | 98 + .../gemm_universal/gemm_rcr_bias_add_relu.py | 98 + .../gemm_universal/gemm_rcr_bias_fast_gelu.py | 144 ++ .../cuda/gemm_universal/gemm_rcr_bias_gelu.py | 106 + .../gemm_universal/gemm_rcr_bias_hardswish.py | 106 + .../cuda/gemm_universal/gemm_rcr_bias_mul.py | 98 + .../gemm_universal/gemm_rcr_bias_mul_add.py | 98 + .../gemm_universal/gemm_rcr_bias_mul_tanh.py | 98 + .../gemm_universal/gemm_rcr_bias_permute.py | 117 + .../cuda/gemm_universal/gemm_rcr_bias_relu.py | 107 + .../gemm_universal/gemm_rcr_bias_sigmoid.py | 107 + .../gemm_rcr_bias_sigmoid_mul.py | 98 + .../gemm_rcr_bias_sigmoid_mul_tanh.py | 98 + .../gemm_universal/gemm_rcr_bias_swish.py | 107 + .../cuda/gemm_universal/gemm_rcr_bias_tanh.py | 144 ++ .../cuda/gemm_universal/gemm_rcr_permute.py | 220 ++ .../backend/cuda/gemm_universal/gemm_rrr.py | 161 ++ .../cuda/gemm_universal/gemm_rrr_permute.py | 221 ++ .../cuda/gemm_universal/group_common.py | 974 ++++++++ .../cuda/gemm_universal/group_common_bias.py | 76 + .../cuda/gemm_universal/group_gemm_rcr.py | 102 + .../gemm_universal/group_gemm_rcr_bias.py | 75 + .../group_gemm_rcr_bias_relu.py | 75 + .../group_gemm_rcr_bias_sigmoid.py | 75 + .../backend/cuda/gemm_universal/layout.py | 79 + .../cuda/gemm_universal/perm021fc_ccr.py | 124 + .../cuda/gemm_universal/perm021fc_ccr_bias.py | 130 + .../perm021fc_ccr_bias_permute.py | 165 ++ .../cuda/gemm_universal/perm021fc_crc.py | 127 + .../cuda/gemm_universal/perm021fc_crc_bias.py | 133 + .../cuda/gemm_universal/perm102_bmm_rcr.py | 179 ++ .../gemm_universal/perm102_bmm_rcr_bias.py | 155 ++ .../cuda/gemm_universal/perm102_bmm_rrr.py | 148 ++ .../gemm_universal/perm102_bmm_rrr_bias.py | 155 ++ .../backend/cuda/groupnorm/__init__.py | 17 + .../backend/cuda/groupnorm/groupnorm.py | 38 + .../cuda/groupnorm/groupnorm_common.py | 179 ++ .../cuda/groupnorm/groupnorm_kernel.cuh | 561 +++++ .../backend/cuda/groupnorm/groupnorm_swish.py | 38 + .../cuda/layernorm_sigmoid_mul/__init__.py | 28 + .../batch_layernorm_sigmoid_mul.py | 136 ++ .../group_layernorm_sigmoid_mul.py | 303 +++ .../layernorm_sigmoid_mul/layernorm_common.py | 113 + .../layernorm_sigmoid_mul.py | 184 ++ .../layernorm_sigmoid_mul_kernel.cuh | 1735 +++++++++++++ .../aitemplate/backend/cuda/lib_template.py | 46 + .../backend/cuda/padding/__init__.py | 20 + .../backend/cuda/padding/nhwc3to4.py | 218 ++ .../backend/cuda/padding/nhwc3to8.py | 221 ++ .../backend/cuda/padding/pad_last_dim.py | 262 ++ .../backend/cuda/pool2d/__init__.py | 20 + .../backend/cuda/pool2d/avg_pool2d.py | 191 ++ .../backend/cuda/pool2d/max_pool2d.py | 236 ++ .../aitemplate/backend/cuda/pool2d/pool2d.py | 76 + .../backend/cuda/reduce/__init__.py | 27 + .../backend/cuda/reduce/reduce_3d.py | 995 ++++++++ .../backend/cuda/reduce/reduce_common.py | 241 ++ .../backend/cuda/reduce/reduce_mean.py | 88 + .../backend/cuda/reduce/reduce_small_axis.py | 425 ++++ .../backend/cuda/reduce/reduce_sum.py | 105 + python/aitemplate/backend/cuda/reduce/var.py | 289 +++ .../backend/cuda/reduce/vector_norm.py | 102 + .../backend/cuda/softmax/__init__.py | 20 + .../backend/cuda/softmax/softmax.cuh | 538 ++++ .../backend/cuda/softmax/softmax.py | 347 +++ python/aitemplate/backend/cuda/target_def.py | 171 ++ .../backend/cuda/tensor/__init__.py | 51 + .../aitemplate/backend/cuda/tensor/argmax.py | 52 + .../backend/cuda/tensor/batch_gather.py | 46 + .../backend/cuda/tensor/concatenate.py | 87 + .../backend/cuda/tensor/concatenate_tanh.py | 103 + .../backend/cuda/tensor/dynamic_slice.py | 84 + .../aitemplate/backend/cuda/tensor/expand.py | 31 + .../aitemplate/backend/cuda/tensor/gather.py | 412 ++++ .../backend/cuda/tensor/permute021.py | 91 + .../backend/cuda/tensor/permute102.py | 91 + .../backend/cuda/tensor/permute210.py | 72 + .../cuda/tensor/slice_reshape_scatter.py | 167 ++ .../backend/cuda/tensor/slice_scatter.py | 90 + .../aitemplate/backend/cuda/tensor/split.py | 77 + python/aitemplate/backend/cuda/tensor/topk.py | 52 + .../backend/cuda/upsample/__init__.py | 20 + .../backend/cuda/upsample/upsampling2d.py | 96 + .../backend/cuda/upsample/upsampling2d_add.py | 99 + python/aitemplate/backend/cuda/utils.py | 63 + .../backend/cuda/view_ops/__init__.py | 20 + .../backend/cuda/view_ops/view_ops.py | 230 ++ .../backend/cuda/vision_ops/__init__.py | 21 + .../backend/cuda/vision_ops/nms/__init__.py | 18 + .../cuda/vision_ops/nms/batched_nms.py | 141 ++ .../vision_ops/nms/batched_nms_kernel.cuh | 203 ++ .../cuda/vision_ops/nms/efficient_nms.py | 62 + .../backend/cuda/vision_ops/nms/nms.py | 52 + .../cuda/vision_ops/roi_ops/__init__.py | 20 + .../roi_ops/multi_level_roi_align.py | 86 + .../cuda/vision_ops/roi_ops/roi_align.py | 108 + .../cuda/vision_ops/roi_ops/roi_ops.py | 94 + python/aitemplate/backend/main_templates.py | 378 +++ python/aitemplate/backend/profiler_cache.py | 554 +++++ python/aitemplate/backend/profiler_runner.py | 123 + python/aitemplate/backend/registry.py | 99 + python/aitemplate/backend/rocm/__init__.py | 30 + .../backend/rocm/common/__init__.py | 19 + .../backend/rocm/common/dummy_op.py | 36 + .../backend/rocm/conv2d/__init__.py | 36 + .../aitemplate/backend/rocm/conv2d/common.py | 892 +++++++ .../aitemplate/backend/rocm/conv2d/conv2d.py | 170 ++ .../backend/rocm/conv2d/conv2d_bias.py | 163 ++ .../rocm/conv2d/conv2d_bias_add_relu.py | 207 ++ .../backend/rocm/conv2d/conv2d_bias_relu.py | 165 ++ .../rocm/conv2d/conv2d_bias_sigmoid.py | 212 ++ .../backend/rocm/conv2d/transposed_conv2d.py | 198 ++ .../conv2d/transposed_conv2d_bias_relu.py | 172 ++ .../backend/rocm/elementwise/__init__.py | 20 + .../backend/rocm/elementwise/custom_math.h | 318 +++ .../rocm/elementwise/fused_elementwise.py | 65 + .../aitemplate/backend/rocm/gemm/__init__.py | 49 + .../aitemplate/backend/rocm/gemm/bmm_ccr.py | 170 ++ .../backend/rocm/gemm/bmm_common.py | 252 ++ .../aitemplate/backend/rocm/gemm/bmm_crr.py | 170 ++ .../backend/rocm/gemm/bmm_permute_common.py | 65 + .../aitemplate/backend/rocm/gemm/bmm_rcr.py | 170 ++ .../backend/rocm/gemm/bmm_rcr_permute.py | 185 ++ .../aitemplate/backend/rocm/gemm/bmm_rrr.py | 170 ++ .../backend/rocm/gemm/bmm_rrr_permute.py | 185 ++ .../backend/rocm/gemm/bmm_softmax_bmm.py | 289 +++ .../rocm/gemm/bmm_softmax_bmm_permute.py | 387 +++ python/aitemplate/backend/rocm/gemm/common.py | 974 ++++++++ .../backend/rocm/gemm/gemm_epilogue.py | 90 + .../aitemplate/backend/rocm/gemm/gemm_rcr.py | 151 ++ .../backend/rocm/gemm/gemm_rcr_bias.py | 151 ++ .../backend/rocm/gemm/gemm_rcr_bias_add.py | 193 ++ .../rocm/gemm/gemm_rcr_bias_add_add.py | 193 ++ .../rocm/gemm/gemm_rcr_bias_add_add_relu.py | 194 ++ .../rocm/gemm/gemm_rcr_bias_add_relu.py | 194 ++ .../rocm/gemm/gemm_rcr_bias_fast_gelu.py | 156 ++ .../backend/rocm/gemm/gemm_rcr_bias_mul.py | 193 ++ .../rocm/gemm/gemm_rcr_bias_mul_add.py | 164 ++ .../rocm/gemm/gemm_rcr_bias_mul_tanh.py | 196 ++ .../rocm/gemm/gemm_rcr_bias_permute.py | 166 ++ .../rocm/gemm/gemm_rcr_bias_permute_m2n3.py | 183 ++ .../rocm/gemm/gemm_rcr_bias_permute_m3n2.py | 183 ++ .../backend/rocm/gemm/gemm_rcr_bias_relu.py | 153 ++ .../rocm/gemm/gemm_rcr_bias_sigmoid.py | 204 ++ .../rocm/gemm/gemm_rcr_bias_sigmoid_mul.py | 195 ++ .../gemm/gemm_rcr_bias_sigmoid_mul_tanh.py | 200 ++ .../backend/rocm/gemm/gemm_rcr_bias_swish.py | 157 ++ .../backend/rocm/gemm/gemm_rcr_bias_tanh.py | 206 ++ .../rocm/gemm/gemm_rcr_permute_m2n3.py | 183 ++ .../aitemplate/backend/rocm/gemm/gemm_rrr.py | 151 ++ .../rocm/gemm/gemm_rrr_bias_permute.py | 166 ++ python/aitemplate/backend/rocm/gemm/layout.py | 246 ++ .../backend/rocm/gemm/permute_common.py | 128 + .../aitemplate/backend/rocm/lib_template.py | 42 + .../backend/rocm/normalization/__init__.py | 18 + .../backend/rocm/normalization/groupnorm.py | 444 ++++ .../rocm/normalization/groupnorm_swish.py | 50 + .../backend/rocm/normalization/layernorm.py | 371 +++ .../backend/rocm/normalization/norm_common.py | 503 ++++ .../backend/rocm/normalization/softmax.py | 239 ++ .../backend/rocm/pool2d/__init__.py | 20 + .../backend/rocm/pool2d/avg_pool2d.py | 45 + .../backend/rocm/pool2d/max_pool2d.py | 45 + .../aitemplate/backend/rocm/pool2d/pool2d.py | 278 +++ python/aitemplate/backend/rocm/target_def.py | 265 ++ .../backend/rocm/tensor/__init__.py | 31 + .../aitemplate/backend/rocm/tensor/argmax.py | 51 + .../backend/rocm/tensor/batch_gather.py | 45 + .../backend/rocm/tensor/concatenate.py | 85 + .../backend/rocm/tensor/concatenate_tanh.py | 122 + .../backend/rocm/tensor/dynamic_slice.py | 84 + .../backend/rocm/tensor/permute021.py | 90 + .../backend/rocm/tensor/permute102.py | 90 + .../backend/rocm/tensor/permute210.py | 71 + .../rocm/tensor/slice_reshape_scatter.py | 129 + .../backend/rocm/tensor/slice_scatter.py | 90 + .../aitemplate/backend/rocm/tensor/split.py | 77 + python/aitemplate/backend/rocm/tensor/topk.py | 51 + .../backend/rocm/upsample/__init__.py | 20 + .../backend/rocm/upsample/upsampling2d.py | 96 + .../backend/rocm/upsample/upsampling2d_add.py | 99 + python/aitemplate/backend/rocm/utils.py | 114 + .../backend/rocm/view_ops/__init__.py | 20 + .../backend/rocm/view_ops/view_ops.py | 228 ++ .../backend/rocm/vision_ops/__init__.py | 19 + .../backend/rocm/vision_ops/efficient_nms.py | 53 + .../aitemplate/backend/rocm/vision_ops/nms.py | 51 + .../rocm/vision_ops/roi_ops/__init__.py | 20 + .../roi_ops/multi_level_roi_align.py | 87 + .../rocm/vision_ops/roi_ops/roi_align.py | 108 + python/aitemplate/backend/target.py | 433 ++++ python/aitemplate/backend/task_runner.py | 327 +++ python/aitemplate/compiler/__init__.py | 29 + python/aitemplate/compiler/base.py | 829 +++++++ python/aitemplate/compiler/compiler.py | 236 ++ python/aitemplate/compiler/model.py | 856 +++++++ python/aitemplate/compiler/op_registry.py | 23 + python/aitemplate/compiler/ops/__init__.py | 34 + .../compiler/ops/attention/__init__.py | 21 + .../compiler/ops/attention/flash_attention.py | 186 ++ .../compiler/ops/common/__init__.py | 24 + .../compiler/ops/common/elementwise.py | 153 ++ .../compiler/ops/common/epilogue.py | 59 + .../compiler/ops/common/fused_elementwise.py | 156 ++ python/aitemplate/compiler/ops/common/math.py | 87 + .../compiler/ops/common/python_ops.py | 56 + .../compiler/ops/common/view_ops.py | 495 ++++ .../aitemplate/compiler/ops/conv/__init__.py | 32 + .../compiler/ops/conv/cache_entry.py | 71 + .../ops/conv/common_conv2d_bias_activation.py | 97 + .../conv/common_conv2d_bias_add_activation.py | 74 + python/aitemplate/compiler/ops/conv/conv2d.py | 621 +++++ .../compiler/ops/conv/conv2d_bias.py | 74 + .../compiler/ops/conv/conv2d_bias_add.py | 77 + .../ops/conv/conv2d_bias_add_hardswish.py | 76 + .../compiler/ops/conv/conv2d_bias_add_relu.py | 77 + .../ops/conv/conv2d_bias_few_channels.py | 41 + .../ops/conv/conv2d_bias_hardswish.py | 72 + .../conv2d_bias_hardswish_few_channels.py | 29 + .../compiler/ops/conv/conv2d_bias_relu.py | 71 + .../ops/conv/conv2d_bias_relu_few_channels.py | 39 + .../compiler/ops/conv/conv2d_bias_sigmoid.py | 72 + .../conv/special_conv2d_bias_activation.py | 87 + .../compiler/ops/conv/transposed_conv2d.py | 111 + .../ops/conv/transposed_conv2d_bias.py | 110 + .../ops/conv/transposed_conv2d_bias_relu.py | 73 + .../compiler/ops/embedding/__init__.py | 20 + .../compiler/ops/embedding/bert_embeddings.py | 136 ++ .../ops/gemm_epilogue_vistor/__init__.py | 20 + .../gemm_epilogue_vistor/bmm_rcr_softmax.py | 127 + .../gemm_rcr_bias_softmax.py | 76 + .../gemm_epilogue_vistor/gemm_rcr_softmax.py | 67 + .../compiler/ops/gemm_special/__init__.py | 23 + .../compiler/ops/gemm_special/bmm_rcr_n1.py | 97 + .../ops/gemm_special/bmm_rrr_k1_tanh.py | 84 + .../ops/gemm_special/gemm_rrr_small_nk.py | 109 + .../compiler/ops/gemm_universal/__init__.py | 63 + .../compiler/ops/gemm_universal/bmm.py | 67 + .../compiler/ops/gemm_universal/bmm_ccr.py | 111 + .../ops/gemm_universal/bmm_ccr_add.py | 81 + .../compiler/ops/gemm_universal/bmm_crr.py | 111 + .../ops/gemm_universal/bmm_crr_add.py | 81 + .../compiler/ops/gemm_universal/bmm_rcr.py | 111 + .../ops/gemm_universal/bmm_rcr_permute.py | 106 + .../compiler/ops/gemm_universal/bmm_rrr.py | 109 + .../ops/gemm_universal/bmm_rrr_add.py | 77 + .../ops/gemm_universal/bmm_rrr_permute.py | 105 + .../ops/gemm_universal/bmm_softmax_bmm.py | 159 ++ .../gemm_universal/bmm_softmax_bmm_permute.py | 184 ++ .../ops/gemm_universal/cache_entry.py | 58 + .../ops/gemm_universal/gemm_common.py | 762 ++++++ .../compiler/ops/gemm_universal/gemm_rcr.py | 103 + .../ops/gemm_universal/gemm_rcr_bias.py | 100 + .../ops/gemm_universal/gemm_rcr_bias_add.py | 45 + .../gemm_universal/gemm_rcr_bias_add_add.py | 46 + .../gemm_rcr_bias_add_add_relu.py | 46 + .../gemm_universal/gemm_rcr_bias_add_relu.py | 45 + .../gemm_universal/gemm_rcr_bias_broadcast.py | 74 + .../gemm_universal/gemm_rcr_bias_fast_gelu.py | 43 + .../ops/gemm_universal/gemm_rcr_bias_gelu.py | 43 + .../gemm_universal/gemm_rcr_bias_hardswish.py | 42 + .../ops/gemm_universal/gemm_rcr_bias_mul.py | 45 + .../gemm_universal/gemm_rcr_bias_mul_add.py | 46 + .../gemm_universal/gemm_rcr_bias_mul_tanh.py | 45 + .../gemm_universal/gemm_rcr_bias_permute.py | 69 + .../ops/gemm_universal/gemm_rcr_bias_relu.py | 43 + .../gemm_universal/gemm_rcr_bias_sigmoid.py | 43 + .../gemm_rcr_bias_sigmoid_mul.py | 44 + .../gemm_rcr_bias_sigmoid_mul_tanh.py | 45 + .../ops/gemm_universal/gemm_rcr_bias_swish.py | 43 + .../ops/gemm_universal/gemm_rcr_bias_tanh.py | 43 + .../ops/gemm_universal/gemm_rcr_permute.py | 72 + .../compiler/ops/gemm_universal/gemm_rrr.py | 106 + .../ops/gemm_universal/gemm_rrr_bias.py | 86 + .../gemm_universal/gemm_rrr_bias_permute.py | 67 + .../ops/gemm_universal/gemm_rrr_permute.py | 62 + .../ops/gemm_universal/group_gemm_rcr.py | 319 +++ .../ops/gemm_universal/group_gemm_rcr_bias.py | 168 ++ .../group_gemm_rcr_bias_relu.py | 52 + .../group_gemm_rcr_bias_sigmoid.py | 52 + .../ops/gemm_universal/perm021fc_ccr.py | 147 ++ .../ops/gemm_universal/perm021fc_ccr_bias.py | 77 + .../perm021fc_ccr_bias_permute.py | 77 + .../ops/gemm_universal/perm021fc_crc.py | 113 + .../ops/gemm_universal/perm021fc_crc_bias.py | 79 + .../ops/gemm_universal/perm102_bmm_rcr.py | 98 + .../gemm_universal/perm102_bmm_rcr_bias.py | 91 + .../ops/gemm_universal/perm102_bmm_rrr.py | 98 + .../gemm_universal/perm102_bmm_rrr_bias.py | 72 + .../compiler/ops/groupnorm/__init__.py | 19 + .../compiler/ops/groupnorm/groupnorm.py | 403 +++ .../compiler/ops/groupnorm/groupnorm_swish.py | 26 + .../compiler/ops/layernorm/__init__.py | 28 + .../layernorm/batch_layernorm_sigmoid_mul.py | 91 + .../compiler/ops/layernorm/group_layernorm.py | 160 ++ .../layernorm/group_layernorm_sigmoid_mul.py | 39 + .../compiler/ops/layernorm/layernorm.py | 417 ++++ .../ops/layernorm/layernorm_sigmoid_mul.py | 97 + .../compiler/ops/padding/__init__.py | 23 + .../compiler/ops/padding/nhwc3to4.py | 39 + .../compiler/ops/padding/nhwc3to8.py | 38 + .../compiler/ops/padding/nhwc_pad_common.py | 109 + .../compiler/ops/padding/pad_last_dim.py | 93 + .../aitemplate/compiler/ops/pool/__init__.py | 22 + .../compiler/ops/pool/avg_pool2d.py | 53 + .../compiler/ops/pool/max_pool2d.py | 56 + python/aitemplate/compiler/ops/pool/pool2d.py | 180 ++ .../compiler/ops/reduce/__init__.py | 24 + .../compiler/ops/reduce/reduce_common.py | 249 ++ .../compiler/ops/reduce/reduce_mean.py | 46 + .../compiler/ops/reduce/reduce_sum.py | 46 + python/aitemplate/compiler/ops/reduce/var.py | 52 + .../compiler/ops/reduce/vector_norm.py | 59 + .../compiler/ops/softmax/__init__.py | 21 + .../compiler/ops/softmax/cache_entry.py | 57 + .../compiler/ops/softmax/softmax.py | 367 +++ .../compiler/ops/tensor/__init__.py | 35 + .../aitemplate/compiler/ops/tensor/argmax.py | 206 ++ .../compiler/ops/tensor/batch_gather.py | 121 + .../aitemplate/compiler/ops/tensor/chunk.py | 70 + .../compiler/ops/tensor/concatenate.py | 260 ++ .../compiler/ops/tensor/concatenate_tanh.py | 28 + .../compiler/ops/tensor/dynamic_slice.py | 186 ++ .../aitemplate/compiler/ops/tensor/expand.py | 135 + .../aitemplate/compiler/ops/tensor/gather.py | 77 + .../aitemplate/compiler/ops/tensor/permute.py | 54 + .../compiler/ops/tensor/permute021.py | 104 + .../compiler/ops/tensor/permute102.py | 142 ++ .../compiler/ops/tensor/permute210.py | 115 + python/aitemplate/compiler/ops/tensor/size.py | 68 + .../ops/tensor/slice_reshape_scatter.py | 144 ++ .../compiler/ops/tensor/slice_scatter.py | 99 + .../aitemplate/compiler/ops/tensor/split.py | 165 ++ python/aitemplate/compiler/ops/tensor/topk.py | 189 ++ .../compiler/ops/upsample/__init__.py | 22 + .../compiler/ops/upsample/upsampling2d.py | 41 + .../compiler/ops/upsample/upsampling2d_add.py | 56 + .../ops/upsample/upsampling_common.py | 172 ++ .../compiler/ops/vision_ops/__init__.py | 19 + .../compiler/ops/vision_ops/nms/__init__.py | 23 + .../ops/vision_ops/nms/batched_nms.py | 114 + .../ops/vision_ops/nms/efficient_nms.py | 244 ++ .../compiler/ops/vision_ops/nms/nms.py | 228 ++ .../ops/vision_ops/roi_ops/__init__.py | 21 + .../roi_ops/multi_level_roi_align.py | 119 + .../ops/vision_ops/roi_ops/roi_align.py | 72 + .../ops/vision_ops/roi_ops/roi_ops.py | 214 ++ python/aitemplate/compiler/public/__init__.py | 77 + python/aitemplate/compiler/tensor_accessor.py | 447 ++++ .../aitemplate/compiler/transform/__init__.py | 39 + .../compiler/transform/apply_padding.py | 245 ++ .../compiler/transform/bind_constants.py | 53 + .../compiler/transform/constant_folding.py | 192 ++ .../transform/fuse_conv_elementwise.py | 72 + .../compiler/transform/fuse_conv_patterns.py | 137 ++ .../compiler/transform/fuse_group_ops.py | 716 ++++++ .../compiler/transform/fuse_mm_elementwise.py | 218 ++ .../transform/fuse_mm_elementwise_patterns.py | 169 ++ .../aitemplate/compiler/transform/fuse_ops.py | 200 ++ .../compiler/transform/fuse_parallel_gemms.py | 461 ++++ .../compiler/transform/fuse_permute_bmm.py | 224 ++ .../compiler/transform/fuse_split.py | 282 +++ .../compiler/transform/fuse_utils.py | 191 ++ .../compiler/transform/mark_param_tensor.py | 61 + .../compiler/transform/memory_planning.py | 289 +++ .../compiler/transform/name_graph.py | 86 + .../compiler/transform/optimize_graph.py | 87 + .../aitemplate/compiler/transform/profile.py | 72 + .../compiler/transform/profile_dynamic_dim.py | 46 + .../compiler/transform/refine_graph.py | 159 ++ .../compiler/transform/remove_no_ops.py | 168 ++ .../compiler/transform/remove_unused_ops.py | 43 + .../aitemplate/compiler/transform/toposort.py | 65 + .../transform/transform_memory_ops.py | 174 ++ .../transform/transform_odd_alignment.py | 301 +++ .../transform/transform_special_ops.py | 301 +++ .../transform_strided_op_and_view_op.py | 154 ++ .../transform/transform_strided_ops.py | 475 ++++ .../transform/transform_strided_ops_utils.py | 108 + .../transform/transform_strided_slice.py | 268 ++ .../compiler/transform/transform_utils.py | 341 +++ python/aitemplate/frontend/__init__.py | 19 + python/aitemplate/frontend/nn/__init__.py | 34 + python/aitemplate/frontend/nn/attention.py | 227 ++ python/aitemplate/frontend/nn/container.py | 890 +++++++ .../aitemplate/frontend/nn/conv2d/__init__.py | 30 + .../nn/conv2d/common_conv2d_bias_act.py | 76 + .../nn/conv2d/common_conv2d_bias_add_act.py | 51 + .../aitemplate/frontend/nn/conv2d/conv2d.py | 114 + .../frontend/nn/conv2d/conv2d_bias.py | 43 + .../nn/conv2d/conv2d_bias_add_hardswish.py | 43 + .../nn/conv2d/conv2d_bias_add_relu.py | 43 + .../nn/conv2d/conv2d_bias_few_channels.py | 45 + .../nn/conv2d/conv2d_bias_hardswish.py | 43 + .../conv2d_bias_hardswish_few_channels.py | 45 + .../frontend/nn/conv2d/conv2d_bias_relu.py | 43 + .../conv2d/conv2d_bias_relu_few_channels.py | 45 + .../frontend/nn/conv2d/conv2d_bias_sigmoid.py | 43 + .../nn/conv2d/special_conv2d_bias_act.py | 57 + .../nn/conv2d/transposed_conv2d_bias.py | 43 + .../nn/conv2d/transposed_conv2d_bias_act.py | 76 + .../nn/conv2d/transposed_conv2d_bias_relu.py | 43 + python/aitemplate/frontend/nn/dropout.py | 40 + python/aitemplate/frontend/nn/embedding.py | 121 + python/aitemplate/frontend/nn/fpn_proposal.py | 118 + python/aitemplate/frontend/nn/group_norm.py | 50 + python/aitemplate/frontend/nn/identity.py | 33 + python/aitemplate/frontend/nn/layer_norm.py | 58 + python/aitemplate/frontend/nn/linear.py | 70 + python/aitemplate/frontend/nn/module.py | 757 ++++++ python/aitemplate/frontend/nn/padding.py | 30 + python/aitemplate/frontend/nn/parameter.py | 30 + python/aitemplate/frontend/nn/pool2d.py | 41 + python/aitemplate/frontend/nn/proposal.py | 278 +++ python/aitemplate/frontend/nn/roi_ops.py | 75 + python/aitemplate/frontend/nn/upsample.py | 42 + python/aitemplate/frontend/nn/view_ops.py | 54 + python/aitemplate/frontend/parameter.py | 30 + python/aitemplate/testing/__init__.py | 25 + python/aitemplate/testing/benchmark_ait.py | 160 ++ python/aitemplate/testing/benchmark_pt.py | 55 + python/aitemplate/testing/detect_target.py | 97 + python/aitemplate/testing/test_utils.py | 105 + python/aitemplate/utils/__init__.py | 26 + python/aitemplate/utils/graph_utils.py | 74 + python/aitemplate/utils/logger.py | 38 + python/aitemplate/utils/markdown_table.py | 183 ++ .../utils/mk_ck_lib/conv2d_operation.py | 388 +++ .../utils/mk_ck_lib/gemm_operation.py | 513 ++++ .../aitemplate/utils/mk_ck_lib/generator.py | 2164 +++++++++++++++++ .../utils/mk_ck_lib/groupnorm_operation.py | 119 + .../utils/mk_ck_lib/layernorm_operation.py | 119 + python/aitemplate/utils/mk_ck_lib/library.py | 375 +++ python/aitemplate/utils/mk_ck_lib/manifest.py | 178 ++ .../utils/mk_ck_lib/softmax_operation.py | 113 + .../utils/mk_cutlass_lib/extra_conv_emit.py | 127 + .../mk_cutlass_lib/extra_cutlass_generator.py | 112 + .../utils/mk_cutlass_lib/extra_enum.py | 139 ++ .../utils/mk_cutlass_lib/extra_gemm_emit.py | 250 ++ .../utils/mk_cutlass_lib/mk_cutlass_lib.py | 90 + python/aitemplate/utils/shape_utils.py | 187 ++ python/aitemplate/utils/tensor_utils.py | 28 + python/aitemplate/utils/torch_utils.py | 38 + .../utils/visualization/__init__.py | 18 + .../utils/visualization/op_attr_factory.py | 21 + python/aitemplate/utils/visualization/plot.py | 202 ++ .../aitemplate/utils/visualization/pydot.py | 1962 +++++++++++++++ .../utils/visualization/web_template.py | 381 +++ python/setup.py | 176 ++ static/README.md | 143 ++ static/csrc/model_container.cpp | 475 ++++ static/csrc/model_interface.cpp | 229 ++ static/csrc/rocm_hack.cpp | 62 + static/csrc/utility.cpp | 69 + static/include/cuda_device_functions.h | 185 ++ static/include/logging.h | 622 +++++ static/include/macros.h | 29 + static/include/model_container.h | 189 ++ static/include/model_interface.h | 184 ++ static/include/owned_constants.h | 46 + static/include/raii_wrapper.h | 53 + static/include/rocm_device_functions.h | 192 ++ static/include/utility.h | 54 + tests/ci_profile_cache/README.md | 5 + tests/ci_profile_cache/update_cache.py | 827 +++++++ tests/lint/check_meta_header.py | 108 + tests/lint/flake8_problem_matcher.json | 17 + .../backend/test_fused_elementwise_backend.py | 412 ++++ tests/unittest/backend/test_model_api.py | 1408 +++++++++++ .../benchmark/test_group_gemm_benchmark.py | 654 +++++ .../test_strided_layernorm_benchmark.py | 85 + .../compiler/test_constant_folding.py | 316 +++ .../compiler/test_fuse_conv_elementwise.py | 696 ++++++ tests/unittest/compiler/test_fuse_expand.py | 63 + .../compiler/test_fuse_mm_elementwise.py | 1387 +++++++++++ .../compiler/test_fuse_permute_bmm.py | 647 +++++ ...st_fused_elementwise_complex_dependency.py | 283 +++ .../test_fused_elementwise_out_of_order.py | 135 + tests/unittest/compiler/test_group_fusions.py | 458 ++++ .../unittest/compiler/test_memory_planning.py | 124 + .../test_pad_bmm_rrr_bias_with_cat.py | 99 + .../compiler/test_pad_gemm_rrr_with_cat.py | 83 + .../compiler/test_pad_gemm_with_cat.py | 95 + .../test_pad_gemm_with_elementwise.py | 171 ++ .../compiler/test_parallel_gemm_fusions.py | 542 +++++ .../compiler/test_permute_bmm_special_op.py | 79 + tests/unittest/compiler/test_public_import.py | 54 + tests/unittest/compiler/test_refine_graph.py | 319 +++ .../compiler/test_remove_unused_ops.py | 77 + .../compiler/test_slice_elemwise_fusion.py | 513 ++++ .../compiler/test_slice_gemm_fusion.py | 769 ++++++ .../compiler/test_slice_reshape_scatter.py | 140 ++ .../compiler/test_slice_scatter_pattern.py | 474 ++++ .../compiler/test_slice_view_strided.py | 122 + .../compiler/test_split_bmm_fusion.py | 299 +++ .../compiler/test_split_bmm_softmax_bmm.py | 92 + .../compiler/test_split_view_strided.py | 180 ++ .../compiler/test_strided_group_gemm.py | 255 ++ .../compiler/test_strided_group_layernorm.py | 335 +++ .../compiler/test_strided_layernorm.py | 297 +++ .../test_strided_layernorm_reshape.py | 147 ++ .../compiler/test_strided_op_cat_pattern.py | 1557 ++++++++++++ .../compiler/test_strided_reshape_cat.py | 247 ++ .../unittest/compiler/test_strided_scatter.py | 870 +++++++ .../compiler/test_strided_split_group_gemm.py | 338 +++ .../compiler/test_strided_view_cat.py | 206 ++ .../unittest/compiler/test_strided_view_op.py | 390 +++ .../unittest/compiler/test_tensor_accessor.py | 360 +++ .../compiler/test_transform_memory_ops.py | 247 ++ .../compiler/test_transform_odd_alignment.py | 493 ++++ .../compiler/test_transform_special_op.py | 366 +++ .../unittest/compiler/test_transform_utils.py | 171 ++ .../unittest/compiler/test_view_strided_op.py | 519 ++++ tests/unittest/frontend/test_module.py | 246 ++ tests/unittest/ops/test_activation.py | 136 ++ tests/unittest/ops/test_argmax.py | 58 + tests/unittest/ops/test_attention.py | 294 +++ tests/unittest/ops/test_avg_pool2d.py | 51 + tests/unittest/ops/test_batch_gather.py | 166 ++ tests/unittest/ops/test_bert_embeddings.py | 167 ++ tests/unittest/ops/test_bmm.py | 399 +++ tests/unittest/ops/test_bmm_add.py | 288 +++ tests/unittest/ops/test_bmm_alpha.py | 282 +++ tests/unittest/ops/test_bmm_permute.py | 112 + tests/unittest/ops/test_bmm_rcr_n1.py | 89 + tests/unittest/ops/test_bmm_rrr_k1_tanh.py | 55 + tests/unittest/ops/test_bmm_softmax.py | 62 + tests/unittest/ops/test_bmm_softmax_bmm.py | 189 ++ tests/unittest/ops/test_chunk.py | 120 + tests/unittest/ops/test_clamp_nan_to_num.py | 178 ++ tests/unittest/ops/test_concatenate.py | 401 +++ tests/unittest/ops/test_concatenate_tanh.py | 369 +++ tests/unittest/ops/test_conv.py | 58 + tests/unittest/ops/test_conv2d_bias_add.py | 71 + tests/unittest/ops/test_conv_bias.py | 62 + .../ops/test_conv_bias_act_few_channels.py | 104 + .../ops/test_conv_bias_add_hardswish.py | 74 + tests/unittest/ops/test_conv_bias_add_relu.py | 71 + .../unittest/ops/test_conv_bias_hardswish.py | 71 + tests/unittest/ops/test_conv_bias_relu.py | 63 + tests/unittest/ops/test_conv_bias_sigmoid.py | 63 + tests/unittest/ops/test_dynamic_conv.py | 70 + tests/unittest/ops/test_efficient_nms.py | 307 +++ tests/unittest/ops/test_expand.py | 104 + tests/unittest/ops/test_flatten.py | 135 + tests/unittest/ops/test_fpn_roi_align.py | 201 ++ tests/unittest/ops/test_fused_elementwise.py | 381 +++ .../ops/test_fused_elementwise_broadcast.py | 471 ++++ ..._fused_elementwise_with_strided_outputs.py | 141 ++ tests/unittest/ops/test_gather.py | 90 + tests/unittest/ops/test_gemm.py | 190 ++ tests/unittest/ops/test_gemm_bias.py | 79 + .../unittest/ops/test_gemm_bias_broadcast.py | 316 +++ .../unittest/ops/test_gemm_bias_hardswish.py | 55 + tests/unittest/ops/test_gemm_bias_permute.py | 187 ++ tests/unittest/ops/test_gemm_bias_relu.py | 76 + tests/unittest/ops/test_gemm_bias_sigmoid.py | 50 + tests/unittest/ops/test_gemm_bias_softmax.py | 69 + tests/unittest/ops/test_gemm_bias_swish.py | 55 + tests/unittest/ops/test_gemm_bias_tanh.py | 68 + tests/unittest/ops/test_gemm_permute.py | 100 + .../ops/test_gemm_rcr_bias_fast_gelu.py | 79 + tests/unittest/ops/test_gemm_rrr_small_nk.py | 68 + tests/unittest/ops/test_gemm_softmax.py | 67 + tests/unittest/ops/test_group_gemm_rcr.py | 91 + .../unittest/ops/test_group_gemm_rcr_bias.py | 77 + .../test_group_gemm_rcr_bias_activation.py | 77 + .../ops/test_group_gemm_rcr_bias_cat.py | 79 + tests/unittest/ops/test_group_gemm_rcr_cat.py | 74 + tests/unittest/ops/test_groupnorm.py | 149 ++ tests/unittest/ops/test_layernorm.py | 146 ++ .../ops/test_layernorm_sigmoid_mul.py | 700 ++++++ tests/unittest/ops/test_max_pool2d.py | 51 + tests/unittest/ops/test_nhwc3to4.py | 57 + tests/unittest/ops/test_nhwc3to8.py | 57 + tests/unittest/ops/test_nms.py | 207 ++ tests/unittest/ops/test_norm.py | 176 ++ tests/unittest/ops/test_pad_last_dim.py | 70 + tests/unittest/ops/test_perm021fc_ccr.py | 62 + tests/unittest/ops/test_perm021fc_ccr_bias.py | 68 + .../ops/test_perm021fc_ccr_bias_perm021.py | 69 + tests/unittest/ops/test_perm021fc_crc.py | 63 + tests/unittest/ops/test_perm021fc_crc_bias.py | 66 + tests/unittest/ops/test_perm102_bmm_rcr.py | 97 + tests/unittest/ops/test_perm102_bmm_rrr.py | 97 + tests/unittest/ops/test_permute.py | 55 + tests/unittest/ops/test_permute021.py | 45 + tests/unittest/ops/test_permute102.py | 45 + tests/unittest/ops/test_permute210.py | 48 + tests/unittest/ops/test_proposal.py | 499 ++++ tests/unittest/ops/test_reduce.py | 367 +++ tests/unittest/ops/test_reshape.py | 168 ++ tests/unittest/ops/test_roi_align.py | 136 ++ tests/unittest/ops/test_size_getitem_ops.py | 112 + tests/unittest/ops/test_slice.py | 220 ++ tests/unittest/ops/test_softmax.py | 74 + tests/unittest/ops/test_split.py | 193 ++ tests/unittest/ops/test_split_getitem.py | 222 ++ tests/unittest/ops/test_squeeze.py | 133 + tests/unittest/ops/test_topk.py | 73 + tests/unittest/ops/test_transpose_conv2d.py | 58 + .../ops/test_transpose_conv2d_bias.py | 67 + .../ops/test_transpose_conv2d_bias_relu.py | 68 + .../unittest/ops/test_tuple_list_construct.py | 82 + tests/unittest/ops/test_upsamping2d.py | 78 + tests/unittest/ops/test_upsamping2d_add.py | 109 + tests/unittest/ops/test_var.py | 132 + 846 files changed, 149832 insertions(+) create mode 100644 .circleci/config.yml create mode 100644 .clang-format create mode 100644 .flake8 create mode 100644 .github/workflows/docs.yml create mode 100644 .github/workflows/lint.yml create mode 100644 .gitignore create mode 100644 .gitmodules create mode 160000 3rdparty/composable_kernel create mode 160000 3rdparty/cub create mode 160000 3rdparty/cutlass create mode 100644 CITATION.cff create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docker/Dockerfile.cuda create mode 100644 docker/Dockerfile.rocm create mode 100644 docker/README.md create mode 100755 docker/build.sh create mode 100644 docker/install/install_ait.sh create mode 100644 docker/install/install_basic_dep.sh create mode 100644 docker/install/install_detection_deps.sh create mode 100644 docker/install/install_doc_dep.sh create mode 100644 docker/install/install_test_dep.sh create mode 100644 docker/install/rocm_dev-requirements.txt create mode 100644 docker/rocm_fix/fix_10736.py create mode 100644 docs/Makefile create mode 100644 docs/README.md create mode 100644 docs/make.bat create mode 100644 docs/source/arch/index.rst create mode 100644 docs/source/arch/philosophy.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/debughints.rst create mode 100644 docs/source/genindex.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/install/index.rst create mode 100644 docs/source/reference/backend.rst create mode 100644 docs/source/reference/compiler.rst create mode 100644 docs/source/reference/cuda.rst create mode 100644 docs/source/reference/env.rst create mode 100644 docs/source/reference/frontend.rst create mode 100644 docs/source/reference/index.rst create mode 100644 docs/source/reference/ops.rst create mode 100644 docs/source/reference/rocm.rst create mode 100644 docs/source/reference/testing.rst create mode 100644 docs/source/reference/transform.rst create mode 100644 docs/source/reference/utils.rst create mode 100644 docs/source/runtime/cxx_design.rst create mode 100644 docs/source/runtime/index.rst create mode 100644 docs/source/runtime/py_design.rst create mode 100644 docs/source/tutorial/how_to_add_op.rst create mode 100644 docs/source/tutorial/how_to_infer_pt.rst create mode 100644 docs/source/tutorial/how_to_visualize.rst create mode 100644 docs/source/tutorial/index.rst create mode 100644 docs/static/ait_model.html create mode 100644 examples/01_resnet-50/README.md create mode 100644 examples/01_resnet-50/benchmark_ait.py create mode 100644 examples/01_resnet-50/benchmark_mi250.sh create mode 100644 examples/01_resnet-50/benchmark_pt.py create mode 100644 examples/01_resnet-50/infer_with_torch.py create mode 100644 examples/01_resnet-50/modeling/__init__.py create mode 100644 examples/01_resnet-50/modeling/resnet.py create mode 100644 examples/01_resnet-50/weight_utils.py create mode 100644 examples/02_detectron2/README.md create mode 100644 examples/02_detectron2/compile_model.py create mode 100644 examples/02_detectron2/configs/__init__.py create mode 100644 examples/02_detectron2/configs/config.py create mode 100644 examples/02_detectron2/configs/defaults.py create mode 100644 examples/02_detectron2/configs/faster_rcnn_R_101_FPN.yaml create mode 100644 examples/02_detectron2/configs/faster_rcnn_R_50_FPN.yaml create mode 100644 examples/02_detectron2/configs/mask_rcnn_R_101_FPN.yaml create mode 100644 examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml create mode 100644 examples/02_detectron2/demo.py create mode 100644 examples/02_detectron2/modeling/backbone/__init__.py create mode 100644 examples/02_detectron2/modeling/backbone/fpn.py create mode 100644 examples/02_detectron2/modeling/backbone/resnet.py create mode 100644 examples/02_detectron2/modeling/backbone/utils.py create mode 100644 examples/02_detectron2/modeling/meta_arch/__init__.py create mode 100644 examples/02_detectron2/modeling/meta_arch/rcnn.py create mode 100644 examples/02_detectron2/modeling/proposal_generator/__init__.py create mode 100644 examples/02_detectron2/modeling/proposal_generator/rpn.py create mode 100644 examples/02_detectron2/modeling/roi_heads/__init__.py create mode 100644 examples/02_detectron2/modeling/roi_heads/box_head.py create mode 100644 examples/02_detectron2/modeling/roi_heads/fast_rcnn.py create mode 100644 examples/02_detectron2/modeling/roi_heads/mask_head.py create mode 100644 examples/02_detectron2/modeling/roi_heads/roi_heads.py create mode 100644 examples/02_detectron2/predictor/__init__.py create mode 100644 examples/02_detectron2/predictor/builtin_meta.py create mode 100644 examples/02_detectron2/predictor/predictor.py create mode 100755 examples/02_detectron2/prepare_and_run_rcnn.sh create mode 100644 examples/02_detectron2/tools/convert_pt2ait.py create mode 100644 examples/03_bert/README.md create mode 100644 examples/03_bert/benchmark_ait.py create mode 100644 examples/03_bert/benchmark_mi250.sh create mode 100644 examples/03_bert/benchmark_pt.py create mode 100644 examples/03_bert/demo.py create mode 100644 examples/03_bert/modeling/__init__.py create mode 100644 examples/03_bert/modeling/bert.py create mode 100644 examples/03_bert/modeling/torch_model.py create mode 100644 examples/04_vit/README.md create mode 100644 examples/04_vit/benchmark_ait.py create mode 100644 examples/04_vit/benchmark_mi250.sh create mode 100644 examples/04_vit/benchmark_pt.py create mode 100644 examples/04_vit/modeling/vision_transformer.py create mode 100644 examples/04_vit/verification.py create mode 100644 examples/04_vit/weight_utils.py create mode 100644 examples/05_stable_diffusion/README.md create mode 100644 examples/05_stable_diffusion/benchmark.py create mode 100644 examples/05_stable_diffusion/benchmark_pt.py create mode 100644 examples/05_stable_diffusion/compile.py create mode 100644 examples/05_stable_diffusion/demo.py create mode 100644 examples/05_stable_diffusion/modeling/attention.py create mode 100644 examples/05_stable_diffusion/modeling/clip.py create mode 100644 examples/05_stable_diffusion/modeling/embeddings.py create mode 100644 examples/05_stable_diffusion/modeling/resnet.py create mode 100644 examples/05_stable_diffusion/modeling/unet_2d_condition.py create mode 100644 examples/05_stable_diffusion/modeling/unet_blocks.py create mode 100644 examples/05_stable_diffusion/modeling/vae.py create mode 100644 examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py create mode 100644 examples/06_how_to_add_an_op/how_to_add_an_op.py create mode 100644 examples/07_how_to_run_pt_model/how_to_run_pt_model.py create mode 100644 licenses/LICENSE.composable_kernel.txt create mode 100644 licenses/LICENSE.cub.txt create mode 100644 licenses/LICENSE.cutlass.txt create mode 100644 licenses/LICENSE.dmlc.txt create mode 100644 licenses/LICENSE.flash_attention.txt create mode 100644 licenses/LICENSE.hipcub.txt create mode 100644 licenses/LICENSE.markdown_table.txt create mode 100644 licenses/LICENSE.oneflow.txt create mode 100644 licenses/LICENSE.pydot.txt create mode 100644 licenses/LICENSE.pytorch.txt create mode 100644 licenses/LICENSE.tensorrt.txt create mode 100644 licenses/license.header.txt create mode 100644 python/aitemplate/__init__.py create mode 100644 python/aitemplate/_libinfo.py create mode 100644 python/aitemplate/backend/__init__.py create mode 100644 python/aitemplate/backend/backend_spec.py create mode 100644 python/aitemplate/backend/builder.py create mode 100644 python/aitemplate/backend/codegen.py create mode 100644 python/aitemplate/backend/common/concatenate_common.py create mode 100644 python/aitemplate/backend/common/elementwise_common.py create mode 100644 python/aitemplate/backend/common/gemm_common.py create mode 100644 python/aitemplate/backend/common/split_common.py create mode 100644 python/aitemplate/backend/common/tensor/argmax_common.py create mode 100644 python/aitemplate/backend/common/tensor/batch_gather_common.py create mode 100644 python/aitemplate/backend/common/tensor/permute021_common.py create mode 100644 python/aitemplate/backend/common/tensor/permute102_common.py create mode 100644 python/aitemplate/backend/common/tensor/permute210_common.py create mode 100644 python/aitemplate/backend/common/tensor/slice_common.py create mode 100644 python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py create mode 100644 python/aitemplate/backend/common/tensor/topk_common.py create mode 100644 python/aitemplate/backend/common/tensor_accessor.cuh create mode 100644 python/aitemplate/backend/common/tensor_accessor_codegen.py create mode 100644 python/aitemplate/backend/common/upsampling2d_common.py create mode 100644 python/aitemplate/backend/common/vision_ops/efficient_nms_common.py create mode 100644 python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py create mode 100644 python/aitemplate/backend/common/vision_ops/multi_level_roi_align_common.py create mode 100644 python/aitemplate/backend/common/vision_ops/nms_common.py create mode 100644 python/aitemplate/backend/common/vision_ops/nms_kernel.py create mode 100644 python/aitemplate/backend/common/vision_ops/roi_align_common.py create mode 100644 python/aitemplate/backend/cuda/__init__.py create mode 100644 python/aitemplate/backend/cuda/attention/__init__.py create mode 100644 python/aitemplate/backend/cuda/attention/flash_attention.py create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/gemm.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/mask.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/softmax.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha/utils.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_kernel.h create mode 100644 python/aitemplate/backend/cuda/attention/src/fmha_utils.h create mode 100644 python/aitemplate/backend/cuda/attention/src/licenses/LICENSE create mode 100644 python/aitemplate/backend/cuda/attention/src/philox.cuh create mode 100644 python/aitemplate/backend/cuda/common/__init__.py create mode 100644 python/aitemplate/backend/cuda/common/dummy_op.py create mode 100644 python/aitemplate/backend/cuda/conv2d/__init__.py create mode 100644 python/aitemplate/backend/cuda/conv2d/common.py create mode 100644 python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py create mode 100644 python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py create mode 100644 python/aitemplate/backend/cuda/conv2d/common_conv2d_few_channels.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py create mode 100644 python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py create mode 100644 python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py create mode 100644 python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py create mode 100644 python/aitemplate/backend/cuda/cuda_common.py create mode 100644 python/aitemplate/backend/cuda/elementwise/__init__.py create mode 100644 python/aitemplate/backend/cuda/elementwise/custom_math.cuh create mode 100644 python/aitemplate/backend/cuda/elementwise/fused_elementwise.py create mode 100644 python/aitemplate/backend/cuda/embedding/__init__.py create mode 100644 python/aitemplate/backend/cuda/embedding/bert_embeddings.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h create mode 100644 python/aitemplate/backend/cuda/gemm_special/__init__.py create mode 100644 python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py create mode 100644 python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py create mode 100644 python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/__init__.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_common.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/bmm_softmax_bmm_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_common.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/layout.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py create mode 100644 python/aitemplate/backend/cuda/groupnorm/__init__.py create mode 100644 python/aitemplate/backend/cuda/groupnorm/groupnorm.py create mode 100644 python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py create mode 100644 python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh create mode 100644 python/aitemplate/backend/cuda/groupnorm/groupnorm_swish.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/batch_layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/group_layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_common.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/backend/cuda/layernorm_sigmoid_mul/layernorm_sigmoid_mul_kernel.cuh create mode 100644 python/aitemplate/backend/cuda/lib_template.py create mode 100644 python/aitemplate/backend/cuda/padding/__init__.py create mode 100644 python/aitemplate/backend/cuda/padding/nhwc3to4.py create mode 100644 python/aitemplate/backend/cuda/padding/nhwc3to8.py create mode 100644 python/aitemplate/backend/cuda/padding/pad_last_dim.py create mode 100644 python/aitemplate/backend/cuda/pool2d/__init__.py create mode 100644 python/aitemplate/backend/cuda/pool2d/avg_pool2d.py create mode 100644 python/aitemplate/backend/cuda/pool2d/max_pool2d.py create mode 100644 python/aitemplate/backend/cuda/pool2d/pool2d.py create mode 100644 python/aitemplate/backend/cuda/reduce/__init__.py create mode 100644 python/aitemplate/backend/cuda/reduce/reduce_3d.py create mode 100644 python/aitemplate/backend/cuda/reduce/reduce_common.py create mode 100644 python/aitemplate/backend/cuda/reduce/reduce_mean.py create mode 100644 python/aitemplate/backend/cuda/reduce/reduce_small_axis.py create mode 100644 python/aitemplate/backend/cuda/reduce/reduce_sum.py create mode 100644 python/aitemplate/backend/cuda/reduce/var.py create mode 100644 python/aitemplate/backend/cuda/reduce/vector_norm.py create mode 100644 python/aitemplate/backend/cuda/softmax/__init__.py create mode 100644 python/aitemplate/backend/cuda/softmax/softmax.cuh create mode 100644 python/aitemplate/backend/cuda/softmax/softmax.py create mode 100644 python/aitemplate/backend/cuda/target_def.py create mode 100644 python/aitemplate/backend/cuda/tensor/__init__.py create mode 100644 python/aitemplate/backend/cuda/tensor/argmax.py create mode 100644 python/aitemplate/backend/cuda/tensor/batch_gather.py create mode 100644 python/aitemplate/backend/cuda/tensor/concatenate.py create mode 100644 python/aitemplate/backend/cuda/tensor/concatenate_tanh.py create mode 100644 python/aitemplate/backend/cuda/tensor/dynamic_slice.py create mode 100644 python/aitemplate/backend/cuda/tensor/expand.py create mode 100644 python/aitemplate/backend/cuda/tensor/gather.py create mode 100644 python/aitemplate/backend/cuda/tensor/permute021.py create mode 100644 python/aitemplate/backend/cuda/tensor/permute102.py create mode 100644 python/aitemplate/backend/cuda/tensor/permute210.py create mode 100644 python/aitemplate/backend/cuda/tensor/slice_reshape_scatter.py create mode 100644 python/aitemplate/backend/cuda/tensor/slice_scatter.py create mode 100644 python/aitemplate/backend/cuda/tensor/split.py create mode 100644 python/aitemplate/backend/cuda/tensor/topk.py create mode 100644 python/aitemplate/backend/cuda/upsample/__init__.py create mode 100644 python/aitemplate/backend/cuda/upsample/upsampling2d.py create mode 100644 python/aitemplate/backend/cuda/upsample/upsampling2d_add.py create mode 100644 python/aitemplate/backend/cuda/utils.py create mode 100644 python/aitemplate/backend/cuda/view_ops/__init__.py create mode 100644 python/aitemplate/backend/cuda/view_ops/view_ops.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/__init__.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/nms/__init__.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/nms/batched_nms.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/nms/batched_nms_kernel.cuh create mode 100644 python/aitemplate/backend/cuda/vision_ops/nms/efficient_nms.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/nms/nms.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/roi_ops/multi_level_roi_align.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_align.py create mode 100644 python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_ops.py create mode 100644 python/aitemplate/backend/main_templates.py create mode 100644 python/aitemplate/backend/profiler_cache.py create mode 100644 python/aitemplate/backend/profiler_runner.py create mode 100644 python/aitemplate/backend/registry.py create mode 100644 python/aitemplate/backend/rocm/__init__.py create mode 100644 python/aitemplate/backend/rocm/common/__init__.py create mode 100644 python/aitemplate/backend/rocm/common/dummy_op.py create mode 100644 python/aitemplate/backend/rocm/conv2d/__init__.py create mode 100644 python/aitemplate/backend/rocm/conv2d/common.py create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d.py create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d_bias.py create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py create mode 100644 python/aitemplate/backend/rocm/conv2d/transposed_conv2d.py create mode 100644 python/aitemplate/backend/rocm/conv2d/transposed_conv2d_bias_relu.py create mode 100644 python/aitemplate/backend/rocm/elementwise/__init__.py create mode 100644 python/aitemplate/backend/rocm/elementwise/custom_math.h create mode 100644 python/aitemplate/backend/rocm/elementwise/fused_elementwise.py create mode 100644 python/aitemplate/backend/rocm/gemm/__init__.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_ccr.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_common.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_crr.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_permute_common.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_rcr.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_rcr_permute.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_rrr.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_rrr_permute.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm.py create mode 100644 python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py create mode 100644 python/aitemplate/backend/rocm/gemm/common.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_epilogue.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add_relu.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_relu.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_fast_gelu.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_add.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_tanh.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m2n3.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m3n2.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_relu.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul_tanh.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_swish.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_tanh.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rcr_permute_m2n3.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rrr.py create mode 100644 python/aitemplate/backend/rocm/gemm/gemm_rrr_bias_permute.py create mode 100644 python/aitemplate/backend/rocm/gemm/layout.py create mode 100644 python/aitemplate/backend/rocm/gemm/permute_common.py create mode 100644 python/aitemplate/backend/rocm/lib_template.py create mode 100644 python/aitemplate/backend/rocm/normalization/__init__.py create mode 100644 python/aitemplate/backend/rocm/normalization/groupnorm.py create mode 100644 python/aitemplate/backend/rocm/normalization/groupnorm_swish.py create mode 100644 python/aitemplate/backend/rocm/normalization/layernorm.py create mode 100644 python/aitemplate/backend/rocm/normalization/norm_common.py create mode 100644 python/aitemplate/backend/rocm/normalization/softmax.py create mode 100644 python/aitemplate/backend/rocm/pool2d/__init__.py create mode 100644 python/aitemplate/backend/rocm/pool2d/avg_pool2d.py create mode 100644 python/aitemplate/backend/rocm/pool2d/max_pool2d.py create mode 100644 python/aitemplate/backend/rocm/pool2d/pool2d.py create mode 100644 python/aitemplate/backend/rocm/target_def.py create mode 100644 python/aitemplate/backend/rocm/tensor/__init__.py create mode 100644 python/aitemplate/backend/rocm/tensor/argmax.py create mode 100644 python/aitemplate/backend/rocm/tensor/batch_gather.py create mode 100644 python/aitemplate/backend/rocm/tensor/concatenate.py create mode 100644 python/aitemplate/backend/rocm/tensor/concatenate_tanh.py create mode 100644 python/aitemplate/backend/rocm/tensor/dynamic_slice.py create mode 100644 python/aitemplate/backend/rocm/tensor/permute021.py create mode 100644 python/aitemplate/backend/rocm/tensor/permute102.py create mode 100644 python/aitemplate/backend/rocm/tensor/permute210.py create mode 100644 python/aitemplate/backend/rocm/tensor/slice_reshape_scatter.py create mode 100644 python/aitemplate/backend/rocm/tensor/slice_scatter.py create mode 100644 python/aitemplate/backend/rocm/tensor/split.py create mode 100644 python/aitemplate/backend/rocm/tensor/topk.py create mode 100644 python/aitemplate/backend/rocm/upsample/__init__.py create mode 100644 python/aitemplate/backend/rocm/upsample/upsampling2d.py create mode 100644 python/aitemplate/backend/rocm/upsample/upsampling2d_add.py create mode 100644 python/aitemplate/backend/rocm/utils.py create mode 100644 python/aitemplate/backend/rocm/view_ops/__init__.py create mode 100644 python/aitemplate/backend/rocm/view_ops/view_ops.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/__init__.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/efficient_nms.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/nms.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/roi_ops/multi_level_roi_align.py create mode 100644 python/aitemplate/backend/rocm/vision_ops/roi_ops/roi_align.py create mode 100644 python/aitemplate/backend/target.py create mode 100644 python/aitemplate/backend/task_runner.py create mode 100644 python/aitemplate/compiler/__init__.py create mode 100644 python/aitemplate/compiler/base.py create mode 100644 python/aitemplate/compiler/compiler.py create mode 100644 python/aitemplate/compiler/model.py create mode 100644 python/aitemplate/compiler/op_registry.py create mode 100644 python/aitemplate/compiler/ops/__init__.py create mode 100644 python/aitemplate/compiler/ops/attention/__init__.py create mode 100644 python/aitemplate/compiler/ops/attention/flash_attention.py create mode 100644 python/aitemplate/compiler/ops/common/__init__.py create mode 100644 python/aitemplate/compiler/ops/common/elementwise.py create mode 100644 python/aitemplate/compiler/ops/common/epilogue.py create mode 100644 python/aitemplate/compiler/ops/common/fused_elementwise.py create mode 100644 python/aitemplate/compiler/ops/common/math.py create mode 100644 python/aitemplate/compiler/ops/common/python_ops.py create mode 100644 python/aitemplate/compiler/ops/common/view_ops.py create mode 100644 python/aitemplate/compiler/ops/conv/__init__.py create mode 100644 python/aitemplate/compiler/ops/conv/cache_entry.py create mode 100644 python/aitemplate/compiler/ops/conv/common_conv2d_bias_activation.py create mode 100644 python/aitemplate/compiler/ops/conv/common_conv2d_bias_add_activation.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_add.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_add_hardswish.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_add_relu.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_relu.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_sigmoid.py create mode 100644 python/aitemplate/compiler/ops/conv/special_conv2d_bias_activation.py create mode 100644 python/aitemplate/compiler/ops/conv/transposed_conv2d.py create mode 100644 python/aitemplate/compiler/ops/conv/transposed_conv2d_bias.py create mode 100644 python/aitemplate/compiler/ops/conv/transposed_conv2d_bias_relu.py create mode 100644 python/aitemplate/compiler/ops/embedding/__init__.py create mode 100644 python/aitemplate/compiler/ops/embedding/bert_embeddings.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/__init__.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/bmm_rcr_softmax.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_softmax.py create mode 100644 python/aitemplate/compiler/ops/gemm_special/__init__.py create mode 100644 python/aitemplate/compiler/ops/gemm_special/bmm_rcr_n1.py create mode 100644 python/aitemplate/compiler/ops/gemm_special/bmm_rrr_k1_tanh.py create mode 100644 python/aitemplate/compiler/ops/gemm_special/gemm_rrr_small_nk.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/__init__.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_ccr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_ccr_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_crr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_crr_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_rcr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_rcr_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_rrr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_rrr_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_rrr_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/cache_entry.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_common.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_add_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_add_add_relu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_add_relu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_broadcast.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_fast_gelu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_mul.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_mul_add.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_mul_tanh.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid_mul.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rrr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_relu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_sigmoid.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias_permute.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm021fc_crc.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm021fc_crc_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm102_bmm_rcr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm102_bmm_rcr_bias.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm102_bmm_rrr.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/perm102_bmm_rrr_bias.py create mode 100644 python/aitemplate/compiler/ops/groupnorm/__init__.py create mode 100644 python/aitemplate/compiler/ops/groupnorm/groupnorm.py create mode 100644 python/aitemplate/compiler/ops/groupnorm/groupnorm_swish.py create mode 100644 python/aitemplate/compiler/ops/layernorm/__init__.py create mode 100644 python/aitemplate/compiler/ops/layernorm/batch_layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/compiler/ops/layernorm/group_layernorm.py create mode 100644 python/aitemplate/compiler/ops/layernorm/group_layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/compiler/ops/layernorm/layernorm.py create mode 100644 python/aitemplate/compiler/ops/layernorm/layernorm_sigmoid_mul.py create mode 100644 python/aitemplate/compiler/ops/padding/__init__.py create mode 100644 python/aitemplate/compiler/ops/padding/nhwc3to4.py create mode 100644 python/aitemplate/compiler/ops/padding/nhwc3to8.py create mode 100644 python/aitemplate/compiler/ops/padding/nhwc_pad_common.py create mode 100644 python/aitemplate/compiler/ops/padding/pad_last_dim.py create mode 100644 python/aitemplate/compiler/ops/pool/__init__.py create mode 100644 python/aitemplate/compiler/ops/pool/avg_pool2d.py create mode 100644 python/aitemplate/compiler/ops/pool/max_pool2d.py create mode 100644 python/aitemplate/compiler/ops/pool/pool2d.py create mode 100644 python/aitemplate/compiler/ops/reduce/__init__.py create mode 100644 python/aitemplate/compiler/ops/reduce/reduce_common.py create mode 100644 python/aitemplate/compiler/ops/reduce/reduce_mean.py create mode 100644 python/aitemplate/compiler/ops/reduce/reduce_sum.py create mode 100644 python/aitemplate/compiler/ops/reduce/var.py create mode 100644 python/aitemplate/compiler/ops/reduce/vector_norm.py create mode 100644 python/aitemplate/compiler/ops/softmax/__init__.py create mode 100644 python/aitemplate/compiler/ops/softmax/cache_entry.py create mode 100644 python/aitemplate/compiler/ops/softmax/softmax.py create mode 100644 python/aitemplate/compiler/ops/tensor/__init__.py create mode 100644 python/aitemplate/compiler/ops/tensor/argmax.py create mode 100644 python/aitemplate/compiler/ops/tensor/batch_gather.py create mode 100644 python/aitemplate/compiler/ops/tensor/chunk.py create mode 100644 python/aitemplate/compiler/ops/tensor/concatenate.py create mode 100644 python/aitemplate/compiler/ops/tensor/concatenate_tanh.py create mode 100644 python/aitemplate/compiler/ops/tensor/dynamic_slice.py create mode 100644 python/aitemplate/compiler/ops/tensor/expand.py create mode 100644 python/aitemplate/compiler/ops/tensor/gather.py create mode 100644 python/aitemplate/compiler/ops/tensor/permute.py create mode 100644 python/aitemplate/compiler/ops/tensor/permute021.py create mode 100644 python/aitemplate/compiler/ops/tensor/permute102.py create mode 100644 python/aitemplate/compiler/ops/tensor/permute210.py create mode 100644 python/aitemplate/compiler/ops/tensor/size.py create mode 100644 python/aitemplate/compiler/ops/tensor/slice_reshape_scatter.py create mode 100644 python/aitemplate/compiler/ops/tensor/slice_scatter.py create mode 100644 python/aitemplate/compiler/ops/tensor/split.py create mode 100644 python/aitemplate/compiler/ops/tensor/topk.py create mode 100644 python/aitemplate/compiler/ops/upsample/__init__.py create mode 100644 python/aitemplate/compiler/ops/upsample/upsampling2d.py create mode 100644 python/aitemplate/compiler/ops/upsample/upsampling2d_add.py create mode 100644 python/aitemplate/compiler/ops/upsample/upsampling_common.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/__init__.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/nms/__init__.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/nms/batched_nms.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/nms/efficient_nms.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/nms/nms.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/roi_ops/multi_level_roi_align.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_align.py create mode 100644 python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_ops.py create mode 100644 python/aitemplate/compiler/public/__init__.py create mode 100644 python/aitemplate/compiler/tensor_accessor.py create mode 100644 python/aitemplate/compiler/transform/__init__.py create mode 100644 python/aitemplate/compiler/transform/apply_padding.py create mode 100644 python/aitemplate/compiler/transform/bind_constants.py create mode 100644 python/aitemplate/compiler/transform/constant_folding.py create mode 100644 python/aitemplate/compiler/transform/fuse_conv_elementwise.py create mode 100644 python/aitemplate/compiler/transform/fuse_conv_patterns.py create mode 100644 python/aitemplate/compiler/transform/fuse_group_ops.py create mode 100644 python/aitemplate/compiler/transform/fuse_mm_elementwise.py create mode 100644 python/aitemplate/compiler/transform/fuse_mm_elementwise_patterns.py create mode 100644 python/aitemplate/compiler/transform/fuse_ops.py create mode 100644 python/aitemplate/compiler/transform/fuse_parallel_gemms.py create mode 100644 python/aitemplate/compiler/transform/fuse_permute_bmm.py create mode 100644 python/aitemplate/compiler/transform/fuse_split.py create mode 100644 python/aitemplate/compiler/transform/fuse_utils.py create mode 100644 python/aitemplate/compiler/transform/mark_param_tensor.py create mode 100644 python/aitemplate/compiler/transform/memory_planning.py create mode 100644 python/aitemplate/compiler/transform/name_graph.py create mode 100644 python/aitemplate/compiler/transform/optimize_graph.py create mode 100644 python/aitemplate/compiler/transform/profile.py create mode 100644 python/aitemplate/compiler/transform/profile_dynamic_dim.py create mode 100644 python/aitemplate/compiler/transform/refine_graph.py create mode 100644 python/aitemplate/compiler/transform/remove_no_ops.py create mode 100644 python/aitemplate/compiler/transform/remove_unused_ops.py create mode 100644 python/aitemplate/compiler/transform/toposort.py create mode 100644 python/aitemplate/compiler/transform/transform_memory_ops.py create mode 100644 python/aitemplate/compiler/transform/transform_odd_alignment.py create mode 100644 python/aitemplate/compiler/transform/transform_special_ops.py create mode 100644 python/aitemplate/compiler/transform/transform_strided_op_and_view_op.py create mode 100644 python/aitemplate/compiler/transform/transform_strided_ops.py create mode 100644 python/aitemplate/compiler/transform/transform_strided_ops_utils.py create mode 100644 python/aitemplate/compiler/transform/transform_strided_slice.py create mode 100644 python/aitemplate/compiler/transform/transform_utils.py create mode 100644 python/aitemplate/frontend/__init__.py create mode 100644 python/aitemplate/frontend/nn/__init__.py create mode 100644 python/aitemplate/frontend/nn/attention.py create mode 100644 python/aitemplate/frontend/nn/container.py create mode 100644 python/aitemplate/frontend/nn/conv2d/__init__.py create mode 100644 python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_act.py create mode 100644 python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_add_act.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_hardswish.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_relu.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish_few_channels.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu_few_channels.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py create mode 100644 python/aitemplate/frontend/nn/conv2d/special_conv2d_bias_act.py create mode 100644 python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias.py create mode 100644 python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_act.py create mode 100644 python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py create mode 100644 python/aitemplate/frontend/nn/dropout.py create mode 100644 python/aitemplate/frontend/nn/embedding.py create mode 100644 python/aitemplate/frontend/nn/fpn_proposal.py create mode 100644 python/aitemplate/frontend/nn/group_norm.py create mode 100644 python/aitemplate/frontend/nn/identity.py create mode 100644 python/aitemplate/frontend/nn/layer_norm.py create mode 100644 python/aitemplate/frontend/nn/linear.py create mode 100644 python/aitemplate/frontend/nn/module.py create mode 100644 python/aitemplate/frontend/nn/padding.py create mode 100644 python/aitemplate/frontend/nn/parameter.py create mode 100644 python/aitemplate/frontend/nn/pool2d.py create mode 100644 python/aitemplate/frontend/nn/proposal.py create mode 100644 python/aitemplate/frontend/nn/roi_ops.py create mode 100644 python/aitemplate/frontend/nn/upsample.py create mode 100644 python/aitemplate/frontend/nn/view_ops.py create mode 100644 python/aitemplate/frontend/parameter.py create mode 100644 python/aitemplate/testing/__init__.py create mode 100644 python/aitemplate/testing/benchmark_ait.py create mode 100644 python/aitemplate/testing/benchmark_pt.py create mode 100644 python/aitemplate/testing/detect_target.py create mode 100644 python/aitemplate/testing/test_utils.py create mode 100644 python/aitemplate/utils/__init__.py create mode 100644 python/aitemplate/utils/graph_utils.py create mode 100644 python/aitemplate/utils/logger.py create mode 100644 python/aitemplate/utils/markdown_table.py create mode 100644 python/aitemplate/utils/mk_ck_lib/conv2d_operation.py create mode 100644 python/aitemplate/utils/mk_ck_lib/gemm_operation.py create mode 100644 python/aitemplate/utils/mk_ck_lib/generator.py create mode 100644 python/aitemplate/utils/mk_ck_lib/groupnorm_operation.py create mode 100644 python/aitemplate/utils/mk_ck_lib/layernorm_operation.py create mode 100644 python/aitemplate/utils/mk_ck_lib/library.py create mode 100644 python/aitemplate/utils/mk_ck_lib/manifest.py create mode 100644 python/aitemplate/utils/mk_ck_lib/softmax_operation.py create mode 100644 python/aitemplate/utils/mk_cutlass_lib/extra_conv_emit.py create mode 100644 python/aitemplate/utils/mk_cutlass_lib/extra_cutlass_generator.py create mode 100644 python/aitemplate/utils/mk_cutlass_lib/extra_enum.py create mode 100644 python/aitemplate/utils/mk_cutlass_lib/extra_gemm_emit.py create mode 100644 python/aitemplate/utils/mk_cutlass_lib/mk_cutlass_lib.py create mode 100644 python/aitemplate/utils/shape_utils.py create mode 100644 python/aitemplate/utils/tensor_utils.py create mode 100644 python/aitemplate/utils/torch_utils.py create mode 100644 python/aitemplate/utils/visualization/__init__.py create mode 100644 python/aitemplate/utils/visualization/op_attr_factory.py create mode 100644 python/aitemplate/utils/visualization/plot.py create mode 100644 python/aitemplate/utils/visualization/pydot.py create mode 100644 python/aitemplate/utils/visualization/web_template.py create mode 100644 python/setup.py create mode 100644 static/README.md create mode 100644 static/csrc/model_container.cpp create mode 100644 static/csrc/model_interface.cpp create mode 100644 static/csrc/rocm_hack.cpp create mode 100644 static/csrc/utility.cpp create mode 100644 static/include/cuda_device_functions.h create mode 100644 static/include/logging.h create mode 100644 static/include/macros.h create mode 100644 static/include/model_container.h create mode 100644 static/include/model_interface.h create mode 100644 static/include/owned_constants.h create mode 100644 static/include/raii_wrapper.h create mode 100644 static/include/rocm_device_functions.h create mode 100644 static/include/utility.h create mode 100644 tests/ci_profile_cache/README.md create mode 100644 tests/ci_profile_cache/update_cache.py create mode 100644 tests/lint/check_meta_header.py create mode 100644 tests/lint/flake8_problem_matcher.json create mode 100644 tests/unittest/backend/test_fused_elementwise_backend.py create mode 100644 tests/unittest/backend/test_model_api.py create mode 100644 tests/unittest/benchmark/test_group_gemm_benchmark.py create mode 100644 tests/unittest/benchmark/test_strided_layernorm_benchmark.py create mode 100644 tests/unittest/compiler/test_constant_folding.py create mode 100644 tests/unittest/compiler/test_fuse_conv_elementwise.py create mode 100644 tests/unittest/compiler/test_fuse_expand.py create mode 100644 tests/unittest/compiler/test_fuse_mm_elementwise.py create mode 100644 tests/unittest/compiler/test_fuse_permute_bmm.py create mode 100644 tests/unittest/compiler/test_fused_elementwise_complex_dependency.py create mode 100644 tests/unittest/compiler/test_fused_elementwise_out_of_order.py create mode 100644 tests/unittest/compiler/test_group_fusions.py create mode 100644 tests/unittest/compiler/test_memory_planning.py create mode 100644 tests/unittest/compiler/test_pad_bmm_rrr_bias_with_cat.py create mode 100644 tests/unittest/compiler/test_pad_gemm_rrr_with_cat.py create mode 100644 tests/unittest/compiler/test_pad_gemm_with_cat.py create mode 100644 tests/unittest/compiler/test_pad_gemm_with_elementwise.py create mode 100644 tests/unittest/compiler/test_parallel_gemm_fusions.py create mode 100644 tests/unittest/compiler/test_permute_bmm_special_op.py create mode 100644 tests/unittest/compiler/test_public_import.py create mode 100644 tests/unittest/compiler/test_refine_graph.py create mode 100644 tests/unittest/compiler/test_remove_unused_ops.py create mode 100644 tests/unittest/compiler/test_slice_elemwise_fusion.py create mode 100644 tests/unittest/compiler/test_slice_gemm_fusion.py create mode 100644 tests/unittest/compiler/test_slice_reshape_scatter.py create mode 100644 tests/unittest/compiler/test_slice_scatter_pattern.py create mode 100644 tests/unittest/compiler/test_slice_view_strided.py create mode 100644 tests/unittest/compiler/test_split_bmm_fusion.py create mode 100644 tests/unittest/compiler/test_split_bmm_softmax_bmm.py create mode 100644 tests/unittest/compiler/test_split_view_strided.py create mode 100644 tests/unittest/compiler/test_strided_group_gemm.py create mode 100644 tests/unittest/compiler/test_strided_group_layernorm.py create mode 100644 tests/unittest/compiler/test_strided_layernorm.py create mode 100644 tests/unittest/compiler/test_strided_layernorm_reshape.py create mode 100644 tests/unittest/compiler/test_strided_op_cat_pattern.py create mode 100644 tests/unittest/compiler/test_strided_reshape_cat.py create mode 100644 tests/unittest/compiler/test_strided_scatter.py create mode 100644 tests/unittest/compiler/test_strided_split_group_gemm.py create mode 100644 tests/unittest/compiler/test_strided_view_cat.py create mode 100644 tests/unittest/compiler/test_strided_view_op.py create mode 100644 tests/unittest/compiler/test_tensor_accessor.py create mode 100644 tests/unittest/compiler/test_transform_memory_ops.py create mode 100644 tests/unittest/compiler/test_transform_odd_alignment.py create mode 100644 tests/unittest/compiler/test_transform_special_op.py create mode 100644 tests/unittest/compiler/test_transform_utils.py create mode 100644 tests/unittest/compiler/test_view_strided_op.py create mode 100644 tests/unittest/frontend/test_module.py create mode 100644 tests/unittest/ops/test_activation.py create mode 100644 tests/unittest/ops/test_argmax.py create mode 100644 tests/unittest/ops/test_attention.py create mode 100644 tests/unittest/ops/test_avg_pool2d.py create mode 100644 tests/unittest/ops/test_batch_gather.py create mode 100644 tests/unittest/ops/test_bert_embeddings.py create mode 100644 tests/unittest/ops/test_bmm.py create mode 100644 tests/unittest/ops/test_bmm_add.py create mode 100644 tests/unittest/ops/test_bmm_alpha.py create mode 100644 tests/unittest/ops/test_bmm_permute.py create mode 100644 tests/unittest/ops/test_bmm_rcr_n1.py create mode 100644 tests/unittest/ops/test_bmm_rrr_k1_tanh.py create mode 100644 tests/unittest/ops/test_bmm_softmax.py create mode 100644 tests/unittest/ops/test_bmm_softmax_bmm.py create mode 100644 tests/unittest/ops/test_chunk.py create mode 100644 tests/unittest/ops/test_clamp_nan_to_num.py create mode 100644 tests/unittest/ops/test_concatenate.py create mode 100644 tests/unittest/ops/test_concatenate_tanh.py create mode 100644 tests/unittest/ops/test_conv.py create mode 100644 tests/unittest/ops/test_conv2d_bias_add.py create mode 100644 tests/unittest/ops/test_conv_bias.py create mode 100644 tests/unittest/ops/test_conv_bias_act_few_channels.py create mode 100644 tests/unittest/ops/test_conv_bias_add_hardswish.py create mode 100644 tests/unittest/ops/test_conv_bias_add_relu.py create mode 100644 tests/unittest/ops/test_conv_bias_hardswish.py create mode 100644 tests/unittest/ops/test_conv_bias_relu.py create mode 100644 tests/unittest/ops/test_conv_bias_sigmoid.py create mode 100644 tests/unittest/ops/test_dynamic_conv.py create mode 100644 tests/unittest/ops/test_efficient_nms.py create mode 100644 tests/unittest/ops/test_expand.py create mode 100644 tests/unittest/ops/test_flatten.py create mode 100644 tests/unittest/ops/test_fpn_roi_align.py create mode 100644 tests/unittest/ops/test_fused_elementwise.py create mode 100644 tests/unittest/ops/test_fused_elementwise_broadcast.py create mode 100644 tests/unittest/ops/test_fused_elementwise_with_strided_outputs.py create mode 100644 tests/unittest/ops/test_gather.py create mode 100644 tests/unittest/ops/test_gemm.py create mode 100644 tests/unittest/ops/test_gemm_bias.py create mode 100644 tests/unittest/ops/test_gemm_bias_broadcast.py create mode 100644 tests/unittest/ops/test_gemm_bias_hardswish.py create mode 100644 tests/unittest/ops/test_gemm_bias_permute.py create mode 100644 tests/unittest/ops/test_gemm_bias_relu.py create mode 100644 tests/unittest/ops/test_gemm_bias_sigmoid.py create mode 100644 tests/unittest/ops/test_gemm_bias_softmax.py create mode 100644 tests/unittest/ops/test_gemm_bias_swish.py create mode 100644 tests/unittest/ops/test_gemm_bias_tanh.py create mode 100644 tests/unittest/ops/test_gemm_permute.py create mode 100644 tests/unittest/ops/test_gemm_rcr_bias_fast_gelu.py create mode 100644 tests/unittest/ops/test_gemm_rrr_small_nk.py create mode 100644 tests/unittest/ops/test_gemm_softmax.py create mode 100644 tests/unittest/ops/test_group_gemm_rcr.py create mode 100644 tests/unittest/ops/test_group_gemm_rcr_bias.py create mode 100644 tests/unittest/ops/test_group_gemm_rcr_bias_activation.py create mode 100644 tests/unittest/ops/test_group_gemm_rcr_bias_cat.py create mode 100644 tests/unittest/ops/test_group_gemm_rcr_cat.py create mode 100644 tests/unittest/ops/test_groupnorm.py create mode 100644 tests/unittest/ops/test_layernorm.py create mode 100644 tests/unittest/ops/test_layernorm_sigmoid_mul.py create mode 100644 tests/unittest/ops/test_max_pool2d.py create mode 100644 tests/unittest/ops/test_nhwc3to4.py create mode 100644 tests/unittest/ops/test_nhwc3to8.py create mode 100644 tests/unittest/ops/test_nms.py create mode 100644 tests/unittest/ops/test_norm.py create mode 100644 tests/unittest/ops/test_pad_last_dim.py create mode 100644 tests/unittest/ops/test_perm021fc_ccr.py create mode 100644 tests/unittest/ops/test_perm021fc_ccr_bias.py create mode 100644 tests/unittest/ops/test_perm021fc_ccr_bias_perm021.py create mode 100644 tests/unittest/ops/test_perm021fc_crc.py create mode 100644 tests/unittest/ops/test_perm021fc_crc_bias.py create mode 100644 tests/unittest/ops/test_perm102_bmm_rcr.py create mode 100644 tests/unittest/ops/test_perm102_bmm_rrr.py create mode 100644 tests/unittest/ops/test_permute.py create mode 100644 tests/unittest/ops/test_permute021.py create mode 100644 tests/unittest/ops/test_permute102.py create mode 100644 tests/unittest/ops/test_permute210.py create mode 100644 tests/unittest/ops/test_proposal.py create mode 100644 tests/unittest/ops/test_reduce.py create mode 100644 tests/unittest/ops/test_reshape.py create mode 100644 tests/unittest/ops/test_roi_align.py create mode 100644 tests/unittest/ops/test_size_getitem_ops.py create mode 100644 tests/unittest/ops/test_slice.py create mode 100644 tests/unittest/ops/test_softmax.py create mode 100644 tests/unittest/ops/test_split.py create mode 100644 tests/unittest/ops/test_split_getitem.py create mode 100644 tests/unittest/ops/test_squeeze.py create mode 100644 tests/unittest/ops/test_topk.py create mode 100644 tests/unittest/ops/test_transpose_conv2d.py create mode 100644 tests/unittest/ops/test_transpose_conv2d_bias.py create mode 100644 tests/unittest/ops/test_transpose_conv2d_bias_relu.py create mode 100644 tests/unittest/ops/test_tuple_list_construct.py create mode 100644 tests/unittest/ops/test_upsamping2d.py create mode 100644 tests/unittest/ops/test_upsamping2d_add.py create mode 100644 tests/unittest/ops/test_var.py diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 000000000..19c2d377a --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,80 @@ +# Use the latest 2.1 version of CircleCI pipeline process engine. +# See: https://circleci.com/docs/2.0/configuration-reference +version: 2.1 + +# Orbs are reusable packages of CircleCI configuration that you may share across projects, enabling you to create encapsulated, parameterized commands, jobs, and executors that can be used across multiple projects. +# See: https://circleci.com/docs/2.0/orb-intro/ +orbs: + # The python orb contains a set of prepackaged CircleCI configuration you can use repeatedly in your configuration files + # Orb commands and jobs help you with common scripting around a language/tool + # so you dont have to copy and paste it everywhere. + # See the orb documentation here: https://circleci.com/developer/orbs/orb/circleci/python + python: circleci/python@1.5.0 + +# ------------------------------------------------------------------------------------- +# Re-usable commands +# ------------------------------------------------------------------------------------- +setup_env: &setup_env + - run: + name: Setup environment + command: | + python3.8 --version + python3.8 -m pip install --upgrade pip + cd python + python3.8 setup.py bdist_wheel + sudo python3.8 -m pip install --no-input dist/*.whl + cd .. + python3.8 -m pip install pytest + python3.8 -m pip install torch + python3.8 -m pip install numpy + python3.8 -m pip install jinja2 + python3.8 -m pip install recordtype + python3.8 -m pip install parameterized + python3.8 -m pip install einops + git submodule sync + git submodule update --init + echo 'export PYTHONPATH=$PWD/python:$PYTHONPATH' >> $BASH_ENV + echo 'export PATH=/usr/local/cuda-11.4/bin:$PATH' >> $BASH_ENV + echo 'export CI_FLAG=CIRCLECI' >> $BASH_ENV + echo 'export CACHE_DIR=$PWD/tests/ci_profile_cache' >> $BASH_ENV + +basic_tests: &basic_tests + - run: + name: Run tests + command: | + set -e + TEST_FILES=$(circleci tests glob "tests/unittest/**/test_*.py" | grep -v benchmark | circleci tests split --split-by=timings) + mkdir test-results + python3.8 -m pytest $TEST_FILES --junitxml=test-results/junit.xml --verbose --continue-on-collection-errors -rA + + +# Define a job to be invoked later in a workflow. +# See: https://circleci.com/docs/2.0/configuration-reference/#jobs +jobs: + build-and-test: + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + # Check T101565170 for multi-gpu use cases. + resource_class: gpu.nvidia.medium + + parallelism: 10 + + # Checkout the code as the first step. This is a dedicated CircleCI step. + # The python orb's install-packages step will install the dependencies from a Pipfile via Pipenv by default. + # Here we're making sure we use just use the system-wide pip. By default it uses the project root's requirements.txt. + # Then run your tests! + # CircleCI will report the results back to your VCS provider. + steps: + - checkout + - <<: *setup_env + - <<: *basic_tests + - store_test_results: + path: test-results + +# Invoke jobs via workflows +# See: https://circleci.com/docs/2.0/configuration-reference/#workflows +workflows: + unittest: # This is the name of the workflow, feel free to change it to better match your workflow. + # Inside the workflow, you define the jobs you want to run. + jobs: + - build-and-test diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..73304266b --- /dev/null +++ b/.clang-format @@ -0,0 +1,88 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 2000000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..3f29e8612 --- /dev/null +++ b/.flake8 @@ -0,0 +1,63 @@ +[flake8] +select = B,C,E,F,P,W,B9 +max-line-length = 80 +# Main Explanation Docs: https://github.com/grantmcconnaughey/Flake8Rules +ignore = + # Black conflicts and overlaps. + # Found in https://github.com/psf/black/issues/429 + B950, # Line too long. + E111, # Indentation is not a multiple of four. + E115, # Expected an indented block (comment). + E117, # Over-indented. + E121, # Continuation line under-indented for hanging indent. + E122, # Continuation line missing indentation or outdented. + E123, # Closing bracket does not match indentation of opening bracket's line. + E124, # Closing bracket does not match visual indentation. + E125, # Continuation line with same indent as next logical line. + E126, # Continuation line over-indented for hanging indent. + E127, # Continuation line over-indented for visual indent. + E128, # Continuation line under-indented for visual indent. + E129, # Visually indented line with same indent as next logical line. + E131, # Continuation line unaligned for hanging indent. + E201, # Whitespace after '('. + E202, # Whitespace before ')'. + E203, # Whitespace before ':'. + E221, # Multiple spaces before operator. + E222, # Multiple spaces after operator. + E225, # Missing whitespace around operator. + E226, # Missing whitespace around arithmetic operator. + E227, # Missing whitespace around bitwise or shift operator. + E231, # Missing whitespace after ',', ';', or ':'. + E241, # Multiple spaces after ','. + E251, # Unexpected spaces around keyword / parameter equals. + E252, # Missing whitespace around parameter equals. + E261, # At least two spaces before inline comment. + E262, # Inline comment should start with '# '. + E265, # Block comment should start with '# '. + E271, # Multiple spaces after keyword. + E272, # Multiple spaces before keyword. + E301, # Expected 1 blank line, found 0. + E302, # Expected 2 blank lines, found 0. + E303, # Too many blank lines (3). + E305, # Expected 2 blank lines after end of function or class. + E306, # Expected 1 blank line before a nested definition. + E501, # Line too long (82 > 79 characters). + E502, # The backslash is redundant between brackets. + E701, # Multiple statements on one line (colon). + E702, # Multiple statements on one line (semicolon). + E703, # Statement ends with a semicolon. + E704, # Multiple statements on one line (def). + W291, # Trailing whitespace. + W292, # No newline at end of file. + W293, # Blank line contains whitespace. + W391, # Blank line at end of file. + W504, # Line break occurred after a binary operator. + + # Too opinionated. + E265, # Block comment should start with '# '. + E266, # Too many leading '#' for block comment. + E402, # Module level import not at top of file. (Use cases like demandimport https://fburl.com/demandimport require statements before imports) + E722, # Do not use bare except, specify exception instead. (Duplicate of B001) + P207, # (Duplicate of B003) + P208, # (Duplicate of C403) + W503 # Line break occurred before a binary operator. diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..208bd1f77 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,67 @@ +# Simple workflow for deploying static content to GitHub Pages +name: Documentation + +on: + # Runs on pushes targeting the default branch + push: + branches: ["main"] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + +jobs: + # Single deploy job since we're just deploying + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9"] + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install autodocsumm + pip install sphinx_rtd_theme + pip install sphinx_gallery + pip install sphinxcontrib-inlinesyntaxhighlight + pip install sphinx_toolbox + pip install numpy + pip install jinja2 + pip install torch + cd python + python setup.py develop + cd .. + - name: Build documents with Sphinx + run: | + cd docs + BUILD_DOCS=1 make html + cd .. + - name: Setup Pages + uses: actions/configure-pages@v2 + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + with: + path: './docs/build/html' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..dbd4beb83 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,41 @@ +name: Lint + +on: + push: + branches: + - main + + pull_request: + branches: + - main +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ufmt + pip install click + pip install flake8 + - name: Analyzing the code with flake8 + run: | + echo "::add-matcher::tests/lint/flake8_problem_matcher.json" + flake8 . + - name: Analyzing the code with ufmt + run: | + ufmt diff python + ufmt diff tests + ufmt diff docs + - name: Check Meta copyright header + run: | + python tests/lint/check_meta_header.py --path=./tests --fixit=False + python tests/lint/check_meta_header.py --path=./python --fixit=False \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..f3bbc0889 --- /dev/null +++ b/.gitignore @@ -0,0 +1,143 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# tmp +tmp/ + +tags + +# macOS dir files +.DS_Store + +# vscode +.vscode + +# vim temp files +*.swp diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..2aeb63ba5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,10 @@ +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/cub"] + path = 3rdparty/cub + url = https://github.com/NVIDIA/cub.git +[submodule "3rdparty/composable_kernel"] + path = 3rdparty/composable_kernel + url = https://github.com/ROCmSoftwarePlatform/composable_kernel.git + branch = develop diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel new file mode 160000 index 000000000..b88255475 --- /dev/null +++ b/3rdparty/composable_kernel @@ -0,0 +1 @@ +Subproject commit b8825547586855ec730a2eca47e415b1404bb5f2 diff --git a/3rdparty/cub b/3rdparty/cub new file mode 160000 index 000000000..dcd5b06a4 --- /dev/null +++ b/3rdparty/cub @@ -0,0 +1 @@ +Subproject commit dcd5b06a417bdfdc2699678bddf7dd7ee38be466 diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..dadc881a9 --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit dadc881a9606f95cba1b20acda03c9d07c286239 diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..810680431 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,54 @@ +cff-version: 1.2.0 +title: AITemplate +message: >- + If you use this software, please cite using the + following metadata. +type: software +authors: + - given-names: Bing + family-names: Xu + affiliation: Meta + email: bingxu@meta.com + - given-names: Ying + family-names: Zhang + affiliation: Meta + email: yingz@meta.com + - given-names: Hao + family-names: Lu + affiliation: Meta + email: hlu@meta.com + - given-names: Yang + family-names: Chen + affiliation: Meta + email: yangche@meta.com + - given-names: Terry + family-names: Chen + affiliation: Meta + email: terrychen@meta.com + - given-names: Mike + family-names: Iovine + affiliation: Meta + email: mikeiovine@meta.com + - given-names: Mu-Chu + family-names: Lee + affiliation: Meta + email: mlee8@meta.com + - given-names: Zhijing + family-names: Li + affiliation: Meta + email: tissue030@meta.com + + +repository-code: 'https://github.com/facebookincubator/AITemplate' +abstract: >- + AITemplate (AIT) is a unified inference framework with separate acceleration backends for both AMD and NVIDIA GPU hardware. It delivers close to hardware-native Tensor Core (NVIDIA GPU) and Matrix Core (AMD GPU) performance on a variety of widely used AI models such as convolutional neural networks, transformers, and diffusers. +keywords: + - 'neural network, cutlass, composable kernel, cuda, rocm' +license: Apache 2.0 +license-url: https://github.com/facebookincubator/AITemplate/LICENSE +version: '0.1' +date-released: '2022-10-03' +identifiers: + - type: url + value: "https://github.com/facebookincubator/AITemplate/tree/v0.1.0" + description: The GitHub release URL of tag 0.1.0 \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..08b500a22 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..fde4225fa --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,37 @@ +# Contributing to AITemplate +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +1. For major change, submit RFC to discuss the change. +2. For feature extension, submit PR with tests and documentation. +3. For bug fix, submit PR with tests and documentation. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + + +## License +By contributing to AITemplate, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..b09cd7856 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..833b6cebc --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# AITemplate + +[![License](https://img.shields.io/badge/License-Apache_2.0-brightgreen.svg)](https://github.com/facebookincubator/AITemplate/blob/main/LICENSE) | +[![Documentation](https://github.com/facebookincubator/AITemplate/actions/workflows/docs.yml/badge.svg)](https://facebookincubator.github.io/AITemplate) | +[![CircleCI](https://circleci.com/gh/facebookincubator/AITemplate.svg?style=svg)](https://app.circleci.com/pipelines/github/facebookincubator/AITemplate) + + + + +AITemplate (AIT) is a Python framework that transforms deep neural networks into CUDA (NVIDIA GPU) / HIP (AMD GPU) C++ code for lightning-fast inference serving. AITemplate highlights include: + +- High performance: close to roofline fp16 TensorCore (NVIDIA GPU) / MatrixCore (AMD GPU) performance on major models, including ResNet, MaskRCNN, BERT, VisionTransformer, Stable Diffusion, etc. +- Unified, open, and flexible. Seamless fp16 deep neural network models for NVIDIA GPU or AMD GPU. Fully open source, Lego-style easy extendable high-performance primitives for new model support. Supports a significantly more comprehensive range of fusions than existing solutions for both GPU platforms. + +## More about AITemplate + +### Excellent Backward Capability + +AITemplate doesn't depend on third-party libraries or runtimes, such as cuBLAS, cuDNN, rocBLAS, MIOpen, TensorRT, MIGraphX, etc. Each model is compiled into a self-contained portable binary, which can be used on any software environment with the same hardware. + +### Horizontal Fusion + +AITemplate provides unique advanced horizontal fusion. AITemplate can fuse parallel GEMM, LayerNorm, and other operators with different input shapes into a single GPU kernel. + +### Vertical Fusion + +AITemplate provides strong vertical fusion. AITemplate can fuse a large range of operations into TensorCore/MatrixCore operations, such as elementwise operations, reduction operations, and layout permutation operations. AITemplate also provides back-to-back style TensorCore / MatrixCore operation fusion. + +### Memory Fusion + +AITemplate provides innovative memory fusions. AITemplate can fuse GEMM, LayerNorm, and other operators, followed by memory operations such as concatenation, split, and slice into a single operator. + +### Working w/wo PyTorch +The AITemplate-generated Python runtime can take PyTorch tensors as inputs and outputs without an extra copy. For environments without PyTorch, the AITemplate Python/C++ runtime is self-contained. + +### Extensions without suffering + +AITemplate provides a straightforward approach for making an extension in codegen. To add a new operator or a new fused kernel into AITemplate, most of the time one only needs to add two Python files: one for a graph node definition and another for the backend codegen. The CUDA/HIP kernel in a text header file can be directly utilized in the codegen. + +## Installalation + +**Hardware requirement:** + - **NVIDIA**: AIT is only tested on SM80+ GPUs (Ampere etc). Not all kernels work with old SM75/SM70 (T4/V100) GPUs. + - **AMD**: AIT is only tested on CDNA2 (MI-210/250) GPUs. There may be compiler issues for old CDNA1 (MI-100) GPUs. + +### Docker Image +We highly recommend using AITemplate with Docker to avoid accidentally using a wrong version of NVCC or HIPCC. +- CUDA: `./docker/build.sh cuda` +- ROCM: `DOCKER_BUILDKIT=1 ./docker/build.sh rocm` + +This will build a docker image with tag `ait:latest`. + +### From Source +The following command will create a Python wheel for AITemplate. Please ensure you have correct CUDA/ROCm compiler installed. +- CUDA: CUDA 11.6 +- ROCm: We tested on ROCm 5.2.3 with a customized build HIPCC with the command in docker/Dockerfile.rocm#L87-L96 + +*Incorrect compiler will lead performance regression.* + +``` +cd python +python setup.py bdist_wheel +pip install dist/*.whl +``` + +## Getting Started + +Check out the [AITemplate Documentation](https://facebookincubator.github.io/AITemplate) for API reference. + +There are a few tutorials for onboarding: + +- 01: [How to inference a PyTorch model with AIT](https://facebookincubator.github.io/AITemplate/tutorial/how_to_infer_pt.html) +- 02: [How to add an op to AIT codegen](https://facebookincubator.github.io/AITemplate/tutorial/how_to_add_op.html) +- 03: [How to visualize AIT's optimization](https://facebookincubator.github.io/AITemplate/tutorial/how_to_visualize.html) + + +## Examples & Performance +AITemplate provides the following model templates & reference performance data on A100/MI-250 + +- [01_ResNet-50](examples/01_resnet-50/) with PyTorch Image Models (TIMM) +- [02_MaskRCNN-FPN](examples/02_detectron2/) with Detectron2 +- [03_BERT](examples/03_bert/) with HuggingFace Transformer +- [04_Vision Transformer](examples/04_vit/) with PyTorch Image Models (TIMM) +- [05_Stable Diffusion](examples/05_stable_diffusion/) with HuggingFace Diffusers + +## Release + +AITemplate has a 90 days release cycle. +In the next one or two releases, we will focus on: +- Deprecating FlashAttention: Unify CUDA Attention computation to Composable Kernel (AMD GPU) style back-to-back fusion to improve performance and increase flexibility for NVIDIA GPU Transformer users. +- Remove kernel profiling requirement. +- GEMM + LayerNorm fusion, GEMM + GEMM fusion, Conv + Conv fusion. +- Better dynamic shape support: Focus on the dynamic sequence in Transformers. +- More model templates: Provide model templates with control flow and containers. +- More automatic graph passes: Relief manual rewrite models to obtain the best performance. +- Enable more fusions on AMD backend. + +Some ongoing/potential work that won't appear in the next short-term release: +- Automatic Pytorch-FX, ONNX, Open-XLA and other format model conversion. +- Quantized model (int8/fp8/int4) support. +- Composable Kernel CPU extension on AVX2/AVX-512 for AMD Epyc CPU. + +## Contributing +Check our [contributing guide](CONTRIBUTING.md) to learn about how to contribute to the project. + +## The Team + +AITemplate is co-created by Meta engineers: [Bing Xu](https://github.com/antinucleon), [Ying Zhang](https://github.com/ipiszy), [Hao Lu](https://github.com/hlu1), [Yang Chen](https://github.com/chenyang78), and [Terry Chen](https://github.com/terrychenism), with major contributions coming from more talented engineers. A non-exhaustive list to mention is Mike Iovine, Mu-Chu Lee, Scott Wolchok, Oleg Khabinov, Shirong Wu, Huaming Li, Hui Guo, Zhijing Li, Max Podkorytov. We also want to thank the discussions with Andrew Tulloch, Yinghai Lu, Lu Fang. + +AITemplate is currently maintained by Meta engineers: [Ying Zhang](https://github.com/ipiszy), [Hao Lu](https://github.com/hlu1), [Yang Chen](https://github.com/chenyang78), [Terry Chen](https://github.com/terrychenism), [Mike Iovine](https://github.com/mikeiovine), [Mu-Chu Lee](https://github.com/muchulee8) and [Bing Xu](https://github.com/antinucleon). + + +## Acknowledgement + +AITemplate team works deeply with NVIDIA [CUTLASS](https://github.com/NVIDIA/cutlass) Team (Led by Andrew Kerr, Haicheng Wu) and AMD [Composable Kernel](https://github.com/ROCmSoftwarePlatform/composable_kernel) Team (Led by Chao Liu, Jing Zhang). We co-designed many advanced GPU optimizations specialized for each platform, and nothing is possible without our close collaboration. + + +## License +AITemplate is licensed under the [Apache 2.0 License](https://github.com/facebookincubator/AITemplate/blob/main/LICENSE). diff --git a/docker/Dockerfile.cuda b/docker/Dockerfile.cuda new file mode 100644 index 000000000..0461f45bf --- /dev/null +++ b/docker/Dockerfile.cuda @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# CUDA Docker Image for AITemplate + +FROM nvidia/cuda:11.6.2-devel-ubuntu20.04 + +# Base scripts +RUN apt-get update --fix-missing +RUN apt install -y python3 python3-dev python3-pip + +# Environment variables +ENV PATH=/usr/local/nvidia/bin:${PATH} +ENV PATH=/usr/local/cuda/bin:${PATH} +ENV LIBRARY_PATH=/usr/local/cuda/lib64:${LIBRARY_PATH} +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} + +ADD ./docker/install/ /Install +# necessary package +RUN bash /Install/install_basic_dep.sh + +# for test +RUN bash /Install/install_test_dep.sh + +# for docs +RUN bash /Install/install_doc_dep.sh + + +# install Pytorch +RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 + +# for detection +RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata +RUN bash /Install/install_detection_deps.sh + +# Copy AITemplate to Docker +RUN mkdir /AITemplate +ADD ./COMMIT_INFO /AITemplate/COMMIT_INFO +ADD ./python /AITemplate/python +ADD ./3rdparty /AITemplate/3rdparty +ADD ./examples /AITemplate/examples +ADD ./tests /AITemplate/tests +ADD ./docs /AITemplate/docs +ADD ./static /AITemplate/static +ADD ./licenses /AITemplate/licenses +ADD ./docker/install/install_ait.sh /AITemplate/ +RUN bash /AITemplate/install_ait.sh diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 000000000..991bc3095 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ROCM Docker Image for AITemplate +FROM ubuntu:20.04 + +ARG ROCMVERSION=5.2.3 +ARG compiler_version=b0f4678b9058a4ae00200dfb1de0da5f2ea84dcb + + +RUN set -xe + +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +# Add rocm repository +RUN apt-get update +RUN apt-get install -y wget gnupg +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - +RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" +RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - +RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" + +# Install dependencies +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + build-essential \ + cmake-data \ + cmake \ + curl \ + git \ + hip-rocclr \ + jq \ + libelf-dev \ + libncurses5-dev \ + libnuma-dev \ + libpthread-stubs0-dev \ + llvm-amdgpu \ + pkg-config \ + python \ + python3 \ + python-dev \ + python3-dev \ + python3-pip \ + software-properties-common \ + rocm-dev \ + rocm-device-libs \ + rocm-cmake \ + rocm-libs \ + vim \ + zlib1g-dev \ + openssh-server \ + clang-format-10 \ + kmod && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Setup ubsan environment to printstacktrace +RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer +ENV UBSAN_OPTIONS=print_stacktrace=1 + +# Install an init system +RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb +RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb + +ARG PREFIX=/opt/rocm +# Install packages for processing the performance results +RUN pip3 install --upgrade pip +RUN pip3 install sqlalchemy +RUN pip3 install pymysql +RUN pip3 install pandas +RUN pip3 install setuptools-rust +RUN pip3 install sshtunnel +# Setup ubsan environment to printstacktrace +ENV UBSAN_OPTIONS=print_stacktrace=1 + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 +ADD ./docker/install/rocm_dev-requirements.txt dev-requirements.txt +RUN groupadd -f render + +# Install the new rocm-cmake version +RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / + +ENV compiler_version=$compiler_version +RUN sh -c "echo compiler version = '$compiler_version'" + +RUN --mount=type=ssh if [ "$compiler_version" != "release" ]; then \ + git clone https://github.com/RadeonOpenCompute/llvm-project.git && \ + cd llvm-project && \ + git checkout "$compiler_version" && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + make -j 96 ; \ + else echo "using the release compiler"; \ + fi + +ENV HIP_CLANG_PATH='/llvm-project/build/bin' +RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" + +# Fix compiler bug in 10736 +ADD ./docker/rocm_fix /rocm_fix +RUN python3 /rocm_fix/fix_10736.py + +ADD ./docker/install/ /Install +# necessary package +RUN bash /Install/install_basic_dep.sh + +# for test +RUN bash /Install/install_test_dep.sh + +# for docs +RUN bash /Install/install_doc_dep.sh + +# Install Pytorch +RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 + +# for detection +RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata +RUN bash /Install/install_detection_deps.sh + + +# Copy AITemplate to Docker +RUN mkdir /AITemplate +ADD ./COMMIT_INFO /AITemplate/COMMIT_INFO +ADD ./python /AITemplate/python +ADD ./3rdparty /AITemplate/3rdparty +ADD ./examples /AITemplate/examples +ADD ./tests /AITemplate/tests +ADD ./docs /AITemplate/docs +ADD ./static /AITemplate/static +ADD ./licenses /AITemplate/licenses +ADD ./docker/install/install_ait.sh /AITemplate/ +RUN bash /AITemplate/install_ait.sh diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 000000000..dea4b35b9 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,30 @@ +# Docker + AITemplate + +AITemplate provides a Docker image with all test, benchmark, and documentation dependencies installed. + +## Build CUDA Docker Image + +```bash docker/build.sh cuda``` +This will build a CUDA 11 docker image with tag: `ait:latest` + +## Build ROCM Docker Image + +```DOCKER_BUILDKIT=1 bash docker/build.sh rocm``` +This will build a RCOM 5 docker image with tag: `ait:latest` + +## Running Unit Tests in Docker + +```docker run --gpus all ait:latest bash /AITemplate/tests/nightly/unittest.sh``` + +## Launching CUDA Docker +```docker run --gpus all -it ait:latest``` + +## Launching ROCM Docker + +```docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined ait:latest``` + + +## Common questions: +- Q: When building ROCm Docker, I hit this error ` => ERROR [internal] load metadata for docker.io/library/ubuntu:20.04`, what shall I do? + + A: Run `docker pull docker.io/library/ubuntu:20.04` to pull base image manually, then re-run `./docker/build.sh rocm` diff --git a/docker/build.sh b/docker/build.sh new file mode 100755 index 000000000..37c3612b1 --- /dev/null +++ b/docker/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +TARGET=$1 +COMMIT=$(git show --format="%H" --no-patch) +COMMIT_AUTHOR=$(git show --format="%an" --no-patch) +COMMIT_TIME=$(git show --format="%cI" --no-patch) +echo "$COMMIT" > COMMIT_INFO +echo "$COMMIT_AUTHOR" >> COMMIT_INFO +echo "$COMMIT_TIME" >> COMMIT_INFO + +if [ "$TARGET" = "cuda" ]; then + if [ "$2" = "debug" ]; then + echo "Build in DEBUG mode with git files" + echo "RUN apt install -y vim git" >> ./docker/Dockerfile.cuda + echo "ADD .git /AITemplate/.git" >> ./docker/Dockerfile.cuda + fi + echo "Building CUDA Docker Image with tag ait:latest" + docker build -f ./docker/Dockerfile.cuda -t ait . +elif [ "$TARGET" = "rocm" ]; then + echo "Building ROCM Docker Image with tag ait:latest" + docker build -f ./docker/Dockerfile.rocm -t ait . +else + echo "Unknown target" +fi diff --git a/docker/install/install_ait.sh b/docker/install/install_ait.sh new file mode 100644 index 000000000..3b1fdf6f3 --- /dev/null +++ b/docker/install/install_ait.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd /AITemplate/python +python3 setup.py bdist_wheel +pip3 install --no-input /AITemplate/python/dist/*.whl diff --git a/docker/install/install_basic_dep.sh b/docker/install/install_basic_dep.sh new file mode 100644 index 000000000..801ef53ef --- /dev/null +++ b/docker/install/install_basic_dep.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +pip3 install numpy +pip3 install jinja2 diff --git a/docker/install/install_detection_deps.sh b/docker/install/install_detection_deps.sh new file mode 100644 index 000000000..47238cd3c --- /dev/null +++ b/docker/install/install_detection_deps.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +apt install -y ffmpeg libsm6 libxext6 wget +pip3 install yacs +pip3 install opencv-python +pip3 install tqdm +pip3 install timm +pip3 install transformers +pip3 install diffusers diff --git a/docker/install/install_doc_dep.sh b/docker/install/install_doc_dep.sh new file mode 100644 index 000000000..350738142 --- /dev/null +++ b/docker/install/install_doc_dep.sh @@ -0,0 +1,6 @@ +#! /bin/bash +pip3 install autodocsumm +pip3 install sphinx_rtd_theme +pip3 install sphinx_gallery +pip3 install sphinxcontrib-inlinesyntaxhighlight +pip3 install sphinx_toolbox diff --git a/docker/install/install_test_dep.sh b/docker/install/install_test_dep.sh new file mode 100644 index 000000000..6cc7c1b44 --- /dev/null +++ b/docker/install/install_test_dep.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +pip3 install click +pip3 install pytest +pip3 install parameterized +pip3 install pylint==2.13.9 +pip3 install ufmt +pip3 install pyGithub +pip3 install gitpython +pip3 install xmltodict +pip3 install einops diff --git a/docker/install/rocm_dev-requirements.txt b/docker/install/rocm_dev-requirements.txt new file mode 100644 index 000000000..3c8cbd155 --- /dev/null +++ b/docker/install/rocm_dev-requirements.txt @@ -0,0 +1,3 @@ +ROCmSoftwarePlatform/rocm-recipes +# 1.90+ +danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f diff --git a/docker/rocm_fix/fix_10736.py b/docker/rocm_fix/fix_10736.py new file mode 100644 index 000000000..c91e7f200 --- /dev/null +++ b/docker/rocm_fix/fix_10736.py @@ -0,0 +1,9 @@ +src = "" +with open("/opt/rocm/hip/bin/hipcc.pl", "r") as fi: + src = fi.read() + +src = src.replace( + "$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);", "chomp($HIP_CLANG_TARGET);" +) +with open("/opt/rocm/hip/bin/hipcc.pl", "w") as fo: + fo.write(src) diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..7f6e76eb5 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,22 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + cp static/ait_model.html build/html/tutorial/ait_model.html + diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..1a07a9b4c --- /dev/null +++ b/docs/README.md @@ -0,0 +1,20 @@ +# AITemplate Documentation + + +## Build locally + +1. Install AITemplate + +2. Install Sphinx +``` +pip install autodocsumm +pip install sphinx_rtd_theme +pip install sphinx_gallery +pip install sphinxcontrib-inlinesyntaxhighlight +pip install sphinx_toolbox +``` + +3. Build HTML +``` +make html +``` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..747ffb7b3 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/arch/index.rst b/docs/source/arch/index.rst new file mode 100644 index 000000000..13fdf207f --- /dev/null +++ b/docs/source/arch/index.rst @@ -0,0 +1,12 @@ +Design and Architecture +======================= + + +.. toctree:: + :maxdepth: 1 + + philosophy + + + +Stay tuned for more... diff --git a/docs/source/arch/philosophy.rst b/docs/source/arch/philosophy.rst new file mode 100644 index 000000000..2eefb8f5d --- /dev/null +++ b/docs/source/arch/philosophy.rst @@ -0,0 +1,16 @@ +Design Philosophy +================== + + +KISS (Keep it simple and stupid) +-------------------------------- + +AITemplate avoids deep IR lowering stacks to reduce the system's complexity. A highly modularized, multiple backend codegen system written in pure Python directly attacks the pain point in high-performance GPU inference. + +Pragmatism +---------- + +AITemplate provides a PyTorch-style frontend to enable engineers to manually match the PyTorch model & weights to AITemplate for optimization. Using it is less painful than debugging different lowering IR stacks, especially for complex models such as MaskRCNN. + + +We believe most of the neural network workload can be decoupled. For example, most of the network can be decoupled into Encoder, Decoder, and Decoder logics. For encoder and decoder, it is a computation bounded problem. For decoder logic, it may involve more control flows. By using divide and conquer, we left the decoder logic part to C++ or Python rather than build a unified language / IR stack to play as the silver bullet. \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 000000000..bf239d5d1 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,67 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = "AITemplate" +copyright = "2022, Meta Platforms" +author = "Meta Platforms" + +# The full version, including alpha/beta/rc tags +release = "0.1" + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "autodocsumm", + "sphinxcontrib.inlinesyntaxhighlight", + "sphinx_toolbox.code", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'alabaster' +html_theme = "sphinx_rtd_theme" + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] diff --git a/docs/source/debughints.rst b/docs/source/debughints.rst new file mode 100644 index 000000000..074254a75 --- /dev/null +++ b/docs/source/debughints.rst @@ -0,0 +1,14 @@ +Debug Hints +=========== + +AITemplate is a new project under active development. We have a rich test set to avoid bugs but don't be surprised if there is anything unexpected. + +Here are some helpful tips when we learned during the development AITemplate: + +1. Once the codegen for op which requires profiling is changed, remember to delete old profilers (usually located at workdir), and flush the cache by either deleting ~/.aitemplate or setting environment variable FLUSH_PROFILE_CACHE=1 + +2. Check the pseudo code/visualization generated by each optimization pass if some optimization is harmful. + +3. Always do the numerical test, from small to large, to make sure the entire model is correct. + +4. Try to make the new fusion subgraph work in a manual way, then try to add an automatic pass to rewrite the graph with the fused subgraph. \ No newline at end of file diff --git a/docs/source/genindex.rst b/docs/source/genindex.rst new file mode 100644 index 000000000..66a235227 --- /dev/null +++ b/docs/source/genindex.rst @@ -0,0 +1,2 @@ +Index +===== \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 000000000..c8e070eac --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,44 @@ + +AITemplate Documentation +====================================== + +AITemplate (AIT) is a Python framework that transforms deep neural networks into CUDA (NVIDIA GPU) / HIP (AMD GPU) C++ code for lightning-fast inference serving. AITemplate highlights include: + +* High performance: close to roofline fp16 TensorCore (NVIDIA GPU) / MatrixCore (AMD GPU) performance on major models, including ResNet, MaskRCNN, BERT, VisionTransformer, Stable Diffusion, etc. +* Unified, open, and flexible. Seamless fp16 deep neural network models for NVIDIA GPU or AMD GPU. Fully open source, Lego-style easy extendable high-performance primitives for new model support. Supports a significantly more comprehensive range of fusions than existing solutions for both GPU platforms. + + +.. toctree:: + :maxdepth: 1 + :caption: : Getting Started + + install/index + + +.. toctree:: + :maxdepth: 1 + :caption: User Guide + + tutorial/index + debughints + +.. toctree:: + :maxdepth: 1 + :caption: Runtime Design + + runtime/index + +.. toctree:: + :maxdepth: 1 + :caption: Architecture Guide + + arch/index + + +.. toctree:: + :maxdepth: 1 + :caption: Reference Guide + + reference/index + reference/env + genindex \ No newline at end of file diff --git a/docs/source/install/index.rst b/docs/source/install/index.rst new file mode 100644 index 000000000..862212889 --- /dev/null +++ b/docs/source/install/index.rst @@ -0,0 +1,64 @@ +Installing AITemplate +===================== + +Using Docker +------------ + +The easiest way to get started is to use Docker. Using docker is able to avoid performance regression caused by incorrect version of NVCC and HIPCC. +To use docker, we provide a bash script to build the docker image. + +- CUDA: + .. code-block:: bash + + ./docker/build.sh cuda +- ROCM: + .. code-block:: bash + + DOCKER_BUILDKIT=1 ./docker/build.sh rocm + + +This will build a docker image with tag `ait:latest`. + +To launch the docker container + +- CUDA: + .. code-block:: bash + + docker run --gpus all -it ait:latest + +- ROCM: + .. code-block:: bash + + docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined ait:latest + +AITemplate will be installed in as a Python package to Python 3.8. There will be also a copy of source code and examples at `/AITemplate` + + +Install as standard Python package +---------------------------------- + +Before start installing AITemplate, first make sure you have correct hardware and software environment. + +- Hardware + - NVIDIA: AIT is only tested on SM80+ GPUs (Ampere etc). + - AMD: AIT is only tested on CDNA2 (MI-210/250) GPUs. + +.. warning:: + - Not all kernels work with old SM75/SM70 (T4/V100) GPUs. + - There may be compiler issues for old CDNA1 (MI-100) GPUs. + +- Software + - NVIDIA: CUDA 11.6 + - AMD: ROCm 5.2, with HIPCC 10736 (commit `b0f4678b9058a4ae00200dfb1de0da5f2ea84dcb`) + +.. warning:: + - Incorrect compiler version will lead performance regression. + - Instruction for building HIPCC 10736 can be founded in `docker/Dockerfile.rocm` + +Then build Python wheel package and install. + + .. code-block:: bash + + cd python + python setup.py bdist_wheel + pip install dist/aitemplate-0.0.1-py3-none-any.whl diff --git a/docs/source/reference/backend.rst b/docs/source/reference/backend.rst new file mode 100644 index 000000000..fefac5d3f --- /dev/null +++ b/docs/source/reference/backend.rst @@ -0,0 +1,60 @@ +aitemplate.backend +=========================== + +aitemplate.backend.task_runner +------------------------------ +.. automodule:: aitemplate.backend.task_runner + :members: + :imported-members: + :exclude-members: OrderedDict + :autosummary: + + +aitemplate.backend.builder +-------------------------- +.. automodule:: aitemplate.backend.builder + :members: + :imported-members: + :exclude-members: BaseRunner, Target + :autosummary: + + +aitemplate.backend.codegen +--------------------------- +.. automodule:: aitemplate.backend.codegen + :members: + :imported-members: + :exclude-members: Tensor, Target + :autosummary: + +aitemplate.backend.profiler_cache +---------------------------------- +.. automodule:: aitemplate.backend.profiler_cache + :members: + :imported-members: + :exclude-members: + :autosummary: + +aitemplate.backend.profiler_runner +----------------------------------- +.. automodule:: aitemplate.backend.profiler_runner + :members: + :imported-members: + :exclude-members: Target, Task, namedtuple, BaseRunner + :autosummary: + +aitemplate.backend.registry +---------------------------- +.. automodule:: aitemplate.backend.registry + :members: + :imported-members: + :exclude-members: + :autosummary: + +aitemplate.backend.target +-------------------------- +.. automodule:: aitemplate.backend.target + :members: + :imported-members: + :exclude-members: + :autosummary: diff --git a/docs/source/reference/compiler.rst b/docs/source/reference/compiler.rst new file mode 100644 index 000000000..7b41c26b9 --- /dev/null +++ b/docs/source/reference/compiler.rst @@ -0,0 +1,37 @@ +aitemplate.compiler +============================== + + +base +------------------------ +.. automodule:: aitemplate.compiler.base + :members: + :imported-members: + :exclude-members: ABC, Enum, abstructmethod, dataclass, pformat, reduce + :autosummary: + + +tensor_accessor +----------------------------------- +.. automodule:: aitemplate.compiler.tensor_accessor + :members: + :imported-members: + :exclude-members: IntImm, IntVar, Tensor, pformat + :autosummary: + +compiler +---------------------------- + +.. automodule:: aitemplate.compiler.compiler + :members: + :imported-members: + :exclude-members: IntImm, IntVar, Tensor, pformat, DynamicProfileStrategy + :autosummary: + +model +---------------------------- +.. automodule:: aitemplate.compiler.model + :members: + :imported-members: + :exclude-members: NamedTuple, TypeVar + :autosummary: \ No newline at end of file diff --git a/docs/source/reference/cuda.rst b/docs/source/reference/cuda.rst new file mode 100644 index 000000000..4770026eb --- /dev/null +++ b/docs/source/reference/cuda.rst @@ -0,0 +1,12 @@ +aitemplate.backend.cuda +=========================== + +target_def +---------- +.. automodule:: aitemplate.backend.cuda.target_def + :members: + :imported-members: + :exclude-members: Path, ProfileCacheDB, TargetType + :autosummary: + + diff --git a/docs/source/reference/env.rst b/docs/source/reference/env.rst new file mode 100644 index 000000000..1342becf6 --- /dev/null +++ b/docs/source/reference/env.rst @@ -0,0 +1,37 @@ +Environment Variables +===================== +AITemplate uses environment variables to control the behavior of codegen and profiling. All the environment variables used in AITemplate are listed here. + +Codegen +------- + +**NUM_BUILDERS**: The number of CPU jobs running in parallel during codegen. It controls both the profiler codegen and the final .so codegen. It's set to 12 in NIGHTLY jobs. Internally, it's set to 12 for normal tests and 24 for heavy tests. By default, the builder uses all the available CPUs for building. + +**RECOMPILE**: If set to "0", it skips compilation for the .so and reuses the previously compiled ones. It is used to speed up local testing. The default value is "1" to always recompile. + +Profiling +--------- + +**CACHE_DIR**: The directory for the profiling cache. If unset, it defaults to `~/.aitemplate`. + +**FLUSH_PROFILE_CACHE**: If set to "1", it removes the cache file and recreates an empty one. + +**DISABLE_PROFILER_CODEGEN**: Normally in CI we randomly choose two profilers to codegen. If set to "1", this flag disables profiler codegen completely to speed up long running tests so that the tests don't time out. The default value is "0". + +**CUDA_VISIBLE_DEVICES**: This one is from CUDA itself. It's used to set the number of GPU devices available for profiling. Set to "0,1,2,3,4,5,6,7" to speed up profiling. For benchmarking, it's useful to set to a particular device to lower noise. + +**HIP_VISIBLE_DEVICES**: This one is from ROCm itself. It's used to set the number of GPU devices available for profiling. Set to "0,1,2,3,4,5,6,7" to speed up profiling. For benchmarking, it's useful to set to a particular device to lower noise. + +**FORCE_PROFILE**: If set to "1", it will do profiling regarless in_ci_env and disable_profiler_codegen. For non-NIGHTLY CI, we do not do profiling, and we could use FORCE_PROFILE=1 in these CI to do runs with codegen, compile, and profile. + +OSS CI +------ + +**CI_FLAG**: It is set to "CIRCLECI" in OSS CI to indicate we're in OSS CI environment. The behavior of the profiler and codegen is different in CI to speed up testing. Profiling itself for gemm/conv ops is disabled in CI. But we still compiles two random profilers to make sure the profiler codegen is not broken. + +**BUILD_DOCS**: If set to "1", it will create a fake CUDA target to enable doc building in Github Actions. + +Miscellaneous +------------- + +**LOGLEVEL**: It is used to control the logging level in python. It's default to "INFO". "DEBUG" is useful for debugging. diff --git a/docs/source/reference/frontend.rst b/docs/source/reference/frontend.rst new file mode 100644 index 000000000..41bd777dc --- /dev/null +++ b/docs/source/reference/frontend.rst @@ -0,0 +1,14 @@ +aitemplate.frontend +==================== + +.. automodule:: aitemplate.frontend.nn + :members: + :imported-members: + :exclude-members: + :autosummary: + +.. automodule:: aitemplate.frontend.tensor + :members: + :imported-members: + :exclude-members: + :autosummary: diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst new file mode 100644 index 000000000..8eb88c925 --- /dev/null +++ b/docs/source/reference/index.rst @@ -0,0 +1,16 @@ +Python API +========== + + +.. toctree:: + :maxdepth: 2 + + compiler + ops + transform + backend + cuda + rocm + frontend + testing + utils diff --git a/docs/source/reference/ops.rst b/docs/source/reference/ops.rst new file mode 100644 index 000000000..f75510665 --- /dev/null +++ b/docs/source/reference/ops.rst @@ -0,0 +1,8 @@ +aitemplate.compiler.ops +======================== + +.. automodule:: aitemplate.compiler.ops + :members: + :imported-members: + :exclude-members: Tensor, TensorAccessor, Enum, Operator, IntImm, IntVar, IntVarTensor, wrap_dim + :autosummary: diff --git a/docs/source/reference/rocm.rst b/docs/source/reference/rocm.rst new file mode 100644 index 000000000..2dc4f6c1e --- /dev/null +++ b/docs/source/reference/rocm.rst @@ -0,0 +1,11 @@ +aitemplate.backend.rocm +=========================== + +target_def +---------- +.. automodule:: aitemplate.backend.rocm.target_def + :members: + :imported-members: + :exclude-members: + :autosummary: + diff --git a/docs/source/reference/testing.rst b/docs/source/reference/testing.rst new file mode 100644 index 000000000..042df07d2 --- /dev/null +++ b/docs/source/reference/testing.rst @@ -0,0 +1,27 @@ +aitemplate.testing +================== + +detect_target +------------- +.. automodule:: aitemplate.testing.detect_target + :members: + :imported-members: + :exclude-members: CUDA, ROCM, Popen + :autosummary: + + +benchmark_pt +------------ +.. automodule:: aitemplate.testing.benchmark_pt + :members: + :imported-members: + :exclude-members: CUDA, ROCM, Popen + :autosummary: + +benchmark_ait +------------- +.. automodule:: aitemplate.testing.benchmark_ait + :members: + :imported-members: + :exclude-members: CUDA, ROCM, Popen + :autosummary: \ No newline at end of file diff --git a/docs/source/reference/transform.rst b/docs/source/reference/transform.rst new file mode 100644 index 000000000..4614e15ca --- /dev/null +++ b/docs/source/reference/transform.rst @@ -0,0 +1,209 @@ +aitemplate.compiler.transform +============================== + + +apply_padding +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.apply_padding + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + + +bind_constants +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.bind_constants + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +constant_folding +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.constant_folding + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor + :autosummary: + +fuse_conv_elementwise +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_conv_elementwise + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor + :autosummary: + +fuse_group_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_group_ops + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, all_static_dimensions + :autosummary: + + +fuse_mm_elementwise +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_mm_elementwise + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, FuncEnum, elementwise, gemm_rcr, gemm_rcr_bias, gemm_rcr_bias_swish, copy_tensor_attributes, extract_only_one_op, get_patterns, remove_single_tensor_op_from_sorted_graph, sanitize_sorted_graph + :autosummary: + +fuse_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_ops + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, FuncEnum, elementwise, gemm_rcr, gemm_rcr_bias, gemm_rcr_bias_swish, copy_tensor_attributes, extract_only_one_op, get_patterns, remove_single_tensor_op_from_sorted_graph, sanitize_sorted_graph, layernorm_sigmoid_mul + :autosummary: + +fuse_parallel_gemms +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_parallel_gemms + :members: + :imported-members: + :exclude-members: TensorAccessor, is_static_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, FuncEnum, elementwise, gemm_rcr, gemm_rcr_bias, gemm_rcr_bias_swish, copy_tensor_attributes, extract_only_one_op, get_patterns, remove_single_tensor_op_from_sorted_graph, sanitize_sorted_graph, layernorm_sigmoid_mul + :autosummary: + +fuse_permute_bmm +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_permute_bmm + :members: + :imported-members: + :exclude-members: copy_src_op_attributes, remove_tensor_from_sorted_graph, bmm_ccr, bmm_crr, bmm_rcr, bmm_rrr, gemm_rrr, gemm_rrr_bias, permute021, TensorAccessor, is_static_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, FuncEnum, elementwise, gemm_rcr, gemm_rcr_bias, gemm_rcr_bias_swish, copy_tensor_attributes, extract_only_one_op, get_patterns, remove_single_tensor_op_from_sorted_graph, sanitize_sorted_graph, layernorm_sigmoid_mul + :autosummary: + +fuse_split +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.fuse_split + :members: + :imported-members: _fuse_split_and_strided_op + :exclude-members: IntVar, copy_src_op_attributes, remove_tensor_from_sorted_graph, bmm_ccr, bmm_crr, bmm_rcr, bmm_rrr, gemm_rrr, gemm_rrr_bias, permute021, TensorAccessor, is_static_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm, AITData, replace_tensor, FuncEnum, elementwise, gemm_rcr, gemm_rcr_bias, gemm_rcr_bias_swish, copy_tensor_attributes, extract_only_one_op, get_patterns, remove_single_tensor_op_from_sorted_graph, sanitize_sorted_graph, layernorm_sigmoid_mul + :autosummary: + +mark_param_tensor +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.mark_param_tensor + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +memory_planning +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.memory_planning + :members: + :imported-members: + :exclude-members: TensorUsageRecord, Workspace, assign_offsets_to_views_and_outputs, greedy_by_size_memory_planning, DimInfo, IntImm, Operator, Source, Tensor, gemm, defaultdict, dataclass + :autosummary: + +name_graph +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.name_graph + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + + +optimize_graph +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.optimize_graph + :members: + :imported-members: + :exclude-members: transform_strided_ops, transform_special_ops, transform_odd_alignment, transform_memory_ops, fuse_permute_bmm, fuse_parallel_gemms, fuse_mm_elementwise, apply_padding, fuse_conv_elementwise, fuse_group_ops, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + + +profile +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.profile + :members: + :imported-members: + :exclude-members: DynamicProfileStrategy, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +refine_graph +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.refine_graph + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +remove_no_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.remove_no_ops + :members: + :imported-members: + :exclude-members: IntVar, is_singleton_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +remove_unused_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.remove_unused_ops + :members: + :imported-members: + :exclude-members: deque, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +toposort +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.toposort + :members: + :imported-members: + :exclude-members: DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + + +transform_memory_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_memory_ops + :members: + :imported-members: + :exclude-members: TensorAccessor, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +transform_odd_alignment +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_odd_alignment + :members: + :imported-members: + :exclude-members: can_be_constant_folded, copy_src_op_attributes, copy_tensor_attributes, extract_only_one_op, remove_tensor_from_sorted_graph, replace_tensor, sanitize_sorted_graph, toposort, IntVar, bmm_ccr, bmm_crr, bmm_rcr, bmm_rrr, permute021, unsqueeze, TensorAccessor, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + + +transform_special_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_special_ops + :members: + :imported-members: + :exclude-members: gemm_rrr, is_singleton_dimension, gemm_rcr, gemm_rrr_small_nk, can_be_constant_folded, copy_src_op_attributes, copy_tensor_attributes, extract_only_one_op, remove_tensor_from_sorted_graph, replace_tensor, sanitize_sorted_graph, toposort, IntVar, bmm_ccr, bmm_crr, bmm_rcr, bmm_rrr, permute021, unsqueeze, TensorAccessor, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +transform_strided_op_and_view_op +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_strided_op_and_view_op + :members: + :imported-members: + :exclude-members: IntVar, is_singleton_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +transform_strided_ops +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_strided_ops + :members: + :imported-members: + :exclude-members: get_tensor_index, slice_reshape_scatter, slice_scatter, gen_tensor_index, IntVar, is_singleton_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: + +transform_strided_slice +------------------------------------------- +.. automodule:: aitemplate.compiler.transform.transform_strided_slice + :members: + :imported-members: + :exclude-members: dynamic_slice, get_tensor_index, slice_reshape_scatter, slice_scatter, gen_tensor_index, IntVar, is_singleton_dimension, DimInfo, IntImm, Operator, Source, Tensor, gemm + :autosummary: diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst new file mode 100644 index 000000000..6c35fc39d --- /dev/null +++ b/docs/source/reference/utils.rst @@ -0,0 +1,12 @@ +aitemplate.utils +================== + + +visualization.plot +------------------ +.. automodule:: aitemplate.utils.visualization.plot + :members: + :imported-members: + :exclude-members: Tensor, Operator + :autosummary: + diff --git a/docs/source/runtime/cxx_design.rst b/docs/source/runtime/cxx_design.rst new file mode 100644 index 000000000..5ef18f889 --- /dev/null +++ b/docs/source/runtime/cxx_design.rst @@ -0,0 +1,29 @@ +================== +C++ Runtime Note +================== + +`Model` v.s. `ModelContainer` +============================== + +These are the two main classes involved in the C++ runtime implementation. + +* The bulk of the runtime implementation is in `Model`. +* `ModelContainer` stores a set of shared constants and a collection of `Model`s. Almost all functions in `model_interface.h` forward to a method on `ModelContainer`. When `Run` is invoked, `ModelContainer` looks for an available `Model`, or blocks until one is available (see the section on asynchronous predictions). It then forwards the run request to the runtime. + +Code Structure +============== + +Some important files: + +1. `include/model_interface.h`: The interface that we expose in the compiled .so +2. `include/model_container.h`: The bulk of the `ModelContainer` implementation. + +Some files are generated at compile time. These include: + +* `model-generated.h`: The implementation for `Model`. +* `model_container_base.cu`: A small part of the implementation for `ModelContainer` needs to be codegened. So `ModelContainer` inherits from `ModelContainerBase`, and `ModelContainerBase`'s implementation lives in this file. See `model_container.h` for more details. + +All codegen templates can be found in `backend/main_templates.py`. The codegen implementation is in `backend/codegen.py`. + +Note that many of the headers in this directory rely on generated code and thus cannot be `#include`d in external projects. The exception is `model_interface.h`. + diff --git a/docs/source/runtime/index.rst b/docs/source/runtime/index.rst new file mode 100644 index 000000000..0dd2462ff --- /dev/null +++ b/docs/source/runtime/index.rst @@ -0,0 +1,9 @@ +Runtime Note +================== + + +.. toctree:: + :maxdepth: 1 + + cxx_design + py_design diff --git a/docs/source/runtime/py_design.rst b/docs/source/runtime/py_design.rst new file mode 100644 index 000000000..c143123de --- /dev/null +++ b/docs/source/runtime/py_design.rst @@ -0,0 +1,135 @@ +===================== +Python Runtime Note +===================== + +Python `Model` +============== + +`Model` is a collection of Python bindings to the C++ AIT runtime. This section describes the API. + +`AITData` +--------- + +This class represents a contiguous blob of memory that AIT will use as a tensor. It is simply a named tuple with these fields: + +* `data_ptr: int`: An **unowned** pointer to **GPU** memory. In general, all of the APIs expect that this pointer will be valid for the entire duration of the call. +* `shape: List[int]`: The shape of the tensor. +* `dtype: str`: The tensor's dtype; one of `"float32", "float16", "int32", "int64"`. Note that most ops only support float16 at this stage. + +If using AITemplate with PyTorch, `AITData`s can be constructed with the `torch_to_ait_data` utility: + +.. code-block:: python + + x = torch.randn(3, 3, 3).half().cuda() + # Equivalent to AITData(x.data_ptr(), [3, 3, 3], "float16") + x_ait = torch_to_ait_data(x) + + +If PyTorch is not available, `Model` provides a set of functions for copying, allocating, and freeing GPU memory. See the docstrings in `compiler/model.py` for more information. + +`run` +----- + +`run` takes a set of inputs and outputs as `AITData`s. Both arguments can be passed as either an ordered list or a dictionary (mapping name to tensor). + +.. code-block:: python + + # Arguments as a dictionary + module.run( + {"input0": in0_ait, "input1": in1_ait}, + {"output0": out0_ait, "output1": out0_ait}, + ) + + # Arguments as an ordered list. Note that you might need to query + # the index mapping. + input_name_to_idx = module.get_input_name_to_index_map() + output_name_to_idx = module.get_output_name_to_index_map() + + inputs = [None for i in range(len(input_name_to_idx))] + outputs = [None for i in range(len(input_name_to_idx))] + + for name in input_name_to_idx: + inputs[input_name_to_idx[name]] = ait_inputs[name] + + for name in output_name_to_idx: + outputs[output_name_to_idx[name]] = ait_outputs[name] + + module.run(inputs, outputs) + + +One important caveat is that the output must be its **maximum** size. This is because of dynamic shapes - the size of the output may vary, but its shape is not inferred until inference time. The maximum shape can be queried with the `get_output_maximum_shape()`: + +.. code-block:: python + + # Can use either name or index. + name_to_idx = module.get_output_name_to_idx() + max_shape = module.get_output_maximum_shape(name_to_idx["output"]) + max_shape = module.get_output_maximum_shape("output") + + +`Model.run` returns a dictionary of output `AITData`s with (possibly dynamic) shapes that the runtime inferred. + +Nullptr Inputs/Outputs +---------------------- + +In general, inputs are allowed to be null if they are size 0 (e.g. at least one dimension is 0). The runtime enforces this with a check before any kernels are launched. + +.. code-block:: cpp + + If (input_name == nullptr && dim0 * dim1 * … * dimN != 0) { + throw std::runtime_error(“input_name cannot be null!”); + } + + +This is convenient since torch.data_ptr() returns null for size zero tensors. The dynamic shape computation is skipped if the lower bound of the tensor’s size is positive. + +Constants +--------- + +There are two types of constants in AIT; *bound* and *unbound* constants. A bound constant is known at compile time and may participate in constant folding. Bound constants are copied into GPU memory at model loading time. Values for bound constants may be provided by passing a dictionary (mapping constant name to AIT tensor) to `compile_model`. + +Unbound constants, on the other hand, do not participate in constant folding and must be provided before running the model. These must be set via `Model.set_constant`: + +.. code-block:: python + + module.set_constant("my_constant", AITData(...)) + # The pointer in the the tensor must live for the entire duration of run() + module.run(...) + + +Constants are read-only and *shared* with all runtimes in the `ModelContainer`. + +`run_with_tensors` +------------------ + +`run_with_tensors` is a convenience method with the same interface as `run`, except it can take lists of `torch.Tensor`s: + +.. code-block:: python + + input0 = torch.randn(input0_shape).cuda().half() + output0 = torch.empty(output0_shape).cuda().half() + # Returns a dictionary of reshaped outputs + result = module.run_with_tensors([input0], [output0]) + + +Streams and Asynchronous Predictions +------------------------------------ + +A pointer to a stream can optionally be passed to `run`. If none is given, the prediction happens on the default stream 0. If the `sync` argument is set to `True`, the stream is synchronized before `run()` returns. `sync` is `True` by default. + +Multiple predictions can happen at the same time (on the same or different streams). Under the hood, there is a fixed-size pool of runtime objects. When all the runtimes are used, `run()` blocks until one is available. +The size of this pool can be configured with the `num_runtimes` option in `Model`'s constructor. + +CUDA Graph +---------- + +Run also takes a `graph_mode` option. If set to true, the runtime will try to use [CUDA graphs](https://developer.nvidia.com/blog/cuda-graphs/) to run the model. `graph_mode` is not supported on ROCm. + +The following is a high level overview of how graph mode works: + +1) Each `Model` has an internal stream used for graph capturing. The model first runs all ops on this stream in capture mode. No kernel launches happen during this stage. +2) If this is the first run, a graph is instantiated via `cudaGraphInstantiate`. +3) On subsequent runs, we try to avoid the relatively expensive `cudaGraphInstantiate` call by updating the graph executor (`cudaGraphExecUpdate`). However, a new graph may still be instantiated if the topology of the graph somehow changed between runs. +4) Once we have the graph executor, we launch a single kernel on the stream that the user provided to `run()`. + +Graph mode is mainly beneficial when there are many small kernel launches. A lot of overhead can be avoided since there is only a single kernel launch in graph mode. diff --git a/docs/source/tutorial/how_to_add_op.rst b/docs/source/tutorial/how_to_add_op.rst new file mode 100644 index 000000000..160745336 --- /dev/null +++ b/docs/source/tutorial/how_to_add_op.rst @@ -0,0 +1,302 @@ +How to add an operator to the AIT codegen +========================================= + +This tutorial will demonstrate how to add a new operator to the AIT codegen. +Full source code can be founded at `examples/07_how_to_run_pt_model/how_to_run_pt_model.py` + + +0. Prerequisites +----------------- + +We need to import necessary Python modules + +.. code-block:: python + + from typing import Any, Dict, List + + import jinja2 + import torch + + from aitemplate import backend + from aitemplate.backend import registry + from aitemplate.backend.backend_spec import CUDASpec, ROCMSpec + from aitemplate.compiler import compile_model + from aitemplate.compiler.base import IntVar, Operator, Tensor + from aitemplate.testing import detect_target + + +1. Define the operator graph node +---------------------------------- + +Graph node is usually defined at `aitemplate/compiler/ops`. + +.. code-block:: python + + class add_one(Operator): + def __init__(self): + super().__init__() + # required, unique identity of operator category + self._attrs["op"] = "add_one" + # we can put whatever we want into the op attrs for later use + self._attrs["has_profiler"] = False + self._attrs["nop"] = False + + def __call__(self, x: Tensor) -> Tensor: + # each operator needs to keep a record of input tensors + self._attrs["inputs"] = [x] + # optional, to set depth of the op based on inputs' depth, used in DFS + self._set_depth() + # infer output shape + output_shape = self._infer_shape(x) + # create output Tensor, of which the source op is the current op + output = Tensor(output_shape, src_ops={self}) + # remember current op's outputs + self._attrs["outputs"] = [output] + return output + + def _infer_shape(self, x) -> List[IntVar]: + # infer output shape + # In case of we need infer shape in C++ side, we will create a jinja2 template + # for shape inference function, and render to Python code in graph node + # and render the template into C++ code in codegen + return x.shape() + + def gen_function(self) -> str: + # this function will be used in codegen + # here we only need to redirect to backend codegen function + target = backend.target.Target.current() + func_key = f"{target.name()}.{self._attrs['op']}.gen_function" + func = registry.get(func_key) + return func(self._attrs) + +.. note:: + + - `_attrs` in Operator is the most important data structure for codegen. + - `_attrs["op"]` is the identity of operator category, which is used to find the corresponding codegen function in backend, must be **unique**. + +2. Define the necessary templates for Codegen +---------------------------------------------- + +In AIT, there are 4 important templates for codegen: + +- `FUNC_TEMPLATE`: the template for generating the function body of the operator, and invoke GPU kernel in the body. +- `FUNC_SIGNATURE_TEMPLATE`: the template for generating the function signature of the operator. The signature defined name, and arguments of the function. +- `FUNC_CALL_TEMPLATE`: the template for generating the function call of the operator. The call will be used during inference to invoke the GPU kernel with given arguments. +- `FUNC_DECL`: the template for forward declaration of the operator function. This is usually an alias of `FUNC_SIGNATURE_TEMPLATE`. + +.. code-block:: python + + FUNC_TEMPLATE = jinja2.Template( + """ + {{header_files}} + namespace { + {{kernel}} + } // namespace + {{func_signature}} + { + invoke_add_one(output, input, num_elements, stream); + } + """ + ) + + FUNC_SIGNATURE = jinja2.Template( + """ + void {{func_name}}(half* output, + const half* input, + const int64_t num_elements, + {{prefix}}Stream_t stream) + """ + ) + + FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ + ) + + + FUNC_CALL_TEMPLATE = jinja2.Template( + """ + {{indent}}int64_t num_elements = 1; + {% for dim_name in dim_names %} + {{indent}}num_elements *= {{dim_name}}; + {% endfor %} + {{indent}}{{func_name}}( + {{indent}} {{output}}, {{input}}, num_elements, stream /* default stream */ + {{indent}}); + """ + ) + +3. Create the GPU kernels +-------------------------- + +In this example we use a simplest add one kernel. The kernel can be written by hand (as what programmer is expected to do), or generated by other tools. + +.. code-block:: python + + KERNEL_TEMPLATE = jinja2.Template( + """ + __global__ void add_one(half* output, const half* input, const int64_t num_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + output[idx] = input[idx] + half(1.0); + } + } + void invoke_add_one(half* output, const half* input, int64_t num_elements, {{prefix}}Stream_t stream) { + if (num_elements < 1024) { + dim3 grid(1); + dim3 block(num_elements); + add_one<<>>(output, input, num_elements); + } else { + dim3 grid((num_elements + 1024 - 1) / 1024); + dim3 block(1024); + add_one<<>>(output, input, num_elements); + } + } + """ + ) + +(Optional) We also provide a helper function to handle CUDA/ROCm float16 data type difference. + +.. code-block:: python + + FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + """reinterpret_cast( + {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})""" + ) + +4. Define the codegen function +------------------------------- + +The codegen function is the function that render the templates we defined into valid C++ code string. +The codegen function will take `func_attrs` from graph node, and fill into the jinja2 template. + +.. code-block:: python + + def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> str: + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 1 + + output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda + ) + input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda + ) + + dim_names = [dim._attrs["name"] for dim in func_attrs["inputs"][0].shape()] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + dim_names=dim_names, + indent=indent, + ) + + + def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + prefix = backend_spec.prefix + return FUNC_TEMPLATE.render( + header_files=header_files, + kernel=KERNEL_TEMPLATE.render(prefix=prefix), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], prefix=prefix + ), + ) + + + def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + ).strip() + ) + +5.1 Register the codegen function to CUDA backend +--------------------------------------------------- + +CUDA backend functions is usually defined at `aitemplate/backend/cuda/`. + +.. code-block:: python + + CUDA_HEADER_FILES = """ + #include + """ + + + @registry.reg("cuda.add_one.gen_function") + def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: + return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec()) + + + @registry.reg("cuda.add_one.func_decl") + def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: + return gen_function_decl(func_attrs, CUDASpec()) + + + @registry.reg("cuda.add_one.func_call") + def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + return gen_function_call(func_attrs, indent, is_cuda=True) + +5.2 (Optional) Register the codegen function to ROCm backend +-------------------------------------------------------------- + +ROCm backend functions is usually defined at `aitemplate/backend/rocm/`. + + +.. code-block:: python + + HIP_HEADER_FILES = """ + #include + #include + """ + + + @registry.reg("rocm.add_one.gen_function") + def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: + return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec()) + + + @registry.reg("rocm.add_one.func_decl") + def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: + return gen_function_decl(func_attrs, ROCMSpec()) + + + @registry.reg("rocm.add_one.func_call") + def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + return gen_function_call(func_attrs, indent, is_cuda=False) + + +6. Compile and verify the results with PyTorch +------------------------------------------------ + +.. code-block:: python + + def create_ait_model(shapes): + X = Tensor( + shape=shapes, + dtype="float16", + name="X", + is_input=True, + ) + Y = add_one()(X) + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + return Y + + + def verify_add_one(): + shapes = [16, 512] + x = torch.randn(shapes).cuda().half() + y_pt = x + 1.0 + + Y = create_ait_model([16, 512]) + target = detect_target() + with compile_model(Y, target, "./tmp", "add_one") as module: + y = torch.empty(shapes).cuda().half() + inputs = {"X": x} + outputs = {"Y": y} + module.run_with_tensors(inputs, outputs) + print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2)) + diff --git a/docs/source/tutorial/how_to_infer_pt.rst b/docs/source/tutorial/how_to_infer_pt.rst new file mode 100644 index 000000000..67891c46a --- /dev/null +++ b/docs/source/tutorial/how_to_infer_pt.rst @@ -0,0 +1,188 @@ +How to inference a PyTorch model with AIT +========================================== + +This tutorial will demonstrate how to inference a PyTorch model with AIT. +Full source code can be founded at `examples/07_how_to_run_pt_model/how_to_run_pt_model.py` + +0. Prerequisites +----------------- + +We need to import necessary Python modules + +.. code-block:: python + + from collections import OrderedDict + + import torch + + from aitemplate.compiler import compile_model + from aitemplate.frontend import nn, Tensor + from aitemplate.testing import detect_target + from aitemplate.testing.benchmark_pt import benchmark_torch_function + from aitemplate.utils.graph_utils import sorted_graph_pseudo_code + + +1. Define a PyTorch module +--------------------------- + +Here we define a PyTorch model which is commonly seen in Transformers. + +.. code-block:: python + + class PTSimpleModel(torch.nn.Module): + def __init__(self, hidden, eps: float = 1e-5): + super().__init__() + self.dense1 = torch.nn.Linear(hidden, 4 * hidden) + self.act1 = torch.nn.functional.gelu + self.dense2 = torch.nn.Linear(4 * hidden, hidden) + self.layernorm = torch.nn.LayerNorm(hidden, eps=eps) + + def forward(self, input): + hidden_states = self.dense1(input) + hidden_states = self.act1(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = hidden_states + input + hidden_states = self.layernorm(hidden_states) + return hidden_states + +2. Define an AIT module +------------------------ + +We can define a similar AIT module as follows: + +.. code-block:: python + + class AITSimpleModel(nn.Module): + def __init__(self, hidden, eps: float = 1e-5): + super().__init__() + self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu") + self.dense2 = nn.Linear(4 * hidden, hidden) + self.layernorm = nn.LayerNorm(hidden, eps=eps) + + def forward(self, input): + hidden_states = self.dense1(input) + hidden_states = self.dense2(hidden_states) + hidden_states = hidden_states + input + hidden_states = self.layernorm(hidden_states) + return hidden_states + +.. warning:: + The `nn.Module` API in AIT looks similar to PyTorch, but it is not the same. + + The fundamental difference is that AIT module is a container to build graph, while PyTorch module is a container to store parameters for eager. + Which means, each AIT module's `forward` method can be only called once, and the graph is built during the first call. If you want to share parameters, needs to call `compiler.ops` instead. The `compiler.ops` is similar to `functional` in PyTorch. + + AITemplate supports automatically fusion on linear followed by other operators. However in many case especially for quick iterations, we use manual `specialization` to specify the fused operator. For example, `specialization="fast_gelu"` will fuse linear with `fast_gelu` operator. + +3. Define a helper function to map PyTorch parameters to AIT parameters +------------------------------------------------------------------------- + +In AIT, all names must follow C variable naming standard because the name will be used in codegen process. + +.. code-block:: python + + def map_pt_params(ait_model, pt_model): + ait_model.name_parameter_tensor() + pt_params = dict(pt_model.named_parameters()) + mapped_pt_params = OrderedDict() + for name, _ in ait_model.named_parameters(): + ait_name = name.replace(".", "_") + assert name in pt_params + mapped_pt_params[ait_name] = pt_params[name] + return mapped_pt_params + +.. warning:: + + - Different to PyTorch, it is required to call ait_model **.name_parameter_tensor()** method to provide each parameter a name with direct map to PyTorch. + - Because all names in AIT must follow C variable naming standard, you can easier replace `.` to `_` or use a regular expression to make sure the name in valid. + - For network with conv + bn subgraph, we currently haven't provide automatic pass to fold it. Refer our ResNet and Detectron2 examples to see how we handle CNN layout transform and BatchNorm folding. + +4. Create PyTorch module, inputs/outputs +----------------------------------------- + +.. code-block:: python + + batch_size=1024 + hidden=512 + # create pt model + pt_model = PTSimpleModel(hidden).cuda().half() + + # create pt input + x = torch.randn([batch_size, hidden]).cuda().half() + + # run pt model + pt_model.eval() + y_pt = pt_model(x) + +5. Create AIT module, inputs/outputs +------------------------------------- + +.. code-block:: python + + batch_size=1024 + hidden=512 + # create AIT model + ait_model = AITSimpleModel(hidden) + # create AIT input Tensor + X = Tensor( + shape=[batch_size, hidden], + name="X", + dtype="float16", + is_input=True, + ) + # run AIT module to generate output tensor + Y = ait_model(X) + # mark the output tensor + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + +.. warning:: + + - Similar to MetaTensor, LazyTensor and a lot of other lazy evaluation frameworks, AIT's Tensor records the computation graph, and the graph is built when the Tensor is compiled. + - For input tensor, it is required to set the attribute **is_input=True** + - For output tensor, it is required to set the attribute **Y._attrs["is_output"] = True** + - For input and output tensors, it is better to provide **name** attributes to use in runtime + +6. Compile AIT module in to runtime, and do verification +-------------------------------------------------------- + +.. code-block:: python + + # map pt weights to ait + weights = map_pt_params(ait_model, pt_model) + + # codegen + target = detect_target() + with compile_model( + Y, target, "./tmp", "simple_model_demo", constants=weights + ) as module: + # create storage for output tensor + y = torch.empty([batch_size, hidden]).cuda().half() + + # inputs and outputs dict + inputs = {"X": x} + outputs = {"Y": y} + + # run + module.run_with_tensors(inputs, outputs, graph_mode=True) + + # verify output is correct + print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2)) + + # benchmark ait and pt + count = 1000 + ait_t, _, _ = module.benchmark_with_tensors( + inputs, outputs, graph_mode=True, count=count + ) + print(f"AITemplate time: {ait_t} ms/iter") + + pt_t = benchmark_torch_function(count, pt_model.forward, x) + print(f"PyTorch eager time: {pt_t} ms/iter") + + +In this example, AIT will automatically fuse GELU and elementwise add into TensorCore/MatrixCore gemm operation. On RTX-3080 for this example, AIT is about 1.15X fast than PyTorch Eager in this example. + +.. note:: + + - In this example, we fold parameters (weights) into AIT runtime, which the final dynamic library will contains parameters. + - If during compile we don't provide parameters, for example the total parameters size is greater than 2GB, we can always call `set_constant` function in runtime. Check runtime API for details. \ No newline at end of file diff --git a/docs/source/tutorial/how_to_visualize.rst b/docs/source/tutorial/how_to_visualize.rst new file mode 100644 index 000000000..5af7c89a5 --- /dev/null +++ b/docs/source/tutorial/how_to_visualize.rst @@ -0,0 +1,85 @@ +How to visualize an AIT model +============================== + +Visualization is important for understanding the behavior of a model optimization. +In AIT, we modify the codegen a little bit, from generating CUDA/HIP C++ code to HTML/Javascript code, +then we can generate a visualization of the model. + + +The following code will generate a visualization of our first example. + +1. Define the AIT Model +------------------------ + +.. code-block:: python + + from aitemplate import compiler + from aitemplate.frontend import nn, Tensor + from aitemplate.testing import detect_target + from aitemplate.utils.visualization import plot_graph + + class AITSimpleModel(nn.Module): + def __init__(self, hidden, eps: float = 1e-5): + super().__init__() + self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu") + self.dense2 = nn.Linear(4 * hidden, hidden) + self.layernorm = nn.LayerNorm(hidden, eps=eps) + + def forward(self, input): + hidden_states = self.dense1(input) + hidden_states = self.dense2(hidden_states) + hidden_states = hidden_states + input + hidden_states = self.layernorm(hidden_states) + return hidden_states + + def gen_ait_model(): + batch_size = 512 + hidden = 1024 + ait_model = AITSimpleModel(hidden) + ait_model.name_parameter_tensor() + X = Tensor( + shape=[batch_size, hidden], + name="X", + dtype="float16", + is_input=True, + ) + Y = ait_model(X) + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + return Y + + output_tensor = gen_ait_model() + +2. Apply optimizations on the AIT Model +--------------------------------------- + +.. code-block:: python + + def apply_optimizations(tensors): + target = detect_target() + # first, convert output tensors to graph + with target: + graph = compiler.transform.toposort(tensors) + # second, provide names to the graph + compiler.transform.name_graph(graph) + compiler.transform.mark_param_tensor(graph) + compiler.transform.mark_special_views(graph) + # we can apply optimizations to the graph, or test single optimization pass on the graph + graph = compiler.transform.optimize_graph(graph, "./tmp") + return graph + + graph = apply_optimizations(output_tensor) + +3. Generate visualization +-------------------------- + +.. code-block:: python + + # Plot the graph + plot_graph(graph, file_path="ait_model.html", network_name="ait_sample_net") + +The visualization will be generated in the "ait_model.html" file. This file can be opened in Chrome without any web server. + +.. raw:: html + + \ No newline at end of file diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst new file mode 100644 index 000000000..339bd16c1 --- /dev/null +++ b/docs/source/tutorial/index.rst @@ -0,0 +1,9 @@ +Tutorials +========= + +.. toctree:: + :maxdepth: 1 + + how_to_infer_pt + how_to_add_op + how_to_visualize diff --git a/docs/static/ait_model.html b/docs/static/ait_model.html new file mode 100644 index 000000000..18c56089d --- /dev/null +++ b/docs/static/ait_model.html @@ -0,0 +1,866 @@ + + + + + + + ait_sample_net + + + + + + + + + + + + + + + + + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/01_resnet-50/README.md b/examples/01_resnet-50/README.md new file mode 100644 index 000000000..3f75060ff --- /dev/null +++ b/examples/01_resnet-50/README.md @@ -0,0 +1,84 @@ +# ResNet-50 + +In this example, we will demo how to use AITemplate for inference on the ResNet-50 model from PyTorch Image Models (TIMM). + +We will demo two usages: +* Using AIT to accelerate PyTorch inference +* Using AIT standalone without PyTorch + +## Code structure +``` +modeling + resnet.py # ResNet definition using AIT's frontend API +weight_utils.py # Utils to convert TIMM R-50 weights to AIT +infer_with_torch.py # Example to accelerate PyTorch, and seamlessly use with other PyTorch code +infer_with_numpy.py # Dump TIMM weights to Numpy and use AIT & Numpy without 3rdparties +benchmark_pt.py # Benchmark code for PyTorch +benchmark_ait.py # Benchmark code for AIT +``` + +## Multi-GPU profiling +AIT requires to do profiling to decide best algorithms for CUTLASS and CK. +To enable multiple GPUs profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform. + +For example, `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py`. + +Benchmark is fast once the profilers are built. + +## Reference Speed vs PyTorch Eager + +### A100-40GB / CUDA 11.6.2 +_PT = PyTorch 1.12 Eager_ + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 7.68 | 130.29 | 0.58 | 1730.17 | +| 2 | 7.16 | 279.36 | 0.62 | 3250.74 | +| 4 | 7.17 | 557.68 | 0.69 | 5773.20 | +| 8 | 7.02 | 1138.83 | 0.88 | 9104.44 | +| 16 | 7.01 | 2280.97 | 1.33 | 12012.81 | +| 32 | 7.53 | 4251.30 | 2.40 | 13350.58 | +| 64 | 13.98 | 4578.09 | 4.53 | 14140.83 | +| 128 | 26.57 | 4816.71 | 8.57 | 14935.82 | +| 256 | 50.93 | 5026.40 | 16.58 | 15444.57 | + + +### MI-250 / ROCm 5.2.3 / HIPCC-10736 +_PT = PyTorch 1.12 Eager_ +#### 1 GCD + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 3.94 | 254.06 | 2.28 | 438.60 | +| 2 | 3.89 | 514.48 | 2.25 | 888.89 | +| 4 | 3.82 | 1047.11 | 2.38 | 1680.67 | +| 8 | 4.40 | 1819.27 | 2.62 | 3053.44 | +| 16 | 6.48 | 2468.65 | 3.41 | 4692.08 | +| 32 | 10.40 | 3076.97 | 4.86 | 6584.36 | +| 64 | 18.35 | 3488.12 | 8.26 | 7748.18 | +| 128 | 34.36 | 3724.76 | 15.38 | 8322.50 | +| 256 | 65.35 | 3917.29 | 29.62 | 8642.81 | + +#### 2 GCDs + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | | | | | +| 2 | 3.94 | 507.54 | 2.36 | 848.15 | +| 4 | 3.89 | 1028.60 | 2.34 | 1710.94 | +| 8 | 3.88 | 2059.41 | 2.70 | 2960.46 | +| 16 | 4.56 | 3507.48 | 2.83 | 5663.52 | +| 32 | 6.72 | 4762.89 | 3.87 | 8275.98 | +| 64 | 10.82 | 5917.63 | 5.26 | 12173.67 | +| 128 | 18.79 | 6812.09 | 8.98 | 14247.09 | +| 256 | 35.99 | 7112.59 | 16.69 | 15338.58 | + + + +### Note for Performance Results + +- For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. +- To benchmark MI-250, the first step is to run `python3 benchmark_ait.py` to generate all necessary model dynamic library files with single GCD. Then run `./benchmark_mi250.sh {batch_size}` to simulate data parallel execution on 2 GCDs, each GCD is processing half of the batch. +- To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. +- To benchmark MI-250 2 GCDs, we observed performance regression with rocm perf-determ mode. The 2 GCDs number is running without perf-determ mode set with command `rocm-smi -d x --resetperfdeterminism`, where `x` is the GPU id. +- Performance results are what we can reproduce and for reference only. It should not be used for other purposes. diff --git a/examples/01_resnet-50/benchmark_ait.py b/examples/01_resnet-50/benchmark_ait.py new file mode 100644 index 000000000..577a4472d --- /dev/null +++ b/examples/01_resnet-50/benchmark_ait.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""benchmark for resnet50""" + +import os + +import click + +import torch +from aitemplate.compiler import compile_model, Model + +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from modeling.resnet import build_resnet_backbone +from weight_utils import export_to_torch_tensor + + +def mark_output(y): + """Different to PyTorch, we need to explicit mark output tensor for optimization, + + Parameters + ---------- + y : List[Tensor] + List of output tensors + """ + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +def compile_module(model_name, batch_size, **kwargs): + + if model_name != "resnet50": + raise NotImplementedError + + model_name = f"{model_name}_{batch_size}" + target = detect_target(**kwargs) + # Create input tensor, need to specify the shape, dtype and is_input flag + x = Tensor( + shape=[batch_size, 224, 224, 3], dtype="float16", name="input0", is_input=True + ) + model = build_resnet_backbone(50, activation="ReLU") + # Mark all parameters with name same to PyTorch name convention + model.name_parameter_tensor() + # Forward the input tensor to the model, get output tensor + y = model(x) + # Mark output tensor + mark_output(y) + # Compile the model + module = compile_model(y, target, "./tmp", model_name) + return module + + +def benchmark(model_name, batch_size, mod=None, graph_mode=True): + # Load params + cuda_params = export_to_torch_tensor(model_name) + # Load compiled model + if mod is None: + model_name = f"{model_name}_{batch_size}" + mod = Model(os.path.join("./tmp", model_name, "test.so")) + + # Set params + for k, v in cuda_params.items(): + mod.set_constant_with_tensor(k, v) + + # prepare input/output tensor + x_input = torch.randn([batch_size, 224, 224, 3]).cuda().half() + x_input = x_input.contiguous() + y_output = torch.zeros([batch_size, 1, 1, 1000]).cuda().half() + y_output = y_output.contiguous() + + # warm up + t, _, __ = mod.benchmark_with_tensors( + [x_input], + [y_output], + count=100, + repeat=4, + graph_mode=graph_mode, + ) + # benchmark + t, _, __ = mod.benchmark_with_tensors( + [x_input], + [y_output], + count=100, + repeat=4, + graph_mode=graph_mode, + ) + print(f"batch_size: {batch_size}, latency: {t}") + dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1") + dev_flag = dev_flag.replace(",", "_") + with open(f"resnet50_ait_benchmark_dev_{dev_flag}.txt", "a") as f: + f.write(f"batch_size: {batch_size}, latency: {t}\n") + + +@click.command() +@click.option( + "--use-fp16-acc", + type=bool, + default=True, + help="Whether to use FP16 for accumulation (similar to TensorRT)", +) +@click.option("--use-graph", type=bool, default=True, help="Whether to use CUDA graph") +@click.option("--batch-size", type=int, default=0, help="Batch size") +def main(use_fp16_acc=True, use_graph=True, batch_size=0): + if detect_target().name() == "rocm": + use_graph = False + if batch_size < 1: + for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256): + compile_module("resnet50", bs, use_fp16_acc=use_fp16_acc) + benchmark("resnet50", bs, graph_mode=use_graph) + else: + benchmark("resnet50", batch_size, graph_mode=use_graph) + + +if __name__ == "__main__": + main() diff --git a/examples/01_resnet-50/benchmark_mi250.sh b/examples/01_resnet-50/benchmark_mi250.sh new file mode 100644 index 000000000..883846b68 --- /dev/null +++ b/examples/01_resnet-50/benchmark_mi250.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & +HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg diff --git a/examples/01_resnet-50/benchmark_pt.py b/examples/01_resnet-50/benchmark_pt.py new file mode 100644 index 000000000..82c74bc89 --- /dev/null +++ b/examples/01_resnet-50/benchmark_pt.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import click +import timm +import torch +from aitemplate.testing.benchmark_pt import benchmark_torch_function + + +def benchmark(model, batch_size): + with torch.inference_mode(): + input_shape = (batch_size, 3, 224, 224) + input_data = torch.randn(input_shape).cuda().half() + # warm up + benchmark_torch_function(100, model, input_data) + # benchmark + t = benchmark_torch_function(100, model, input_data) + print("batch_size: {}, time: {}".format(batch_size, t)) + dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1") + dev_flag = dev_flag.replace(",", "_") + with open(f"resnet50_pt_benchmark_dev_{dev_flag}.txt", "a") as f: + f.write("batch_size: {}, latency: {}\n".format(batch_size, t)) + + +@click.command() +@click.option("--batch-size", default=0, type=int) +def main(batch_size): + model = timm.create_model("resnet50", pretrained=False).cuda().half() + model.eval() + if batch_size == 0: + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256]: + benchmark(model, batch_size) + else: + benchmark(model, batch_size) + + +if __name__ == "__main__": + main() diff --git a/examples/01_resnet-50/infer_with_torch.py b/examples/01_resnet-50/infer_with_torch.py new file mode 100644 index 000000000..23269b2e4 --- /dev/null +++ b/examples/01_resnet-50/infer_with_torch.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import numpy as np +import torch +from aitemplate.compiler import compile_model, Model + +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from modeling.resnet import build_resnet_backbone +from PIL import Image +from weight_utils import timm_export + + +def mark_output(y): + """Different to PyTorch, we need to explicit mark output tensor for optimization, + + Parameters + ---------- + y : List[Tensor] + List of output tensors + """ + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +def compile_module(model_name, **kwargs): + batch_size = 1 + + if model_name != "resnet50": + raise NotImplementedError + + model_name = f"{model_name}_{batch_size}" + target = detect_target(**kwargs) + # Create input tensor, need to specify the shape, dtype and is_input flag + x = Tensor( + shape=[batch_size, 224, 224, 3], dtype="float16", name="input0", is_input=True + ) + model = build_resnet_backbone(50, activation="ReLU") + # Mark all parameters with name same to PyTorch name convention + model.name_parameter_tensor() + # Forward the input tensor to the model, get output tensor + y = model(x) + # Mark output tensor + mark_output(y) + # Compile the model + module = compile_model(y, target, "./tmp", model_name) + return module + + +def prepare_data(img_path=None): + # we find a 224x224 image online for demo purpose: + img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" + if img_path is None: + if os.path.exists("cat.png") is False: + os.system(f"wget -O cat.png {img_url}") + img_path = "cat.png" + image = Image.open(img_path).resize((224, 224)) + image = torch.as_tensor(np.array(image).astype("float32")).cuda().half() + image = image.unsqueeze(0) + mean = torch.tensor([0.485, 0.456, 0.406]).cuda().half() + std = torch.tensor([0.229, 0.224, 0.225]).cuda().half() + image = (image / 255.0 - mean[None, None, None, :]) / std[None, None, None, :] + return image + + +def export_to_torch_tensor(model_name="resnet50"): + if model_name != "resnet50": + raise NotImplementedError + timm2ait = timm_export(model_name) + params = timm2ait.export_model(half=True) + return params, timm2ait.pt_model + + +def inference(model_name, mod=None): + # Load params + cuda_params, pt_model = export_to_torch_tensor(model_name) + # Load compiled model + if mod is None: + mod = Model(os.path.join("./tmp", model_name, "test.so")) + + # Set torch tensor params to runtime + for k, v in cuda_params.items(): + mod.set_constant_with_tensor(k, v) + + # prepare input/output tensor + x_input = prepare_data() + x_input = x_input.contiguous() + y_output = torch.zeros([1, 1, 1, 1000]).cuda().half() + y_output = y_output.contiguous() + + # execute + mod.run_with_tensors([x_input], [y_output]) + + # process output with pytorch + y_label = torch.argmax(y_output, dim=-1) + y_cpu = y_label.cpu().numpy() + print(y_cpu) + + # run pytorch + pt_model.eval() + pt_model = pt_model.cuda().half() + pt_output = pt_model(x_input.permute([0, 3, 1, 2])) + y_label = torch.argmax(pt_output, dim=-1) + y_cpu = y_label.cpu().numpy() + print(y_cpu) + + # verify outputs + assert torch.allclose(y_output, pt_output, 1e-1, 1e-1) + print("Verification done!") + + +if __name__ == "__main__": + np.random.seed(4896) + model_name = "resnet50" + mod = compile_module(model_name, use_fp16_acc=True) + inference(model_name, mod) diff --git a/examples/01_resnet-50/modeling/__init__.py b/examples/01_resnet-50/modeling/__init__.py new file mode 100644 index 000000000..5cf1a826f --- /dev/null +++ b/examples/01_resnet-50/modeling/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/01_resnet-50/modeling/resnet.py b/examples/01_resnet-50/modeling/resnet.py new file mode 100644 index 000000000..9842aa18d --- /dev/null +++ b/examples/01_resnet-50/modeling/resnet.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +from aitemplate.frontend import nn +from aitemplate.testing import detect_target + + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NHWC tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block), + with a conv, relu and max_pool. + """ + + def __init__(self, in_channels=3, out_channels=64, norm="BN", activation="ReLU"): + super().__init__(in_channels, out_channels, 4) + conv_op = None + if detect_target().name() == "cuda": + if activation == "ReLU": + conv_op = nn.Conv2dBiasReluFewChannels + elif activation == "Hardswish": + conv_op = nn.Conv2dBiasHardswishFewChannels + else: + raise NotImplementedError + else: + if activation == "ReLU": + conv_op = nn.Conv2dBiasRelu + elif activation == "Hardswish": + conv_op = nn.Conv2dBiasHardswish + else: + raise NotImplementedError + self.conv1 = conv_op(in_channels, out_channels, 7, 2, 7 // 2) + self.pool = nn.MaxPool2d(3, 2, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.pool(x) + return x + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, + with two 3x3 conv layers and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + super().__init__(in_channels, out_channels, stride) + + def forward(self, x): + raise NotImplementedError() + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + activation="ReLU", + stride_in_1x1=False, + dilation=1, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.downsample_0 = nn.Conv2dBias(in_channels, out_channels, 1, stride, 0) + else: + self.downsample_0 = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + conv_op = None + conv_op_add = None + if activation == "ReLU": + conv_op = nn.Conv2dBiasRelu + conv_op_add = nn.Conv2dBiasAddRelu + elif activation == "Hardswish": + conv_op = nn.Conv2dBiasHardswish + conv_op_add = nn.Conv2dBiasAddHardswish + else: + raise NotImplementedError + + self.conv1 = conv_op(in_channels, bottleneck_channels, 1, stride_1x1, 0) + + self.conv2 = conv_op( + bottleneck_channels, + bottleneck_channels, + 3, + stride_3x3, + 1 * dilation, + dilation, + ) + + self.conv3 = conv_op_add(bottleneck_channels, out_channels, 1, 1, 0) + + # for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + + if self.downsample_0 is not None: + downsample = self.downsample_0(x) + else: + downsample = x + + out = self.conv3(out, downsample) + return out + + +class ResNet(nn.Module): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + activation (str): activation function to use. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [ + {"layer1": 1, "layer2": 2, "layer3": 3, "layer4": 4}.get(f, 0) + for f in out_features + ] + ) + stages = stages[:num_stages] + + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "layer" + str(i + 1) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AvgPool2d(7, 1, 0) + self.fc = nn.Linear(curr_channels, num_classes) + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format( + ", ".join(children) + ) + self.reshape = nn.Reshape() + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + # assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = self.fc(x) + if x._rank() == 2: + x = self.reshape(x, [x._size(0), 1, 1, x._size(1)]) + return x + return outputs + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + Returns: + list[CNNBlockBase]: a list of block module. + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert ( + newk not in kwargs + ), f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class( + in_channels=in_channels, out_channels=out_channels, **curr_kwargs + ) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip( + num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels + ): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret + + +def make_stage(*args, **kwargs): + """ + Deprecated alias for backward compatibiltiy. + """ + return ResNet.make_stage(*args, **kwargs) + + +def build_resnet_backbone(depth, activation): + """ + Create a ResNet instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + norm = "BN" + activation = activation + num_groups = 1 + stride_in_1x1 = False + num_groups = 1 + width_per_group = 64 + bottleneck_channels = num_groups * width_per_group + in_channels = 64 + out_channels = 256 + + stem = BasicStem(in_channels=3, out_channels=64, norm=norm, activation=activation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + stages = [] + + for idx, stage_idx in enumerate(range(2, 6)): + # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper + dilation = 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + "activation": activation, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + + return ResNet(stem, stages, num_classes=1000) diff --git a/examples/01_resnet-50/weight_utils.py b/examples/01_resnet-50/weight_utils.py new file mode 100644 index 000000000..beaebd330 --- /dev/null +++ b/examples/01_resnet-50/weight_utils.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +script for converting model from timm to aitemplate +Only tested on resnet50 +""" + + +import pickle +import re + +import click +import numpy as np +import timm +import torch +from aitemplate.testing import detect_target + +CONV_WEIGHT_PATTERN = re.compile(r"conv\d+\.weight") + + +class timm_export(object): + def __init__(self, model_name): + self.model_name = model_name + if model_name != "resnet50": + raise NotImplementedError + + with torch.no_grad(): + self.pt_model = timm.create_model( + model_name, pretrained=True, num_classes=1000 + ) + self.pt_state = self.pt_model.state_dict() + + def export_model(self, half=True): + fused_model = {} + for param_name in self.pt_state.keys(): + self.transform_params(param_name, fused_model) + ait_model = {k.replace(".", "_"): weight for k, weight in fused_model.items()} + if detect_target().name() == "cuda": + self.export_conv0(ait_model, fused_model) + if half: + half_params = {} + for k, v in ait_model.items(): + half_params[k] = v.detach().cuda().half().contiguous() + return half_params + return ait_model + + def fuse_conv_bn_weights( + self, conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False + ): + conv_w = torch.tensor(conv_w) + bn_rm = torch.tensor(bn_rm) + bn_rv = torch.tensor(bn_rv) + bn_w = torch.tensor(bn_w) + bn_b = torch.tensor(bn_b) + bn_eps = torch.tensor(bn_eps) + + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + + # NCHW -> NHWC + conv_w = conv_w.permute(0, 2, 3, 1).contiguous() + for arr in [conv_w.numpy(), conv_b.numpy()]: + if np.isnan(arr).any(): + print("fuse bn error") + return conv_w, conv_b + + def transform_conv0(self): + conv_w = self.pt_state["conv1.weight"] + bn_w = self.pt_state["bn1.weight"] + bn_b = self.pt_state["bn1.bias"] + bn_rm = self.pt_state["bn1.running_mean"] + bn_rv = self.pt_state["bn1.running_var"] + fused_w, fused_b = self.fuse_conv_bn_weights( + conv_w, None, bn_rm, bn_rv, 1e-5, bn_w, bn_b + ) + return fused_w, fused_b + + def transform_params(self, param_name, fused_model): + if param_name == "conv1.weight": + fused_w, fused_b = self.transform_conv0() + fused_model["stem.conv1.weight"] = fused_w + fused_model["stem.conv1.bias"] = fused_b + elif "downsample.0.weight" in param_name: + fused_w, fused_b = self.transform_downsample(param_name) + fused_model[param_name] = fused_w + fused_model[param_name.replace("weight", "bias")] = fused_b + elif param_name == "fc.weight": + fused_model["fc.weight"] = self.pt_state["fc.weight"] + fused_model["fc.bias"] = self.pt_state["fc.bias"] + elif CONV_WEIGHT_PATTERN.search(param_name) is not None: + bn_w_name = param_name.replace("conv", "bn") + conv_w = self.pt_state[param_name] + bn_w = self.pt_state[bn_w_name] + bn_b = self.pt_state[bn_w_name.replace("weight", "bias")] + bn_rm = self.pt_state[bn_w_name.replace("weight", "running_mean")] + bn_rv = self.pt_state[bn_w_name.replace("weight", "running_var")] + fused_w, fused_b = self.fuse_conv_bn_weights( + conv_w, None, bn_rm, bn_rv, 1e-5, bn_w, bn_b + ) + fused_model[param_name] = fused_w + fused_model[param_name.replace("weight", "bias")] = fused_b + else: + pass + + def transform_downsample(self, param_name): + assert "downsample" in param_name + tags = param_name.split(".") + block_tag = ".".join(tags[:2]) + conv_w = self.pt_state[f"{block_tag}.downsample.0.weight"] + bn_w = self.pt_state[f"{block_tag}.downsample.1.weight"] + bn_b = self.pt_state[f"{block_tag}.downsample.1.bias"] + bn_rm = self.pt_state[f"{block_tag}.downsample.1.running_mean"] + bn_rv = self.pt_state[f"{block_tag}.downsample.1.running_var"] + fused_w, fused_b = self.fuse_conv_bn_weights( + conv_w, None, bn_rm, bn_rv, 1e-5, bn_w, bn_b + ) + return fused_w, fused_b + + def export_conv0(self, ait_model, fuse_model): + pt_name = "stem.conv1.weight" + x = fuse_model[pt_name] + conv_w = torch.zeros((64, 7, 7, 4)) + conv_w[:, :, :, :3] = x + ait_model[pt_name.replace(".", "_")] = conv_w + + +def export_to_torch_tensor(model_name="resnet50"): + if model_name != "resnet50": + raise NotImplementedError + timm2ait = timm_export(model_name) + ait_model = timm2ait.export_model(half=True) + return ait_model + + +@click.command() +@click.option("--param-path", type=str, default="resnet50.pkl") +def export_to_numpy(param_path): + ait_model = export_to_torch_tensor() + np_weights = {} + for k, v in ait_model.items(): + np_weights[k] = v.detach().cpu().numpy().astype(np.float16) + + with open(param_path, "wb") as f: + pickle.dump(np_weights, f) + + +if __name__ == "__main__": + export_to_numpy() diff --git a/examples/02_detectron2/README.md b/examples/02_detectron2/README.md new file mode 100644 index 000000000..99fadec85 --- /dev/null +++ b/examples/02_detectron2/README.md @@ -0,0 +1,169 @@ +# Getting Started with AIT for the Inference of Detectron2 Based Models + +This document describes the usage of AIT for detectron2 vision models such as mask RCNN and faster RCNN. + +For an end-to-end example with the API, see `prepare_and_run_rcnn.sh` which covers how to prepare and run inference with `mask_rcnn_R_50_FPN`. + +## Create the AIT Model from a Config File + +1. Pick a model and its config file from `configs`, for example, `mask_rcnn_R_50_FPN.yaml`. + +2. Build the AIT Model by running `compile_model.py` with the config file, for example, + +``` +cfg=examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 examples/02_detectron2/compile_model.py \ + --config $cfg \ + --batch 1 +``` + +All parameters in the built AIT model are not initialized, and therefore are filled with random values. We will initialize these parameters in the following step (i.e., exporting the weights of the pre-trained model to the AIT model). Check `tmp/mask_rcnn_R_50_FPN/params.json` for the list of parameters in the AIT model and their shapes. + +## Download the Detectron2 Pre-trained Model, and Export the Weights to the AIT Model + +1. For example, download Detectron2 `mask_rcnn_R_50_FPN` pre-trained model and save it to `tmp/pt_mask_rcnn_R_50_FPN.pkl`: + +``` +wget https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl -O tmp/pt_mask_rcnn_R_50_FPN.pkl +``` + +2. Export the weights from the pre-trained model to AIT model by running `tools/convert_pt2ait.py`: + +``` +python3 examples/02_detectron2/tools/convert_pt2ait.py \ + --d2-weight tmp/pt_mask_rcnn_R_50_FPN.pkl \ + --ait-weight tmp/ait_mask_rcnn_R_50_FPN.pt \ + --model-name mask_rcnn_R_50_FPN +``` + +The weights are exported to AIT and saved as `tmp/ait_mask_rcnn_R_50_FPN.pt` for inference run. + +## Download Inference DataSet and Run AIT Model + +1. For example, download the COCO 2017 Dataset: + +``` +mkdir -p ~/.torch/datasets/coco + +wget https://dl.fbaipublicfiles.com/detectron2/annotations/coco/val2017_100.tgz -O ~/.torch/datasets/coco/val2017_100.tgz +tar xzf ~/.torch/datasets/coco/val2017_100.tgz -C ~/.torch/datasets/coco && rm -f ~/.torch/datasets/coco/val2017_100.tgz +``` + +2. Run inference of the AIT model on the inputs with `demo.py`: + +``` +python3 examples/02_detectron2/demo.py \ + --weight tmp/ait_mask_rcnn_R_50_FPN.pt \ + --config examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml \ + --batch 1 --input "~/.torch/datasets/coco/val2017/*.jpg" \ + --confidence-threshold 0.5 \ + --display \ + --cudagraph +``` + +## Multi-GPU profiling +AIT requires to do profiling to decide best algorithms for CUTLASS and CK. +To enable multiple GPUs profiling, set the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform with available GPU ids. + + +## Results +_PT = PyTorch 1.12 Eager_ +### A100-40GB / CUDA 11.6 + +- Input size: 448x608 + +| Batch size | PT Latency (ms) | PT FPS | AIT Latency (ms) | AIT FPS | +|------------|-----------------|--------|------------------|---------| +| 1 | 21.70 | 46.09 | 4.40 | 227.27 | +| 2 | 29.71 | 67.32 | 6.68 | 299.40 | +| 4 | 35.67 | 112.13 | 11.12 | 359.71 | +| 8 | 59.71 | 133.98 | 22.24 | 359.71 | +| 16 | 112.91 | 141.70 | 36.64 | 436.68 | +| 32 | 224.24 | 142.70 | 71.04 | 450.45 | +| 64 | 448.84 | 142.59 | 140.16 | 456.62 | + +- Input size: 800x1344 + +| Batch size | PT Latency (ms) | PT FPS | AIT Latency (ms) | AIT FPS | +|------------|-----------------|--------|------------------|---------| +| 1 | 22.99 | 43.50 | 8.50 | 117.65 | +| 2 | 34.48 | 58.01 | 13.42 | 149.03 | +| 4 | 65.00 | 61.54 | 22.88 | 174.83 | +| 8 | 125.25 | 63.87 | 41.44 | 193.05 | +| 16 | 246.49 | 64.91 | 78.56 | 203.67 | +| 32 | 503.21 | 63.59 | 154.56 | 207.04 | +| 64 | OOM | OOM | 304.64 | 210.08 | + + +### MI-250 / ROCm 5.2.3 / HIPCC-10736 +_PT = PyTorch 1.12 Eager_ +#### 1 GCDs + +- Input size: 448x608 + +| Batch size | PT Latency (ms) | PT FPS | AIT Latency (ms) | AIT FPS | +|------------|-----------------|--------|------------------|---------| +| 1 | 24.75 | 40.41 | 10.63 | 94.07 | +| 2 | 29.28 | 68.30 | 15.96 | 125.31 | +| 4 | 42.45 | 94.24 | 26.24 | 152.44 | +| 8 | 79.73 | 100.34 | 51.04 | 156.74 | +| 16 | 141.84 | 112.81 | 89.12 | 179.53 | +| 32 | 284.39 | 112.52 | 161.92 | 197.63 | +| 64 | 600.84 | 106.52 | Error | Error | + +- Input size: 800x1344 + +| Batch size | PT Latency (ms) | PT FPS | AIT Latency (ms) | AIT FPS | +|------------|-----------------|--------|------------------|---------| +| 1 | 26.80 | 37.31 | 19.23 | 52.00 | +| 2 | 43.61 | 45.86 | 30.28 | 66.05 | +| 4 | 98.88 | 40.45 | 51.56 | 77.58 | +| 8 | 189.45 | 42.23 | 98.80 | 80.97 | +| 16 | 389.94 | 41.03 | 177.28 | 90.25 | +| 32 | 807.22 | 39.64 | 333.44 | 95.97 | +| 64 | 1768.66 | 36.19 | Error | Error | + +#### 2 GCDs + +- Input size: 448x608 + +| Batch size | AIT Latency (ms) | AIT FPS | +|------------|------------------|---------| +| 1 | | | +| 2 | 12.78 | 156.49 | +| 4 | 20.66 | 193.61 | +| 8 | 32.16 | 248.76 | +| 16 | 61.52 | 260.08 | +| 32 | 106.08 | 301.66 | +| 64 | 194.24 | 329.49 | + + +- Input size: 800x1344 + +| Batch size | AIT Latency (ms) | AIT FPS | +|------------|------------------|---------| +| 1 | | | +| 2 | 22 | 90.91 | +| 4 | 34 | 117.65 | +| 8 | 55.52 | 144.09 | +| 16 | 104.48 | 153.14 | +| 32 | 190.24 | 168.21 | +| 64 | 362.88 | 176.37 | + + +### Sample outputs + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_d2_1.jpg) + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_d2_2.jpg) + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_d2_3.jpg) + + +### Note for Performance Results + +- For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. +- To benchmark MI-250, the first step is to run `python3 benchmark_ait.py` to generate all necessary model dynamic library files with single GCD. Then run `./benchmark_mi250.sh {batch_size}` to simulate data parallel execution on 2 GCDs, each GCD is processing half of the batch. +- To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. +- To benchmark MI-250 2 GCDs, we observed performance regression with rocm perf-determ mode. The 2 GCDs number is running without perf-determ mode set with command `rocm-smi -d x --resetperfdeterminism`, where `x` is the GPU id. +- Performance results are what we can reproduced. It should not be used for other purposes. diff --git a/examples/02_detectron2/compile_model.py b/examples/02_detectron2/compile_model.py new file mode 100644 index 000000000..4bf5d4d25 --- /dev/null +++ b/examples/02_detectron2/compile_model.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import os + +import click + +import numpy as np +import torch +from aitemplate.compiler import compile_model, Model + +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from configs import get_cfg_defaults +from modeling.meta_arch import GeneralizedRCNN + +# pylint: disable=W0102 + + +def rand_init(shape): + if len(shape) == 1: + arr = np.zeros(shape).astype("float16") + else: + fout = shape[0] + fin = shape[-1] + scale = np.sqrt(2) / np.sqrt(fout + fin) + arr = np.random.normal(0, scale, shape).astype("float16") + return torch.from_numpy(arr).cuda().half() + + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +def extract_params_meta(net): + ret = [] + params = net.parameters() + for p in params: + t = p.tensor() + name = t._attrs["name"] + shape = [x._attrs["values"][0] for x in t._attrs["shape"]] + ret.append([name, shape]) + return ret + + +def benchmark(cfg, mod=None): + im_shape = (cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, 3) + HH, WW, CC = im_shape + BS = cfg.SOLVER.IMS_PER_BATCH + inputs = np.random.normal(0, 1, (BS, HH, WW, CC)).astype("float16") + + model_name = cfg.MODEL.NAME + if mod is None: + mod = Model(os.path.join("./tmp", model_name, "test.so")) + + ait_mod = GeneralizedRCNN(cfg) + + for name, param in ait_mod.named_parameters(): + shape = get_shape(param.tensor()) + arr = rand_init(shape) + mod.set_constant_with_tensor(name.replace(".", "_"), arr) + + x_input = torch.tensor(inputs).cuda().half() + x = x_input.contiguous() + + GeneralizedRCNN(cfg).set_anchors(mod) + + topk = cfg.POSTPROCESS.TOPK + outputs = [ + torch.empty([BS, 1], dtype=torch.int64).cuda(), + torch.empty([BS, topk, 4]).cuda().half(), + torch.empty([BS, topk]).cuda().half(), + torch.empty([BS, topk], dtype=torch.int64).cuda(), + ] + if cfg.MODEL.MASK_ON: + mask_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION * 2 + outputs.append(torch.empty([BS, topk, mask_size, mask_size]).cuda().half()) + + mod.benchmark_with_tensors([x], outputs, count=100, repeat=2, graph_mode=True) + + +def compile_module(cfg): + model_name = cfg.MODEL.NAME + target = detect_target() + + im_shape = (cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, 3) + HH, WW, CC = im_shape + BS = cfg.SOLVER.IMS_PER_BATCH + x = Tensor(shape=[BS, HH, WW, CC], dtype="float16", name="input_0", is_input=True) + model = GeneralizedRCNN(cfg) + model.name_parameter_tensor() + + y = model(x) + mark_output(y) + module = compile_model(y, target, "./tmp", model_name) + + with open(os.path.join("./tmp", model_name, "params.json"), "w") as fo: + fo.write(json.dumps(extract_params_meta(model))) + + benchmark(cfg, module) + + +@click.command() +@click.option("--config", default="", metavar="FILE", help="path to config file") +@click.option("--bench-config", default="", metavar="FILE", help="path to config file") +@click.option("--batch", default=0, help="batch size") +@click.option("--eval/--no-eval", default=False, help="perform evaluation only") +def compile_and_benchmark(config, bench_config, batch, eval): + cfg = get_cfg_defaults() + cfg.merge_from_file(config) + if bench_config != "": + cfg.merge_from_file(bench_config) + if batch > 0: + cfg.SOLVER.IMS_PER_BATCH = batch + cfg.freeze() + print(cfg.MODEL.NAME) + + if eval: + benchmark(cfg) + else: + compile_module(cfg) + + +if __name__ == "__main__": + np.random.seed(4896) + compile_and_benchmark() diff --git a/examples/02_detectron2/configs/__init__.py b/examples/02_detectron2/configs/__init__.py new file mode 100644 index 000000000..679ca77c9 --- /dev/null +++ b/examples/02_detectron2/configs/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .config import get_cfg_defaults + +__all__ = ["get_cfg_defaults"] diff --git a/examples/02_detectron2/configs/config.py b/examples/02_detectron2/configs/config.py new file mode 100644 index 000000000..c9cf1e5c3 --- /dev/null +++ b/examples/02_detectron2/configs/config.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from yacs.config import CfgNode + + +def get_cfg_defaults() -> CfgNode: + """ + Get a copy of the default config. + Returns: + a detectron2 CfgNode instance. + """ + from .defaults import _C + + return _C.clone() diff --git a/examples/02_detectron2/configs/defaults.py b/examples/02_detectron2/configs/defaults.py new file mode 100644 index 000000000..c2bb11eb7 --- /dev/null +++ b/examples/02_detectron2/configs/defaults.py @@ -0,0 +1,668 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +# The version number, to upgrade from old configs to new ones if any +# changes happen. It's recommended to keep a VERSION in your config file. +_C.VERSION = 2 + +_C.MODEL = CN() +_C.MODEL.NAME = "" +_C.MODEL.LOAD_PROPOSALS = False +_C.MODEL.MASK_ON = False +_C.MODEL.KEYPOINT_ON = False +_C.MODEL.DEVICE = "cuda" +_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" + +# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file +# to be loaded to the model. You can find available models in the model zoo. +_C.MODEL.WEIGHTS = "" + +# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR). +# To train on images of different number of channels, just set different mean & std. +# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] +_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] +# When using pre-trained models in Detectron1 or any MSRA models, +# std has been absorbed into its conv1 weights, so the std needs to be set 1. +# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) +_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + +# ----------------------------------------------------------------------------- +# POST PROCESS +# ----------------------------------------------------------------------------- +_C.POSTPROCESS = CN() +_C.POSTPROCESS.POST_ON = True +_C.POSTPROCESS.USE_TOPK = True +_C.POSTPROCESS.TOPK = 130 + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge. +# Please refer to ResizeShortestEdge for detailed definition. +# Size of the smallest side of the image during training +_C.INPUT.MIN_SIZE_TRAIN = (800,) +# Sample size of smallest side by choice or random selection from range give by +# INPUT.MIN_SIZE_TRAIN +_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice" +# Maximum size of the side of the image during training +_C.INPUT.MAX_SIZE_TRAIN = 1333 +# Size of the smallest side of the image during testing. Set to zero to disable resize in testing. +_C.INPUT.MIN_SIZE_TEST = 800 +# Maximum size of the side of the image during testing +_C.INPUT.MAX_SIZE_TEST = 1333 +# Mode for flipping images used in data augmentation during training +# choose one of ["horizontal, "vertical", "none"] +_C.INPUT.RANDOM_FLIP = "horizontal" + +# `True` if cropping is used for data augmentation during training +_C.INPUT.CROP = CN({"ENABLED": False}) +# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation. +_C.INPUT.CROP.TYPE = "relative_range" +# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of +# pixels if CROP.TYPE is "absolute" +_C.INPUT.CROP.SIZE = [0.9, 0.9] + + +# Whether the model needs RGB, YUV, HSV etc. +# Should be one of the modes defined here, as we use PIL to read the image: +# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes +# with BGR being the one exception. One can set image format to BGR, we will +# internally use RGB for conversion and flip the channels over +_C.INPUT.FORMAT = "BGR" +# The ground truth mask format that the model will use. +# Mask R-CNN supports either "polygon" or "bitmask" as ground truth. +_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask" + + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training. Must be registered in DatasetCatalog +# Samples from these datasets will be merged and used as one dataset. +_C.DATASETS.TRAIN = () +# List of the pre-computed proposal files for training, which must be consistent +# with datasets listed in DATASETS.TRAIN. +_C.DATASETS.PROPOSAL_FILES_TRAIN = () +# Number of top scoring precomputed proposals to keep for training +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000 +# List of the dataset names for testing. Must be registered in DatasetCatalog +_C.DATASETS.TEST = () +# List of the pre-computed proposal files for test, which must be consistent +# with datasets listed in DATASETS.TEST. +_C.DATASETS.PROPOSAL_FILES_TEST = () +# Number of top scoring precomputed proposals to keep for test +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000 + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +# If True, each batch should contain only images for which the aspect ratio +# is compatible. This groups portrait images together, and landscape images +# are not batched with portrait images. +_C.DATALOADER.ASPECT_RATIO_GROUPING = True +# Options: TrainingSampler, RepeatFactorTrainingSampler +_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler" +# Repeat threshold for RepeatFactorTrainingSampler +_C.DATALOADER.REPEAT_THRESHOLD = 0.0 +# Tf True, when working on datasets that have instance annotations, the +# training dataloader will filter out images without associated annotations +_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True + +# ---------------------------------------------------------------------------- # +# Backbone options +# ---------------------------------------------------------------------------- # +_C.MODEL.BACKBONE = CN() + +_C.MODEL.BACKBONE.NAME = "build_resnet_backbone" +# Freeze the first several stages so they are not trained. +# There are 5 stages in ResNet. The first is a convolution, and the following +# stages are each group of residual blocks. +_C.MODEL.BACKBONE.FREEZE_AT = 2 + + +# ---------------------------------------------------------------------------- # +# FPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.FPN = CN() +# Names of the input feature maps to be used by FPN +# They must have contiguous power of 2 strides +# e.g., ["res2", "res3", "res4", "res5"] +_C.MODEL.FPN.IN_FEATURES = [] +_C.MODEL.FPN.OUT_CHANNELS = 256 + +# Options: "" (no norm), "GN" +_C.MODEL.FPN.NORM = "" + +# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg" +_C.MODEL.FPN.FUSE_TYPE = "sum" + + +# ---------------------------------------------------------------------------- # +# Proposal generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.PROPOSAL_GENERATOR = CN() +# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals" +_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN" +# Proposal height and width both need to be greater than MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0 + + +# ---------------------------------------------------------------------------- # +# Anchor generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.ANCHOR_GENERATOR = CN() +# The generator can be any name in the ANCHOR_GENERATOR registry +_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator" +# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input. +# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for +# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1. +# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]] +# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect +# ratios are generated by an anchor generator. +# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W) +# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true, +# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used +# for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]] +# Anchor angles. +# list[list[float]], the angle in degrees, for each input feature map. +# ANGLES[i] specifies the list of angles for IN_FEATURES[i]. +_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]] +# Relative offset between the center of the first anchor and the top-left corner of the image +# Value has to be in [0, 1). Recommend to use 0.5, which means half stride. +# The value is not expected to affect model accuracy. +_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0 + +# ---------------------------------------------------------------------------- # +# RPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.RPN = CN() +_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY + +# Names of the input feature maps to be used by RPN +# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN +_C.MODEL.RPN.IN_FEATURES = ["res4"] +# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels +# Set to -1 or a large value, e.g. 100000, to disable pruning anchors +_C.MODEL.RPN.BOUNDARY_THRESH = -1 +# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD] +# Minimum overlap required between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD +# ==> positive RPN example: 1) +# Maximum overlap allowed between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD +# ==> negative RPN example: 0) +# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD) +# are ignored (-1) +_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7] +_C.MODEL.RPN.IOU_LABELS = [0, -1, 1] +# Number of regions per image used to train RPN +_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per RPN minibatch +_C.MODEL.RPN.POSITIVE_FRACTION = 0.5 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1" +_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0 +# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets +_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0 +_C.MODEL.RPN.LOSS_WEIGHT = 1.0 +# Number of top scoring RPN proposals to keep before applying NMS +# When FPN is used, this is *per FPN level* (not total) +_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000 +_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000 +# Number of top scoring RPN proposals to keep after applying NMS +# When FPN is used, this limit is applied per level and then again to the union +# of proposals from all levels +# NOTE: When FPN is used, the meaning of this config is different from Detectron1. +# It means per-batch topk in Detectron1, but per-image topk here. +# See the "find_top_rpn_proposals" function for details. +_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000 +_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000 +# NMS threshold used on RPN proposals +_C.MODEL.RPN.NMS_THRESH = 0.7 +# Set this to -1 to use the same number of output channels as input channels. +_C.MODEL.RPN.CONV_DIMS = [-1] + +_C.MODEL.RPN.RPN_DIM = 256 + +# ---------------------------------------------------------------------------- # +# ROI HEADS options +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_HEADS = CN() +_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads" +# Number of foreground classes +_C.MODEL.ROI_HEADS.NUM_CLASSES = 80 +# Names of the input feature maps to be used by ROI heads +# Currently all heads (box, mask, ...) use the same input feature map list +# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN +_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"] +# IOU overlap ratios [IOU_THRESHOLD] +# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD) +# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD) +_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5] +_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1] +# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training +# Total number of RoIs per training minibatch = +# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH +# E.g., a common configuration is: 512 * 16 = 8192 +_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 +# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) +_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 + +# Only used on test mode + +# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to +# balance obtaining high recall with not having too many low precision +# detections that will slow down inference post processing steps (like NMS) +# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down +# inference. +_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 +# Overlap threshold used for non-maximum suppression (suppress boxes with +# IoU >= this threshold) +_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5 +# If True, augment proposals with ground-truth boxes before sampling proposals to +# train ROI heads. +_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True + +# ---------------------------------------------------------------------------- # +# Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_HEAD = CN() +# C4 don't use head name option +# Options for non-C4 models: FastRCNNConvFCHead, +_C.MODEL.ROI_BOX_HEAD.NAME = "" +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1" +# The final scaling coefficient on the box regression loss, used to balance the magnitude of its +# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`. +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0 +# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets +# These are empirically chosen to approximately lead to unit variance targets +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0 +_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2" + +_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0 +# Hidden layer dimension for FC layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024 +_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0 +# Channel dimension for Conv layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_BOX_HEAD.NORM = "" +# Whether to use class agnostic for bbox regression +_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False +# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes. +_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False + +# Federated loss can be used to improve the training of LVIS +_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False +# Sigmoid cross entrophy is used with federated loss +_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False +# The power value applied to image_count when calcualting frequency weight +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER = 0.5 +# Number of classes to keep in total +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES = 50 + +# ---------------------------------------------------------------------------- # +# Cascaded Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_CASCADE_HEAD = CN() +# The number of cascade stages is implicitly defined by the length of the following two configs. +_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = ( + (10.0, 10.0, 5.0, 5.0), + (20.0, 20.0, 10.0, 10.0), + (30.0, 30.0, 15.0, 15.0), +) +_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7) + + +# ---------------------------------------------------------------------------- # +# Mask Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_MASK_HEAD = CN() +_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead" +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head +_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_MASK_HEAD.NORM = "" +# Whether to use class agnostic for mask prediction +_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2" + + +# ---------------------------------------------------------------------------- # +# Keypoint Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_KEYPOINT_HEAD = CN() +_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead" +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8)) +_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO. + +# Images with too few (or no) keypoints are excluded from training. +_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1 +# Normalize by the total number of visible keypoints in the minibatch if True. +# Otherwise, normalize by the total number of keypoints that could ever exist +# in the minibatch. +# The keypoint softmax loss is only calculated on visible keypoints. +# Since the number of visible keypoints can vary significantly between +# minibatches, this has the effect of up-weighting the importance of +# minibatches with few visible keypoints. (Imagine the extreme case of +# only one visible keypoint versus N: in the case of N, each one +# contributes 1/N to the gradient compared to the single keypoint +# determining the gradient direction). Instead, we can normalize the +# loss by the total number of keypoints, if it were the case that all +# keypoints were visible in a full minibatch. (Returning to the example, +# this means that the one visible keypoint contributes as much as each +# of the N keypoints.) +_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True +# Multi-task loss weight to use for keypoints +# Recommended values: +# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True +# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False +_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2" + +# ---------------------------------------------------------------------------- # +# Semantic Segmentation Head +# ---------------------------------------------------------------------------- # +_C.MODEL.SEM_SEG_HEAD = CN() +_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead" +_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"] +# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for +# the correposnding pixel. +_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255 +# Number of classes in the semantic segmentation head +_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54 +# Number of channels in the 3x3 convs inside semantic-FPN heads. +_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128 +# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride. +_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 +# Normalization method for the convolution layers. Options: "" (no norm), "GN". +_C.MODEL.SEM_SEG_HEAD.NORM = "GN" +_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0 + +_C.MODEL.PANOPTIC_FPN = CN() +# Scaling of all losses from instance detection / segmentation head. +_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0 + +# options when combining instance & semantic segmentation outputs +_C.MODEL.PANOPTIC_FPN.COMBINE = CN( + {"ENABLED": True} +) # "COMBINE.ENABLED" is deprecated & not used +_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5 +_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096 +_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5 + + +# ---------------------------------------------------------------------------- # +# RetinaNet Head +# ---------------------------------------------------------------------------- # +_C.MODEL.RETINANET = CN() + +# This is the number of foreground classes. +_C.MODEL.RETINANET.NUM_CLASSES = 80 + +_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] + +# Convolutions to use in the cls and bbox tower +# NOTE: this doesn't include the last conv for logits +_C.MODEL.RETINANET.NUM_CONVS = 4 + +# IoU overlap ratio [bg, fg] for labeling anchors. +# Anchors with < bg are labeled negative (0) +# Anchors with >= bg and < fg are ignored (-1) +# Anchors with >= fg are labeled positive (1) +_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5] +_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1] + +# Prior prob for rare case (i.e. foreground) at the beginning of training. +# This is used to set the bias for the logits layer of the classifier subnet. +# This improves training stability in the case of heavy class imbalance. +_C.MODEL.RETINANET.PRIOR_PROB = 0.01 + +# Inference cls score threshold, only anchors with score > INFERENCE_TH are +# considered for inference (to improve speed) +_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05 +# Select topk candidates before NMS +_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000 +_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5 + +# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets +_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) + +# Loss parameters +_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0 +_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25 +_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1" + +# One of BN, SyncBN, FrozenBN, GN +# Only supports GN until unshared norm is implemented +_C.MODEL.RETINANET.NORM = "" + + +# ---------------------------------------------------------------------------- # +# ResNe[X]t options (ResNets = {ResNet, ResNeXt} +# Note that parts of a resnet may be used for both the backbone and the head +# These options apply to both +# ---------------------------------------------------------------------------- # +_C.MODEL.RESNETS = CN() + +_C.MODEL.RESNETS.DEPTH = 50 + +_C.MODEL.RESNETS.STAGES = [3, 4, 6, 3] + +_C.MODEL.RESNETS.FILTERS = [64, 256, 512, 1024, 2048] + +_C.MODEL.RESNETS.OUT_FEATURES = [ + "res4" +] # res4 for C4 backbone, res2..5 for FPN backbone + +# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt +_C.MODEL.RESNETS.NUM_GROUPS = 1 + +# Options: FrozenBN, GN, "SyncBN", "BN" +_C.MODEL.RESNETS.NORM = "FrozenBN" + +# Baseline width of each group. +# Scaling this parameters will scale the width of all bottleneck layers. +_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64 + +# Place the stride 2 conv on the 1x1 filter +# Use True only for the original MSRA ResNet; use False for C2 and Torch models +_C.MODEL.RESNETS.STRIDE_IN_1X1 = True + +# Apply dilation in stage "res5" +_C.MODEL.RESNETS.RES5_DILATION = 1 + +# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet +# For R18 and R34, this needs to be set to 64 +_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 +_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 + +# Apply Deformable Convolution in stages +# Specify if apply deform_conv on Res2, Res3, Res4, Res5 +_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False] +# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168); +# Use False for DeformableV1. +_C.MODEL.RESNETS.DEFORM_MODULATED = False +# Number of groups in deformable conv. +_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1 + + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() + +# Options: WarmupMultiStepLR, WarmupCosineLR. +# See detectron2/solver/build.py for definition. +_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR" + +_C.SOLVER.MAX_ITER = 40000 + +_C.SOLVER.BASE_LR = 0.001 +# The end lr, only used by WarmupCosineLR +_C.SOLVER.BASE_LR_END = 0.0 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.NESTEROV = False + +_C.SOLVER.WEIGHT_DECAY = 0.0001 +# The weight decay that's applied to parameters of normalization layers +# (typically the affine transformation) +_C.SOLVER.WEIGHT_DECAY_NORM = 0.0 + +_C.SOLVER.GAMMA = 0.1 +# The iteration number to decrease learning rate by GAMMA. +_C.SOLVER.STEPS = (30000,) + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000 +_C.SOLVER.WARMUP_ITERS = 1000 +_C.SOLVER.WARMUP_METHOD = "linear" + +# Save a checkpoint after every this number of iterations +_C.SOLVER.CHECKPOINT_PERIOD = 5000 + +# Number of images per batch across all machines. This is also the number +# of training images per step (i.e. per iteration). If we use 16 GPUs +# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch. +# May be adjusted automatically if REFERENCE_WORLD_SIZE is set. +_C.SOLVER.IMS_PER_BATCH = 16 + +# The reference number of workers (GPUs) this config is meant to train with. +# It takes no effect when set to 0. +# With a non-zero value, it will be used by DefaultTrainer to compute a desired +# per-worker batch size, and then scale the other related configs (total batch size, +# learning rate, etc) to match the per-worker batch size. +# See documentation of `DefaultTrainer.auto_scale_workers` for details: +_C.SOLVER.REFERENCE_WORLD_SIZE = 0 + +# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for +# biases. This is not useful (at least for recent models). You should avoid +# changing these and they exist only to reproduce Detectron v1 training if +# desired. +_C.SOLVER.BIAS_LR_FACTOR = 1.0 +_C.SOLVER.WEIGHT_DECAY_BIAS = None # None means following WEIGHT_DECAY + +# Gradient clipping +_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False}) +# Type of gradient clipping, currently 2 values are supported: +# - "value": the absolute values of elements of each gradients are clipped +# - "norm": the norm of the gradient for each parameter is clipped thus +# affecting all elements in the parameter +_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value" +# Maximum absolute value used for clipping gradients +_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0 +# Floating point number p for L-p norm to be used with the "norm" +# gradient clipping type; for L-inf, please specify .inf +_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0 + +# Enable automatic mixed precision for training +# Note that this does not change model's inference behavior. +# To use AMP in inference, run inference under autocast() +_C.SOLVER.AMP = CN({"ENABLED": False}) + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +# For end-to-end tests to verify the expected accuracy. +# Each item is [task, metric, value, tolerance] +# e.g.: [['bbox', 'AP', 38.5, 0.2]] +_C.TEST.EXPECTED_RESULTS = [] +# The period (in terms of steps) to evaluate the model during training. +# Set to 0 to disable. +_C.TEST.EVAL_PERIOD = 0 +# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval +# When empty, it will use the defaults in COCO. +# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. +_C.TEST.KEYPOINT_OKS_SIGMAS = [] +# Maximum number of detections to return per image during inference (100 is +# based on the limit established for the COCO dataset). +_C.TEST.DETECTIONS_PER_IMAGE = 100 + +_C.TEST.AUG = CN({"ENABLED": False}) +_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200) +_C.TEST.AUG.MAX_SIZE = 4000 +_C.TEST.AUG.FLIP = True + +_C.TEST.PRECISE_BN = CN({"ENABLED": False}) +_C.TEST.PRECISE_BN.NUM_ITER = 200 + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +# Directory where output files are written +_C.OUTPUT_DIR = "./output" +# Set seed to negative to fully randomize everything. +# Set seed to positive to use a fixed seed. Note that a fixed seed increases +# reproducibility but does not guarantee fully deterministic behavior. +# Disabling all parallelism further increases reproducibility. +_C.SEED = -1 +# Benchmark different cudnn algorithms. +# If input images have very different sizes, this option will have large overhead +# for about 10k iterations. It usually hurts total time, but can benefit for certain models. +# If input images have the same or similar sizes, benchmark is often helpful. +_C.CUDNN_BENCHMARK = False +# The period (in terms of steps) for minibatch visualization at train time. +# Set to 0 to disable. +_C.VIS_PERIOD = 0 + +# global config is for quick hack purposes. +# You can set them in command line or config files, +# and access it with: +# +# from detectron2.config import global_cfg +# print(global_cfg.HACK) +# +# Do not commit any configs into it. +_C.GLOBAL = CN() +_C.GLOBAL.HACK = 1.0 + + +# def get_cfg_defaults(): +# return _C.clone() diff --git a/examples/02_detectron2/configs/faster_rcnn_R_101_FPN.yaml b/examples/02_detectron2/configs/faster_rcnn_R_101_FPN.yaml new file mode 100644 index 000000000..b69b96822 --- /dev/null +++ b/examples/02_detectron2/configs/faster_rcnn_R_101_FPN.yaml @@ -0,0 +1,47 @@ +MODEL: + NAME: "faster_rcnn_R_101_FPN" + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + DEPTH: 101 + STAGES: [3, 4, 23, 3] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 1 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1344 +POSTPROCESS: + POST_ON: True + USE_TOPK: True + TOPK: 100 +VERSION: 2 diff --git a/examples/02_detectron2/configs/faster_rcnn_R_50_FPN.yaml b/examples/02_detectron2/configs/faster_rcnn_R_50_FPN.yaml new file mode 100644 index 000000000..26aa4c210 --- /dev/null +++ b/examples/02_detectron2/configs/faster_rcnn_R_50_FPN.yaml @@ -0,0 +1,45 @@ +MODEL: + NAME: "faster_rcnn_R_50_FPN" + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 1 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1344 +POSTPROCESS: + POST_ON: True + USE_TOPK: True + TOPK: 100 +VERSION: 2 diff --git a/examples/02_detectron2/configs/mask_rcnn_R_101_FPN.yaml b/examples/02_detectron2/configs/mask_rcnn_R_101_FPN.yaml new file mode 100644 index 000000000..c2c6c946c --- /dev/null +++ b/examples/02_detectron2/configs/mask_rcnn_R_101_FPN.yaml @@ -0,0 +1,48 @@ +MODEL: + NAME: "mask_rcnn_R_101_FPN" + MASK_ON: True + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + DEPTH: 101 + STAGES: [3, 4, 23, 3] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 1 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1344 +POSTPROCESS: + POST_ON: True + USE_TOPK: False + TOPK: 100 +VERSION: 2 diff --git a/examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml b/examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml new file mode 100644 index 000000000..47149bf18 --- /dev/null +++ b/examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml @@ -0,0 +1,46 @@ +MODEL: + NAME: "mask_rcnn_R_50_FPN" + MASK_ON: True + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 1 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1344 +POSTPROCESS: + POST_ON: True + USE_TOPK: False + TOPK: 100 +VERSION: 2 diff --git a/examples/02_detectron2/demo.py b/examples/02_detectron2/demo.py new file mode 100644 index 000000000..749a1eab8 --- /dev/null +++ b/examples/02_detectron2/demo.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +A main inference script for rcnn models +""" +import glob +import os + +import click +import tqdm +from configs import get_cfg_defaults +from predictor import Predictor + + +@click.command() +@click.option("--config", default="", metavar="FILE", help="path to config file") +@click.option("--bench-config", default="", metavar="FILE", help="path to config file") +@click.option( + "--input", + multiple=True, + help="A list of space separated input images; " + "or a single glob pattern such as 'directory/*.jpg'", +) +@click.option( + "--output", + help="A file or directory to save output visualizations. " + "If not given, will show output in an OpenCV window.", +) +@click.option( + "--confidence-threshold", + type=float, + default=0.5, + help="Minimum score for instance predictions to be shown", +) +@click.option("--weight", default="", metavar="FILE", help="path to model weights") +@click.option("--batch", default=0, help="batch size") +@click.option("--display/--no-display", default=False, help="display results") +@click.option("--cudagraph/--no-cudagraph", default=False, help="enable CUDA graph") +def run_model( + config, + bench_config, + input, + output, + confidence_threshold, + weight, + batch, + display, + cudagraph, +): + cfg = get_cfg_defaults() + cfg.merge_from_file(config) + if bench_config != "": + cfg.merge_from_file(bench_config) + if batch > 0: + cfg.SOLVER.IMS_PER_BATCH = batch + cfg.MODEL.WEIGHTS = weight + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold + cfg.freeze() + + assert ( + weight != "" + ), "export model first: python convert_pt2ait.py model_d2.pkl params_ait.pkl \ + --config configs/faster_rcnn_R_50_DC5.yaml" + + demo = Predictor(cfg) + print("run {} end2end".format(cfg.MODEL.NAME)) + + cnt = 0 + duration = 0 + detections = {} + bs = cfg.SOLVER.IMS_PER_BATCH + if input: + if len(input) == 1: + input = glob.glob(os.path.expanduser(input[0])) + assert input, "The input path(s) was not found" + batch_data = demo.data_loader(input) + print("{} images, run {} batch".format(len(input), len(batch_data))) + for batch in tqdm.tqdm(batch_data, disable=not output): + results = demo.run_batch(batch, cudagraph) + detections.update(results) + if display: + demo.visualize(results) + duration += demo.benchmark(batch["data"], 10, cudagraph) + cnt += 1 + + duration /= cnt * bs + print( + f"AIT Detection: Batch size: {bs}, Time per iter: {duration:.2f} ms, FPS: {1000 / duration:.2f}" + ) + + +if __name__ == "__main__": + run_model() diff --git a/examples/02_detectron2/modeling/backbone/__init__.py b/examples/02_detectron2/modeling/backbone/__init__.py new file mode 100644 index 000000000..e2778377d --- /dev/null +++ b/examples/02_detectron2/modeling/backbone/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from .fpn import build_resnet_fpn_backbone, FPN +from .resnet import ( + BasicStem, + BottleneckBlock, + build_resnet_backbone, + make_stage, + ResNet, +) + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/examples/02_detectron2/modeling/backbone/fpn.py b/examples/02_detectron2/modeling/backbone/fpn.py new file mode 100644 index 000000000..fe14b3b98 --- /dev/null +++ b/examples/02_detectron2/modeling/backbone/fpn.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math + +from aitemplate.compiler import ops +from aitemplate.frontend import nn + +from .resnet import build_resnet_backbone +from .utils import ShapeSpec + + +class FPN(nn.Module): + """ + This module implements :paper:`FPN`. + It creates pyramid features built on top of some input feature maps. + """ + + def __init__( + self, + bottom_up, + in_features, + out_channels, + norm="", + top_block=None, + fuse_type="sum", + square_pad=0, + ): + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + norm (str): the normalization to use. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + FPN output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra FPN levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + fuse_type (str): types for fusing the top down features and the lateral + ones. It can be "sum" (default), which sums up element-wise; or "avg", + which takes the element-wise mean of the two. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super().__init__() + assert in_features, in_features + + # Feature map strides and channels from the bottom up network (e.g. ResNet) + input_shapes = bottom_up.output_shape() + strides = [input_shapes[f].stride for f in in_features] + in_channels_per_feature = [input_shapes[f].channels for f in in_features] + + _assert_strides_are_log2_contiguous(strides) + lateral_convs = [] + output_convs = [] + + # use_bias = norm == "" + for idx, in_channels in enumerate(in_channels_per_feature): + lateral_conv = nn.Conv2dBias( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + output_conv = nn.Conv2dBias( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + stage = int(math.log2(strides[idx])) + self.add_module("fpn_lateral{}".format(stage), lateral_conv) + self.add_module("fpn_output{}".format(stage), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + self.top_block = top_block + self.in_features = tuple(in_features) + self.bottom_up = bottom_up + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = { + "p{}".format(int(math.log2(s))): s for s in strides + } + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + assert fuse_type in {"avg", "sum"} + self._fuse_type = fuse_type + + def size_divisibility(self): + return self._size_divisibility + + def padding_constraints(self): + return {"square_size": self._square_pad} + + def forward(self, x): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to + feature map tensor for each feature level in high to low resolution order. + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + bottom_up_features = self.bottom_up(x) + + results = [] + prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) + results.append(self.output_convs[0](prev_features)) + + # Reverse feature maps into top-down order (from low to high resolution) + for idx, (lateral_conv, output_conv) in enumerate( + zip(self.lateral_convs, self.output_convs) + ): + # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 + # Therefore we loop over all modules but skip the first one + if idx > 0: + features = self.in_features[-idx - 1] + features = bottom_up_features[features] + # top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") + lateral_features = lateral_conv(features) + # prev_features = lateral_features + top_down_features + interpolate_op = ops.upsampling2d_add(scale_factor=2.0, mode="nearest") + prev_features = interpolate_op(prev_features, lateral_features) + if self._fuse_type == "avg": + prev_features /= 2 + results.insert(0, output_conv(prev_features)) + + if self.top_block is not None: + if self.top_block.in_feature in bottom_up_features: + top_block_in_feature = bottom_up_features[self.top_block.in_feature] + else: + top_block_in_feature = results[ + self._out_features.index(self.top_block.in_feature) + ] + results.extend(self.top_block(top_block_in_feature)) + assert len(self._out_features) == len(results) + return {f: res for f, res in zip(self._out_features, results)} + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + +def _assert_strides_are_log2_contiguous(strides): + """ + Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". + """ + for i, stride in enumerate(strides[1:], 1): + assert ( + stride == 2 * strides[i - 1] + ), "Strides {} {} are not log2 contiguous".format(stride, strides[i - 1]) + + +class LastLevelMaxPool(nn.Module): + """ + This module is used in the original FPN to generate a downsampled + P6 feature from P5. + """ + + def __init__(self): + super().__init__() + self.num_levels = 1 + self.in_feature = "p5" + self.pool = nn.MaxPool2d(1, 2, 0) + + def forward(self, x): + return [self.pool(x)] + + +def build_resnet_fpn_backbone(cfg): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelMaxPool(), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/examples/02_detectron2/modeling/backbone/resnet.py b/examples/02_detectron2/modeling/backbone/resnet.py new file mode 100644 index 000000000..5b3777a0d --- /dev/null +++ b/examples/02_detectron2/modeling/backbone/resnet.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +from aitemplate.frontend import nn +from aitemplate.testing import detect_target + +from .utils import ShapeSpec + + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NHWC tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block), + with a conv, relu and max_pool. + """ + + def __init__(self, in_channels=3, out_channels=64, norm="BN"): + super().__init__(in_channels, out_channels, 4) + conv_op = ( + nn.Conv2dBiasReluFewChannels + if detect_target().name() == "cuda" + else nn.Conv2dBiasRelu + ) + self.conv1 = conv_op(in_channels, out_channels, 7, 2, 7 // 2) + self.pool = nn.MaxPool2d(3, 2, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.pool(x) + return x + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, + with two 3x3 conv layers and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + super().__init__(in_channels, out_channels, stride) + + def forward(self, x): + raise NotImplementedError() + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = nn.Conv2dBias(in_channels, out_channels, 1, stride, 0) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = nn.Conv2dBiasRelu( + in_channels, bottleneck_channels, 1, stride_1x1, 0 + ) + + self.conv2 = nn.Conv2dBiasRelu( + bottleneck_channels, + bottleneck_channels, + 3, + stride_3x3, + 1 * dilation, + dilation, + ) + + self.conv3 = nn.Conv2dBiasAddRelu(bottleneck_channels, out_channels, 1, 1, 0) + + # for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out = self.conv3(out, shortcut) + return out + + +class ResNet(nn.Module): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [ + {"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) + for f in out_features + ] + ) + stages = stages[:num_stages] + + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AvgPool2d(7, 1, 0) + self.linear = nn.Linear(curr_channels, num_classes) + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format( + ", ".join(children) + ) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + # assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + Returns: + list[CNNBlockBase]: a list of block module. + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert ( + newk not in kwargs + ), f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class( + in_channels=in_channels, out_channels=out_channels, **curr_kwargs + ) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip( + num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels + ): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret + + +def make_stage(*args, **kwargs): + """ + Deprecated alias for backward compatibiltiy. + """ + return ResNet.make_stage(*args, **kwargs) + + +def build_resnet_backbone(cfg): + """ + Create a ResNet instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=3, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert ( + out_channels == 64 + ), "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert ( + res5_dilation == 1 + ), "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + for idx, stage_idx in enumerate(range(2, 6)): + # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at) diff --git a/examples/02_detectron2/modeling/backbone/utils.py b/examples/02_detectron2/modeling/backbone/utils.py new file mode 100644 index 000000000..81a0cb203 --- /dev/null +++ b/examples/02_detectron2/modeling/backbone/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ShapeSpec: + """ + A simple structure that contains basic shape specification about a tensor. + It is often used as the auxiliary inputs/outputs of models, + to complement the lack of shape inference ability among pytorch modules. + """ + + channels: Optional[int] = None + height: Optional[int] = None + width: Optional[int] = None + stride: Optional[int] = None diff --git a/examples/02_detectron2/modeling/meta_arch/__init__.py b/examples/02_detectron2/modeling/meta_arch/__init__.py new file mode 100644 index 000000000..0093ee3f1 --- /dev/null +++ b/examples/02_detectron2/modeling/meta_arch/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from .rcnn import GeneralizedRCNN + +__all__ = list(globals().keys()) diff --git a/examples/02_detectron2/modeling/meta_arch/rcnn.py b/examples/02_detectron2/modeling/meta_arch/rcnn.py new file mode 100644 index 000000000..1f60c0171 --- /dev/null +++ b/examples/02_detectron2/modeling/meta_arch/rcnn.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from aitemplate.frontend import nn +from aitemplate.frontend.nn.proposal import gen_batch_inds + +from ..backbone import build_resnet_fpn_backbone +from ..proposal_generator import build_rpn_head +from ..roi_heads import build_roi_heads + + +class GeneralizedRCNN(nn.Module): + def __init__(self, cfg): + super().__init__() + im_shape = (cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST) + self._batch_size = cfg.SOLVER.IMS_PER_BATCH + self._mask_on = cfg.MODEL.MASK_ON + self._num_mask_roi = cfg.POSTPROCESS.TOPK + + self.backbone = build_resnet_fpn_backbone(cfg) + self.proposal_generator = build_rpn_head(cfg, im_shape) + self.roi_heads = build_roi_heads(cfg, im_shape) + self._params = self.get_params() + + def forward(self, x): + features = self.backbone(x) + rois, proposals = self.proposal_generator(features) + results = self.roi_heads(features, rois, proposals) + return results + + def set_anchors(self, mod): + self.proposal_generator.set_anchors(mod) + if self._mask_on: + batch_inds_mask = gen_batch_inds(self._batch_size, self._num_mask_roi) + weight = torch.from_numpy(batch_inds_mask).cuda().half() + mod.set_constant_with_tensor("batch_inds_mask", weight) + + def get_params(self): + params = self.proposal_generator.get_params() + if self._mask_on: + params["batch_inds_mask"] = gen_batch_inds( + self._batch_size, self._num_mask_roi + ) + return params diff --git a/examples/02_detectron2/modeling/proposal_generator/__init__.py b/examples/02_detectron2/modeling/proposal_generator/__init__.py new file mode 100644 index 000000000..07de3d901 --- /dev/null +++ b/examples/02_detectron2/modeling/proposal_generator/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from .rpn import build_rpn_head, RPN, StandardRPNHead + +__all__ = list(globals().keys()) diff --git a/examples/02_detectron2/modeling/proposal_generator/rpn.py b/examples/02_detectron2/modeling/proposal_generator/rpn.py new file mode 100644 index 000000000..ce7a0f2bc --- /dev/null +++ b/examples/02_detectron2/modeling/proposal_generator/rpn.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import torch +from aitemplate.compiler import ops +from aitemplate.compiler.base import Tensor +from aitemplate.frontend import nn + + +class StandardRPNHead(nn.Module): + """ + Standard RPN classification and regression heads described in :paper:`Faster R-CNN`. + Uses a 3x3 conv to produce a shared hidden state from which one 1x1 conv predicts + objectness logits for each anchor and a second 1x1 conv predicts bounding-box deltas + specifying how to deform each anchor into an object proposal. + """ + + def __init__( + self, + in_planes, + rpn_dim=256, + scales=((32,), (64,), (128,), (256,), (512,)), + ratios=(0.5, 1, 2), + ): + super().__init__() + num_anchors = len(scales) * len(ratios) + self.conv = nn.Conv2dBiasRelu(in_planes, rpn_dim, 3, 1, 1) + self.objectness_logits = nn.Conv2dBiasSigmoid(rpn_dim, num_anchors, 1, 1, 0) + self.anchor_deltas = nn.Conv2dBias(rpn_dim, num_anchors * 4, 1, 1, 0) + + def forward(self, features): + pred_objectness_logits = [] + pred_anchor_deltas = [] + for _, x in features.items(): + t = ops.conv2d_bias_relu(stride=1, pad=1)( + x, self.conv.weight.tensor(), self.conv.bias.tensor() + ) + logit = ops.conv2d_bias_sigmoid(stride=1, pad=0)( + t, + self.objectness_logits.weight.tensor(), + self.objectness_logits.bias.tensor(), + ) + reg = ops.conv2d_bias(stride=1, pad=0)( + t, self.anchor_deltas.weight.tensor(), self.anchor_deltas.bias.tensor() + ) + pred_objectness_logits.append(logit) + pred_anchor_deltas.append(reg) + + return pred_objectness_logits, pred_anchor_deltas + + +class RPN(nn.Module): + """ + Region Proposal Network, introduced by :paper:`Faster R-CNN`. + """ + + def __init__( + self, + cfg, + im_shape, + dtype="float16", + ): + super().__init__() + # fmt: off + in_planes = cfg.MODEL.FPN.OUT_CHANNELS + batch_size = cfg.SOLVER.IMS_PER_BATCH + rpn_pre_nms_top_n = cfg.MODEL.RPN.PRE_NMS_TOPK_TEST + rpn_post_nms_top_n = cfg.MODEL.RPN.POST_NMS_TOPK_TEST + self.iou_threshold = cfg.MODEL.RPN.NMS_THRESH + self.rpn_min_size = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE + self.scales = cfg.MODEL.ANCHOR_GENERATOR.SIZES + self.ratios = cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS + # fmt: on + self.rpn_pre_nms_top_n = rpn_pre_nms_top_n + self.rpn_post_nms_top_n = rpn_post_nms_top_n + self.topk = rpn_pre_nms_top_n + self.dtype = dtype + self.im_shape = im_shape + self.feat_strides = (4, 8, 16, 32, 64) + self.batch_size = batch_size + self.batch_inds = np.zeros((batch_size, rpn_post_nms_top_n, 1)).astype(dtype) + + self.rpn_head = StandardRPNHead( + in_planes, + in_planes, + scales=self.scales[0], + ratios=self.ratios, + ) + + self.proposal = nn.FPNProposal( + im_shape=im_shape, + feat_strides=self.feat_strides, + scales=self.scales, + ratios=self.ratios, + clip_box=True, + nms_on=False, + rpn_pre_nms_top_n=self.rpn_pre_nms_top_n, + rpn_post_nms_top_n=self.rpn_post_nms_top_n, + iou_threshold=self.iou_threshold, + rpn_min_size=self.rpn_min_size, + batch_size=batch_size, + ) + + def forward(self, features): + N = self.batch_size + pred_logits, pred_deltas = self.rpn_head(features) + pred_rois = self.proposal(pred_deltas) + + proposal_list = [] + score_list = [] + for rois, logit in zip(pred_rois, pred_logits): + rois = ops.reshape()(rois, [N, -1, 4]) + if self.topk > 0 and rois.shape()[1].value() > self.topk: + score_inds = ops.topk(k=self.topk)(ops.reshape()(logit, [N, -1])) + boxes_topk = ops.batch_gather()(rois, score_inds) + scores_topk = ops.batch_gather()( + ops.reshape()(logit, [N, -1, 1]), score_inds + ) + proposal_list.append(boxes_topk) + score_list.append(ops.reshape()(scores_topk, [N, -1])) + else: + proposal_list.append(rois) + score_list.append(ops.reshape()(logit, [N, -1])) + + proposals_concat = ops.concatenate()(proposal_list, dim=1) + scores_concat = ops.concatenate()(score_list, dim=1) + scores_r = ops.reshape()(scores_concat, [N, -1]) + proposals_r = ops.reshape()(proposals_concat, [N, -1, 4]) + + dets = ops.nms( + self.rpn_pre_nms_top_n, + self.rpn_post_nms_top_n, + self.iou_threshold, + self.rpn_min_size, + )(proposals_r, scores_r) + + batch_inds = Tensor( + shape=[N, self.rpn_post_nms_top_n, 1], + dtype=self.dtype, + name="batch_inds", + value=0, + ) + ret = ops.reshape()(ops.concatenate()([batch_inds, dets], dim=2), [-1, 5]) + return ret, ops.reshape()(dets, [-1, 4]) + + def set_anchors(self, mod): + param = {"batch_inds": self.batch_inds.copy()} + for idx, _ in enumerate(self.feat_strides): + param["anchors_%d" % (idx + 2)] = self.proposal._anchors[idx].copy() + + weights = {name: torch.from_numpy(w).cuda().half() for name, w in param.items()} + for name, weight in weights.items(): + mod.set_constant_with_tensor(name, weight) + + def get_params(self): + params = { + "anchors_%d" % (idx + 2): anchor.copy() + for idx, anchor in enumerate(self.proposal._anchors) + } + params["batch_inds"] = self.batch_inds + return params + + +def build_rpn_head(cfg, input_shape): + return RPN(cfg, input_shape) diff --git a/examples/02_detectron2/modeling/roi_heads/__init__.py b/examples/02_detectron2/modeling/roi_heads/__init__.py new file mode 100644 index 000000000..f812e3c17 --- /dev/null +++ b/examples/02_detectron2/modeling/roi_heads/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from .box_head import build_box_head, FastRCNNConvFCHead +from .mask_head import MaskRCNNConvUpsampleHead +from .roi_heads import build_roi_heads, StandardROIHeads + +__all__ = list(globals().keys()) diff --git a/examples/02_detectron2/modeling/roi_heads/box_head.py b/examples/02_detectron2/modeling/roi_heads/box_head.py new file mode 100644 index 000000000..0269a6a4a --- /dev/null +++ b/examples/02_detectron2/modeling/roi_heads/box_head.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Tuple + +from aitemplate.compiler import ops +from aitemplate.frontend import nn + +from .fast_rcnn import FastRCNNOutputLayers + + +class FastRCNNConvFCHead(nn.Module): + """ + A head with a multi_level roi align layer and two fc layers. + """ + + def __init__( + self, + num_rois: int, + num_classes: int, + feat_dim: int, + fc_dim: int, + pooled_size: int, + im_shape: Tuple[int, int], + ): + super().__init__() + self.num_rois = num_rois + HH, WW = im_shape + self.roi_align = ops.multi_level_roi_align( + num_rois=num_rois, + pooled_size=pooled_size, + spatial_scale=1.0, + sampling_ratio=0, + position_sensitive=False, + continuous_coordinate=False, + im_shape=im_shape, + ) + in_channel = int(feat_dim * pooled_size**2) + mid_channel = fc_dim + + self.fc1 = nn.Linear(in_channel, mid_channel, specialization="relu") + self.fc2 = nn.Linear(mid_channel, mid_channel, specialization="relu") + + def forward(self, feat, rois): + roi_feat = self.roi_align(feat[0], feat[1], feat[2], feat[3], rois) + roi_feat = ops.reshape()(roi_feat, [ops.size()(roi_feat, 0), -1]) + fc1 = self.fc1(roi_feat) + fc2 = self.fc2(fc1) + return fc2 + + +def build_box_head(cfg, input_shape): + """ + Build a box head through `FastRCNNOutputLayers`. + """ + return FastRCNNOutputLayers(cfg, input_shape) diff --git a/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py b/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py new file mode 100644 index 000000000..d825a59a0 --- /dev/null +++ b/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Tuple + +from aitemplate.compiler import ops +from aitemplate.frontend import nn, Tensor + + +class fast_rcnn_inference: + def __init__( + self, + im_shape: Tuple[int, int], + num_rois: int, + num_classes: int, + clip_box: bool = True, + nms_on: bool = True, + use_topk: bool = True, + topk_per_image: int = 100, + iou_thresh: float = 0.5, + roi_align_on: bool = False, + batch_size: int = 1, + dtype: str = "float16", + ): + self.im_h, self.im_w = im_shape + self.num_rois = num_rois + self.num_classes = num_classes + self.dtype = dtype + self.clip_box = clip_box + self.topk_per_image = topk_per_image + self.iou_threshold = iou_thresh + self.nms_on = nms_on + self.use_topk = use_topk + self.roi_align_on = roi_align_on + self.batch_size = batch_size + self.class_agnostic_nms = False + + def __call__(self, boxes, scores, deltas): + """ + Args: + boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic + boxes for each image. Element i has shape (Ri, K * 4) if doing + class-specific regression, or (Ri, 4) if doing class-agnostic + regression, where Ri is the number of predicted objects for image i. + + scores (list[Tensor]): A list of Tensors of predicted class scores for each image. + Element i has shape (Ri, K + 1), where Ri is the number of predicted objects + for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. + + deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box + transform (see :class:`fast_rcnn_inference.box_transform`). + + Returns: + proposals. + """ + proposals = self.box_transform(boxes, deltas) + if self.nms_on: + return self.nms_wrapper(proposals, scores) + else: + return proposals + + def nms_wrapper(self, proposals, scores): + N = self.batch_size + proposals_p = ops.permute102()(proposals) + scores_x = ops.dynamic_slice()( + scores, start_indices=[0, 0], end_indices=[self.num_rois, self.num_classes] + ) + + OP = ops.efficient_nms( + self.num_rois // N, self.topk_per_image, self.iou_threshold, 0 + ) + args = ( + ops.reshape()(proposals_p, [N, -1, self.num_classes, 4]), + ops.reshape()(scores_x, [N, -1, self.num_classes]), + ) + detections = OP(*args) + if self.roi_align_on: + batch_inds = Tensor( + shape=[N, self.topk_per_image, 1], + dtype=self.dtype, + name="batch_inds_mask", + value=0, + ) + rois = ops.reshape()( + ops.concatenate()([batch_inds, detections[1]], dim=2), [-1, 5] + ) + return detections + (rois,) + else: + return detections + + def layout_transform(self, delta): + return ops.permute210()( + ops.reshape()(delta, [1, self.num_rois, self.num_classes]) + ) + + def apply_weight(self, deltas, weights=(0.1, 0.2)): + ww = weights[0] + wh = weights[1] + + deltas_r = ops.reshape()(deltas, [self.num_rois, -1, 4]) + (delta_x, delta_y, delta_w, delta_h) = ops.split()(deltas_r, 1, dim=2) + delta_xm = delta_x * ww + delta_ym = delta_y * ww + delta_wm = delta_w * wh + delta_hm = delta_h * wh + + return ( + self.layout_transform(delta_xm), + self.layout_transform(delta_ym), + self.layout_transform(delta_wm), + self.layout_transform(delta_hm), + ) + + def box_transform(self, boxes, deltas): + """ + The box-to-box transform defined in R-CNN. The transformation is parameterized by 4 deltas: (dx, dy, dw, dh). The transformation scales the box’s width and height by exp(dw), exp(dh) and shifts a box’s center by the offset (dx * width, dy * height). + """ + const_0_5 = 0.5 + + (delta_x, delta_y, delta_w, delta_h) = self.apply_weight(deltas) + + boxes_r = ops.reshape()(boxes, [self.num_rois, 4]) + (anchor_x1, anchor_y1, anchor_x2, anchor_y2) = ops.split()(boxes_r, 1, dim=1) + widths = ops.reshape()(anchor_x2 - anchor_x1, [self.num_rois, 1]) + heights = ops.reshape()(anchor_y2 - anchor_y1, [self.num_rois, 1]) + + width_mid = widths * const_0_5 + height_mid = heights * const_0_5 + ctr_x = anchor_x1 + width_mid + ctr_y = anchor_y1 + height_mid + + pred_ctr_x = (delta_x * widths) + ctr_x + + pred_ctr_y = (delta_y * heights) + ctr_y + pred_w = ops.exp(delta_w) * widths + pred_h = ops.exp(delta_h) * heights + + p_x1 = pred_ctr_x - (const_0_5 * pred_w) + p_y1 = pred_ctr_y - (const_0_5 * pred_h) + p_x2 = pred_ctr_x + (const_0_5 * pred_w) + p_y2 = pred_ctr_y + (const_0_5 * pred_h) + + if self.clip_box: + f_x1, f_y1, f_x2, f_y2 = self.box_clip(p_x1, p_y1, p_x2, p_y2) + proposals = ops.concatenate()([f_x1, f_y1, f_x2, f_y2], dim=2) + else: + proposals = ops.concatenate()([p_x1, p_y1, p_x2, p_y2], dim=2) + + return proposals + + def box_clip(self, p_x1, p_y1, p_x2, p_y2): + x_min = 0 + x_max_h = self.im_h + x_max_w = self.im_w + + f_x1 = ops.hardtanh(p_x1, x_min, x_max_w) + f_y1 = ops.hardtanh(p_y1, x_min, x_max_h) + f_x2 = ops.hardtanh(p_x2, x_min, x_max_w) + f_y2 = ops.hardtanh(p_y2, x_min, x_max_h) + return f_x1, f_y1, f_x2, f_y2 + + +class FastRCNNOutputLayers(nn.Module): + """ + Two linear layers for predicting Fast R-CNN outputs: + + 1. proposal-to-detection box regression deltas + 2. classification scores + + , and a postprocess procedure. + """ + + def __init__(self, cfg, im_shape): + super().__init__() + in_channel = cfg.MODEL.ROI_BOX_HEAD.FC_DIM + num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES + + self.cls_score = nn.Linear(in_channel, num_classes + 1) + self.bbox_pred = nn.Linear(in_channel, num_classes * 4) + + self.postprocess = fast_rcnn_inference( + im_shape=im_shape, + num_classes=num_classes, + num_rois=cfg.MODEL.RPN.POST_NMS_TOPK_TEST * cfg.SOLVER.IMS_PER_BATCH, + use_topk=cfg.POSTPROCESS.USE_TOPK, + roi_align_on=True if cfg.MODEL.MASK_ON else False, + topk_per_image=cfg.POSTPROCESS.TOPK, + iou_thresh=cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS[0], + clip_box=True, + nms_on=True, + batch_size=cfg.SOLVER.IMS_PER_BATCH, + ) + + def forward(self, x, proposals): + rcnn_logit = self.cls_score(x) + rcnn_logit = ops.softmax()(rcnn_logit, -1) + rcnn_reg = self.bbox_pred(x) + return self.postprocess(proposals, rcnn_logit, rcnn_reg) diff --git a/examples/02_detectron2/modeling/roi_heads/mask_head.py b/examples/02_detectron2/modeling/roi_heads/mask_head.py new file mode 100644 index 000000000..94e022205 --- /dev/null +++ b/examples/02_detectron2/modeling/roi_heads/mask_head.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Tuple + +from aitemplate.compiler import ops +from aitemplate.frontend import nn + + +class MaskRCNNConvUpsampleHead(nn.Module): + """ + A mask head with several conv layers, plus an upsample layer (or `ConvTranspose2d`). + Predictions are made with a final 1x1 conv layer. + """ + + def __init__( + self, + num_rois: int, + num_classes: int, + feat_dim: int, + conv_dim: int, + pooled_size: int, + im_shape: Tuple[int, int], + ): + super().__init__() + HH, WW = im_shape + self.roi_align = ops.multi_level_roi_align( + num_rois=num_rois, + pooled_size=pooled_size, + spatial_scale=1.0, + sampling_ratio=0, + position_sensitive=False, + continuous_coordinate=False, + im_shape=im_shape, + ) + in_channel = feat_dim + mid_channel = conv_dim + + self.mask_fcn1 = nn.Conv2dBiasRelu(in_channel, mid_channel, 3, 1, 1) + self.mask_fcn2 = nn.Conv2dBiasRelu(mid_channel, mid_channel, 3, 1, 1) + self.mask_fcn3 = nn.Conv2dBiasRelu(mid_channel, mid_channel, 3, 1, 1) + self.mask_fcn4 = nn.Conv2dBiasRelu(mid_channel, mid_channel, 3, 1, 1) + self.deconv = nn.ConvTranspose2dBiasRelu(mid_channel, mid_channel, 2, 2, 0) + self.predictor = nn.Conv2dBiasSigmoid(mid_channel, num_classes, 1, 1, 0) + + def forward(self, feat, rois): + roi_feat = self.roi_align(feat[0], feat[1], feat[2], feat[3], rois) + conv1 = self.mask_fcn1(roi_feat) + conv2 = self.mask_fcn2(conv1) + conv3 = self.mask_fcn3(conv2) + conv4 = self.mask_fcn4(conv3) + upsp = self.deconv(conv4) + mask = self.predictor(upsp) + return mask diff --git a/examples/02_detectron2/modeling/roi_heads/roi_heads.py b/examples/02_detectron2/modeling/roi_heads/roi_heads.py new file mode 100644 index 000000000..587d9601b --- /dev/null +++ b/examples/02_detectron2/modeling/roi_heads/roi_heads.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict + +from aitemplate.compiler import ops + +from aitemplate.frontend import nn, Tensor + +from .box_head import build_box_head, FastRCNNConvFCHead +from .mask_head import MaskRCNNConvUpsampleHead + + +class StandardROIHeads(nn.Module): + """ + The StandardROIHeads in a typical "C4" R-CNN model, where + the box and mask head share the cropping and + the per-region feature computation by a Res5 block. + See :paper:`ResNet` Appendix A. + """ + + def __init__(self, cfg, input_shape): + super().__init__() + self.mask_on = cfg.MODEL.MASK_ON + self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + self.box_predictor = build_box_head(cfg, input_shape) + + self.box_head = FastRCNNConvFCHead( + num_rois=cfg.MODEL.RPN.POST_NMS_TOPK_TEST, + num_classes=cfg.MODEL.ROI_HEADS.NUM_CLASSES, + feat_dim=cfg.MODEL.FPN.OUT_CHANNELS, + fc_dim=cfg.MODEL.ROI_BOX_HEAD.FC_DIM, + pooled_size=cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION, + im_shape=input_shape, + ) + if cfg.MODEL.MASK_ON: + self.mask_head = MaskRCNNConvUpsampleHead( + num_rois=cfg.POSTPROCESS.TOPK, + num_classes=cfg.MODEL.ROI_HEADS.NUM_CLASSES, + feat_dim=cfg.MODEL.FPN.OUT_CHANNELS, + conv_dim=cfg.MODEL.ROI_MASK_HEAD.CONV_DIM, + pooled_size=cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION, + im_shape=input_shape, + ) + + def get_shape(self, x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + def forward(self, features: Dict[str, Tensor], rois: Tensor, proposals: Tensor): + + box_features = [features[f] for f in self.in_features] + roi_feat = self.box_head(box_features, rois) + detections = self.box_predictor(roi_feat, proposals) + if self.mask_on: + num_dets, boxes, probs, class_pred, mask_rois = detections + pred_mask_logits = self.mask_head(box_features, mask_rois) + + num_rois, roi_size, _, num_classes = self.get_shape(pred_mask_logits) + batch_size = self.get_shape(boxes)[0] + batch_rois = num_rois // batch_size + + pred_mask_logits = ops.permute021()( + ops.reshape()(pred_mask_logits, [num_rois, -1, num_classes]) + ) + indices = ops.reshape()(class_pred, [num_rois, 1]) + mask_probs_pred = ops.batch_gather()(pred_mask_logits, indices) + mask_probs_pred = ops.reshape()( + mask_probs_pred, [batch_size, batch_rois, roi_size, roi_size] + ) + return num_dets, boxes, probs, class_pred, mask_probs_pred + else: + return detections + + +def build_roi_heads(cfg, input_shape): + """ + Build ROIHeads through `StandardROIHeads`. + """ + return StandardROIHeads(cfg, input_shape) diff --git a/examples/02_detectron2/predictor/__init__.py b/examples/02_detectron2/predictor/__init__.py new file mode 100644 index 000000000..96a749d96 --- /dev/null +++ b/examples/02_detectron2/predictor/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .builtin_meta import _get_coco_instances_meta +from .predictor import Predictor + +__all__ = ["Predictor", "_get_coco_instances_meta"] diff --git a/examples/02_detectron2/predictor/builtin_meta.py b/examples/02_detectron2/predictor/builtin_meta.py new file mode 100644 index 000000000..c09e5a5ba --- /dev/null +++ b/examples/02_detectron2/predictor/builtin_meta.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Note: +For your custom dataset, there is no need to hard-code metadata anywhere in the code. +For example, for COCO-format dataset, metadata will be obtained automatically +when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways +during loading. +However, we hard-coded metadata for a few common dataset here. +The only goal is to allow users who don't have these dataset to use pre-trained models. +Users don't have to download a COCO json (which contains metadata), in order to visualize a +COCO model (with correct class names and colors). +""" + + +# All coco categories, together with their nice-looking visualization colors +# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, + {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, + {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, + {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, + {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, + {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, + {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, + {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, + {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, + {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, + {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, + {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, + {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, + {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, + {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, + {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, + {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, + {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, + {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, + {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, + {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, + {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, + {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, + {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, + {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, + {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, + {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, + {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, + {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, + {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, + {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, + {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, + {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, + {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, + {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, + {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, + {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, + {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, + {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, + {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, + {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, + {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, + {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, + {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, + {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, + {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, + {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, + {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, + {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, + {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, + {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, + {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, + {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}, +] + + +def _get_coco_instances_meta(): + thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1] + thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] + assert len(thing_ids) == 80, len(thing_ids) + # Mapping from the incontiguous COCO category id to an id in [0, 79] + thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} + thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] + ret = { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + "thing_colors": thing_colors, + } + return ret diff --git a/examples/02_detectron2/predictor/predictor.py b/examples/02_detectron2/predictor/predictor.py new file mode 100644 index 000000000..324a138c2 --- /dev/null +++ b/examples/02_detectron2/predictor/predictor.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import itertools +import os +from typing import Tuple + +import cv2 +import numpy as np +import torch + +from aitemplate.compiler import Model +from modeling.meta_arch import GeneralizedRCNN +from PIL import Image + +from .builtin_meta import _get_coco_instances_meta + + +class Predictor: + """ + Use this class to create AIT inference instances for detectron2 models. It includes procedures that is to 1) preprocess the input images, 2) load the weights of the AIT model, 3) run the AIT model and visualize the outputs, 4) benchmark the AIT model. + """ + + def __init__(self, cfg, workdir="./tmp"): + self.cfg = cfg + self.model_name = cfg.MODEL.NAME + self.batch_size = cfg.SOLVER.IMS_PER_BATCH + self.im_shape = (cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST) + self.pixel_mean = cfg.MODEL.PIXEL_MEAN + self.pixel_std = cfg.MODEL.PIXEL_STD + self.mask_on = cfg.MODEL.MASK_ON + self.model = GeneralizedRCNN(cfg) + self.weights = self.get_parameters() + self.module = self.init_modules(cfg.MODEL.NAME, workdir) + self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES + self.min_size = cfg.INPUT.MIN_SIZE_TEST + self.max_size = cfg.INPUT.MAX_SIZE_TEST + self.interp_method = Image.BILINEAR + + def get_parameters(self): + """ + Obtain the weights. + """ + parameters = { + name: w.contiguous().cuda().half() + for name, w in torch.load(self.cfg.MODEL.WEIGHTS).items() + } + for name, param in self.model._params.items(): + parameters[name] = torch.from_numpy(param).cuda().half() + return parameters + + def preprocess(self, im_path: str, pad_value: float = 0.0): + """ + Image preprocess: resize the image (see `apply_transform`), normalize the pixels, + and add padding. + """ + # HH, WW = self.im_shape + ori_img = cv2.imread(im_path) + ori_shape = ori_img.shape + if ori_shape[0] > ori_shape[1]: + img = np.rot90(ori_img, k=1) + else: + img = ori_img + inputs = self.apply_transform(img) + resize_scale = img.shape[0] / inputs.shape[0] + pixel_mean = np.array(self.pixel_mean).reshape(1, 1, -1) + pixel_std = np.array(self.pixel_std).reshape(1, 1, -1) + inputs = (inputs - pixel_mean) / pixel_std + padding_size = ( + (0, self.min_size - inputs.shape[0]), + (0, self.max_size - inputs.shape[1]), + (0, 0), + ) + inputs = np.pad(inputs, padding_size, constant_values=pad_value) + inputs = inputs[np.newaxis, :] + return inputs.astype("float16"), ori_img, ori_shape, resize_scale + + def apply_transform(self, img): + """ + Resize the image while keeping the aspect ratio unchanged. + It attempts to scale the shorter edge to the given `short_edge_length`, + as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. + """ + h, w = img.shape[:2] + new_h, new_w = Predictor.get_output_shape(h, w, self.min_size, self.max_size) + if len(img.shape) > 2 and img.shape[2] == 1: + pil_image = Image.fromarray(img[:, :, 0], mode="L") + else: + pil_image = Image.fromarray(img) + pil_image = pil_image.resize((new_w, new_h), self.interp_method) + ret = np.asarray(pil_image) + if len(img.shape) > 2 and img.shape[2] == 1: + ret = np.expand_dims(ret, -1) + return ret + + def apply_bbox(self, bbox, im_w, im_h): + if im_h > im_w: + x0 = bbox[:, 0][..., np.newaxis] + y0 = bbox[:, 1][..., np.newaxis] + x1 = bbox[:, 2][..., np.newaxis] + y1 = bbox[:, 3][..., np.newaxis] + bbox = np.hstack((im_w - y1, x0, im_w - y0, x1)) + return bbox + + @staticmethod + def get_output_shape( + oldh: int, oldw: int, short_edge_length: int, max_size: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target short edge length. + """ + h, w = oldh, oldw + size = short_edge_length * 1.0 + scale = size / min(h, w) + if h < w: + newh, neww = size, scale * w + else: + newh, neww = scale * h, size + if max(newh, neww) > max_size: + scale = max_size * 1.0 / max(newh, neww) + newh = newh * scale + neww = neww * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def data_loader(self, image_list): + """ + Load the images and convert them to batched data. + """ + batch_data = [] + HH, WW = self.im_shape + batch = np.zeros((self.batch_size, HH, WW, 3), dtype="float16") + img_paths, img_shapes, img_scales, raw_images = [], [], [], [] + num_samples = len(image_list) + max_iter = ( + (num_samples + self.batch_size - 1) // self.batch_size * self.batch_size + ) + datasets = itertools.cycle(image_list) + for idx in range(max_iter): + im_path = next(datasets) + input_data, raw_input, im_shape, im_scale = self.preprocess(im_path) + im_name = im_path.split("/")[-1] + img_paths.append(im_name) + img_shapes.append(im_shape) + img_scales.append(im_scale) + raw_images.append(raw_input) + batch[idx % self.batch_size, :, :, :] = input_data + if (idx + 1) % self.batch_size == 0: + batch_data.append( + { + "data": batch.astype("float16"), + "image_shape": img_shapes, + "image_scale": img_scales, + "path": img_paths, + "image": raw_images, + } + ) + img_paths, img_shapes, img_scales, raw_images = [], [], [], [] + return batch_data + + def init_modules(self, detection_model_name, workdir): + """ + Load the AIT module of the detection model, and set the weights. + """ + mod = Model(os.path.join(workdir, detection_model_name, "test.so")) + for name, weight in self.weights.items(): + mod.set_constant_with_tensor(name, weight) + + return mod + + def run_batch(self, batch_data, graph_mode=False): + """ + Run the inference of the AIT model with batched data. + """ + score_thresh = self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST + results = {} + inputs = batch_data["data"] + image_list = batch_data["path"] + image_shapes = batch_data["image_shape"] + image_scales = batch_data["image_scale"] + images = batch_data["image"] + ret = self.run_on_image(inputs, graph_mode=graph_mode) + batched_boxes, batched_scores, batched_classes = ret[:3] + if self.mask_on: + batched_masks = ret[-1] + for i in range(self.batch_size): + boxes, scores, classes = ( + batched_boxes[i, :], + batched_scores[i, :], + batched_classes[i, :], + ) + + filter_mask = scores > score_thresh + filter_inds = filter_mask.nonzero()[0] + scores = scores[filter_inds] + boxes = boxes[filter_inds, :] * image_scales[i] + boxes = self.apply_bbox(boxes, image_shapes[i][1], image_shapes[i][0]) + classes = classes[filter_inds] + + results[image_list[i]] = { + "boxes": boxes, + "scores": scores, + "classes": classes, + "image_height": image_shapes[i][0], + "image_width": image_shapes[i][1], + "num_instances": boxes.shape[0], + "image": images[i], + } + if self.mask_on: + mask_pred = batched_masks[i, filter_inds, :, :] + results[image_list[i]]["masks"] = mask_pred + return results + + @staticmethod + def overlay(image, mask, color, alpha_transparency=0.5): + for channel in range(3): + image[:, :, channel] = np.where( + mask == 1, + image[:, :, channel] * (1 - alpha_transparency) + + alpha_transparency * color[channel] * 255, + image[:, :, channel], + ) + return image + + def visualize( + self, detections, output_path="./tmp/outputs", thickness=1, mask_thresh=0.5 + ): + """ + Visualize the outputs. + """ + os.makedirs(output_path, exist_ok=True) + meta_data = _get_coco_instances_meta() + thing_colors = meta_data["thing_colors"] + thing_classes = meta_data["thing_classes"] + for file_name, result in detections.items(): + img = result["image"] + boxes = result["boxes"] + classes = result["classes"] + scores = result["scores"] + for pred_box, pred_class, pred_score in zip(boxes, classes, scores): + box = pred_box.astype("int") + start_point = (box[0], box[1]) + end_point = (box[2], box[3]) + color = tuple(thing_colors[pred_class]) + img = cv2.rectangle(img, start_point, end_point, color, thickness) + text = thing_classes[pred_class] + ": " + str(pred_score) + img = cv2.putText( + img, + text, + (box[0], box[1] - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + thickness, + ) + + if self.mask_on: + masks = result["masks"] + im_height, im_width = img.shape[:2] + for pred_box, pred_class, mask in zip(boxes, classes, masks): + np_color = np.array(thing_colors[pred_class]) / 255 + if im_height > im_width: + mask = np.rot90(mask, k=-1) + box = pred_box.astype("int") + det_width = box[2] - box[0] + det_height = box[3] - box[1] + mask = mask.astype(np.float32) + small_mask = Image.fromarray(mask) + mask = small_mask.resize( + (det_width, det_height), resample=self.interp_method + ) + mask = np.array(mask, copy=False) + mask = np.array(mask > mask_thresh, dtype=np.uint8) + padded_mask = np.zeros((im_height, im_width), dtype=np.uint8) + x_0 = max(box[0], 0) + x_1 = min(box[2], im_width) + y_0 = max(box[1], 0) + y_1 = min(box[3], im_height) + padded_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + img = Predictor.overlay(img, padded_mask, np_color) + cv2.imwrite(os.path.join(output_path, file_name), img) + + def run_on_image(self, inputs, graph_mode=False): + """ + Call the AIT module for the inference of the model on given inputs, and return the outputs. + """ + topk = self.cfg.POSTPROCESS.TOPK + mod = self.module + if type(inputs) is np.ndarray: + arr = torch.from_numpy(inputs).cuda() + else: + arr = inputs.contiguous() + + inputs = [arr] + + outputs = [ + torch.empty([self.batch_size, 1], dtype=torch.int64).cuda(), + torch.empty([self.batch_size, topk, 4]).cuda().half(), + torch.empty([self.batch_size, topk]).cuda().half(), + torch.empty([self.batch_size, topk], dtype=torch.int64).cuda(), + ] + if self.mask_on: + mask_size = self.cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION * 2 + mask_blob = torch.empty([self.batch_size, topk, mask_size, mask_size]) + outputs.append(mask_blob.cuda().half()) + mod.run_with_tensors(inputs, outputs, graph_mode=graph_mode) + + ret = [ + outputs[1].cpu().numpy(), + outputs[2].cpu().numpy(), + outputs[3].cpu().numpy(), + ] + if self.mask_on: + ret.append(outputs[-1].cpu().numpy()) + return ret + + def benchmark(self, inputs, count=10, graph_mode=False): + """ + Benchmark the inference of the AIT model on given inputs, and return the runtime in ms. + """ + mod = self.module + if type(inputs) is np.ndarray: + arr = torch.from_numpy(inputs).cuda() + else: + arr = inputs.cuda().contiguous() + topk = self.cfg.POSTPROCESS.TOPK + outputs = [ + torch.empty([self.batch_size, 1], dtype=torch.int64).cuda(), + torch.empty([self.batch_size, topk, 4]).cuda().half(), + torch.empty([self.batch_size, topk]).cuda().half(), + torch.empty([self.batch_size, topk], dtype=torch.int64).cuda(), + ] + if self.mask_on: + mask_blob = torch.empty([self.batch_size, topk, 28, 28]) + outputs.append(mask_blob.cuda().half()) + + duration, _, _ = mod.benchmark_with_tensors( + [arr], + outputs, + count=count, + repeat=2, + graph_mode=graph_mode, + ) + return duration diff --git a/examples/02_detectron2/prepare_and_run_rcnn.sh b/examples/02_detectron2/prepare_and_run_rcnn.sh new file mode 100755 index 000000000..f1edabe26 --- /dev/null +++ b/examples/02_detectron2/prepare_and_run_rcnn.sh @@ -0,0 +1,59 @@ +#!/bin/bash -e +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +### Download COCO 2017 Dataset + +#### Download image annotations +BASE=https://dl.fbaipublicfiles.com/detectron2 +ROOT=~/.torch/datasets +mkdir -p $ROOT/coco/annotations +echo "$ROOT" + +for anno in instances_val2017_100 \ + person_keypoints_val2017_100 ; do + + dest=$ROOT/coco/annotations/$anno.json + [[ -s $dest ]] && { + echo "$dest exists. Skipping ..." + } || { + wget $BASE/annotations/coco/$anno.json -O $dest + } +done + +#### Download images +dest=$ROOT/coco/val2017_100.tgz +[[ -d $ROOT/coco/val2017 ]] && { + echo "$ROOT/coco/val2017 exists. Skipping ..." +} || { + wget $BASE/annotations/coco/val2017_100.tgz -O $dest + tar xzf $dest -C $ROOT/coco/ && rm -f $dest +} +IMG_PATH=$ROOT/coco/val2017 + +### Download Pre-trained Model + +MODEL_PATH=~/.torch/model +mkdir -p $MODEL_PATH +MODEL_NAME=mask_rcnn_R_50_FPN + +wget $BASE/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl -O tmp/pt_$MODEL_NAME.pkl + +### Build AIT Model, Export the Pre-trained Weights and Run Inference + +cfg=examples/02_detectron2/configs/$MODEL_NAME.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 examples/02_detectron2/compile_model.py \ + --config $cfg \ + --batch 1 + +python3 examples/02_detectron2/tools/convert_pt2ait.py \ + --d2-weight tmp/pt_$MODEL_NAME.pkl \ + --ait-weight tmp/ait_$MODEL_NAME.pt \ + --model-name $MODEL_NAME + +python3 examples/02_detectron2/demo.py \ + --weight tmp/ait_$MODEL_NAME.pt \ + --config $cfg \ + --batch 1 --input "$IMG_PATH/*.jpg" \ + --confidence-threshold 0.5 \ + --display \ + --cudagraph diff --git a/examples/02_detectron2/tools/convert_pt2ait.py b/examples/02_detectron2/tools/convert_pt2ait.py new file mode 100644 index 000000000..584e14560 --- /dev/null +++ b/examples/02_detectron2/tools/convert_pt2ait.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +script for converting model from detectron2 to aitemplate +""" + +import json +import os +import pickle as pkl + +import click + +import numpy as np +import torch +from aitemplate.testing import detect_target + +# pylint: disable=C0103 + + +class detectron2_export: + def __init__(self, model_name): + self.model_name = model_name + + def export_model(self, model): + fuse_model = {} + bn_keys = set() + for k, _ in model.items(): + if "norm" in k: + param_name = k.split(".norm")[0] + if param_name in bn_keys: + continue + bn_keys.add(param_name) + self.transform_params(param_name, model, fuse_model, fuse_bn=True) + else: + self.transform_params(k, model, fuse_model, fuse_bn=False) + + ait_model = { + k.replace(".", "_"): weight + for k, weight in fuse_model.items() + if "anchors" not in k + } + + if detect_target().name() == "cuda": + self.export_conv0(ait_model, fuse_model) + + self.check_model(ait_model) + return ait_model + + def check_model(self, ait_model): + with open(os.path.join("./tmp", self.model_name, "params.json")) as fi: + param_map = json.load(fi) + for name, shape in param_map: + assert ait_model[name].shape == tuple( + shape + ), "weight shape mismatch {} {} expected {}".format( + name, ait_model[name].shape, shape + ) + + def fuse_conv_bn_weights( + self, conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False + ): + conv_w = torch.tensor(conv_w) + bn_rm = torch.tensor(bn_rm) + bn_rv = torch.tensor(bn_rv) + bn_w = torch.tensor(bn_w) + bn_b = torch.tensor(bn_b) + bn_eps = torch.tensor(bn_eps) + + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + + for arr in [conv_w.numpy(), conv_b.numpy()]: + if np.isnan(arr).any(): + print("fuse bn error") + return conv_w, conv_b + + def transform_params(self, param_name, obj, fuse_model, fuse_bn=True): + if not fuse_bn: + arr = obj[param_name] + if len(arr.shape) == 4: + arr = np.transpose(arr, (0, 2, 3, 1)) + elif "fc1.weight" in param_name: + arr = arr.reshape((1024, -1, 7, 7)) + arr = np.transpose(arr, (0, 2, 3, 1)) + arr = arr.reshape((1024, -1)) + fuse_model[param_name] = torch.tensor(arr) + + else: + conv_k = "%s.weight" % (param_name) + conv_b = "%s.bias" % (param_name) + bn_w_k = "%s.norm.weight" % (param_name) + bn_b_k = "%s.norm.bias" % (param_name) + bn_rm_k = "%s.norm.running_mean" % (param_name) + bn_rv_k = "%s.norm.running_var" % (param_name) + fused_conv_weight, fused_conv_bias = self.fuse_conv_bn_weights( + obj[conv_k], + None, + obj[bn_rm_k], + obj[bn_rv_k], + 1e-5, + obj[bn_w_k], + obj[bn_b_k], + ) + fuse_model[conv_k] = fused_conv_weight.permute((0, 2, 3, 1)) + fuse_model[conv_b] = fused_conv_bias + + def export_conv0(self, ait_model, fuse_model): + pt_name = "backbone.bottom_up.stem.conv1.weight" + x = fuse_model[pt_name] + conv_w = torch.zeros((64, 7, 7, 4)) + conv_w[:, :, :, :3] = x + ait_model[pt_name.replace(".", "_")] = conv_w + + +@click.command() +@click.option("--model-name", default="", metavar="FILE", help="path to ait param file") +@click.option("--d2-weight", default="", metavar="FILE", help="D2 weight") +@click.option("--ait-weight", default="", metavar="FILE", help="AIT weight") +def export_pt_model_to_ait(model_name, d2_weight, ait_weight): + d2ait = detectron2_export(model_name) + with open(d2_weight, "rb") as f: + file = f.read() + obj = pkl.loads(file, encoding="latin1") + pt_model = obj["model"] + + ait_model = d2ait.export_model(pt_model) + + torch.save(ait_model, ait_weight) + + +if __name__ == "__main__": + export_pt_model_to_ait() diff --git a/examples/03_bert/README.md b/examples/03_bert/README.md new file mode 100644 index 000000000..2c6e4a489 --- /dev/null +++ b/examples/03_bert/README.md @@ -0,0 +1,303 @@ +# BERT + +This directory contains an AIT demo for the [BERT language representation model](https://huggingface.co/docs/transformers/v4.22.1/en/model_doc/bert). + +Only `bert-base-uncased` is included. + +## Prerequisites + +Install the dependencies: +``` +python3 -m pip install transformers click torch +``` + +## Benchmarking + +To run a basic benchmark, use `benchmark.py`: + +``` +python3 examples/03_bert/benchmark_ait.py +``` + +There are two options for hidden activations, `gelu` and `fast_gelu` (`fast_gelu` by default). +`gelu` is not supported on AMD hardware yet. + +``` +python3 examples/03_bert/benchmark_ait.py --activation gelu +python3 examples/03_bert/benchmark_ait.py --activation fast_gelu +``` + +The batch size and sequence length can also be configured via the command line: +``` +python3 examples/03_bert/benchmark_ait.py --batch_size 1 --seq_length 128 +``` + +PyTorch eager mode benchmarks can also be run: +``` +python3 examples/03_bert/benchmark_pt.py +``` + +To benchmark BERT embeddings, run benchmark with `--encoders-only False` + +## Quick Demo + +To run a quick demo with a simple prompt, use `demo.py`: +``` +python3 examples/03_bert/demo.py --prompt "The quick brown fox jumps over the lazy dog." +``` + +The demo prints out the resulting logits. The demo only works with sequence length <= 512. + +## Multi-GPU profiling +AIT requires to do profiling to decide best algorithms for CUTLASS and CK. +To enable multiple GPUs profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform. + +## Reference Speed vs PyTorch Eager +_PT = PyTorch 1.12 Eager_ +_OOM = Out of Memory_ + +### A100-40GB / CUDA 11.6.2 + +- Sequence length 64 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 7.96 | 125.65 | 0.71 | 1399.64 | +| 2 | 8.38 | 238.59 | 0.74 | 2719.15 | +| 4 | 8.29 | 482.30 | 0.80 | 4994.37 | +| 8 | 8.51 | 939.97 | 0.95 | 8439.67 | +| 16 | 8.09 | 1978.47 | 1.41 | 11385.85 | +| 32 | 9.19 | 3481.34 | 2.23 | 14357.58 | +| 64 | 9.12 | 7016.80 | 4.14 | 15458.15 | +| 128 | 14.52 | 8814.57 | 8.00 | 15991.44 | +| 256 | 27.75 | 9224.39 | 15.99 | 16006.79 | + + +- Sequence length 128 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 8.02 | 124.72 | 0.78 | 1281.52 | +| 2 | 8.29 | 241.22 | 0.85 | 2364.94 | +| 4 | 8.51 | 470.29 | 0.99 | 4044.33 | +| 8 | 8.12 | 985.72 | 1.43 | 5600.93 | +| 16 | 9.22 | 1735.20 | 2.21 | 7232.47 | +| 32 | 9.11 | 3512.80 | 4.17 | 7677.82 | +| 64 | 15.29 | 4184.93 | 8.05 | 7949.06 | +| 128 | 29.44 | 4347.33 | 16.03 | 7987.11 | +| 256 | 56.34 | 4543.88 | 31.57 | 8109.08 | + + +- Sequence length 384 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 8.72 | 114.73 | 1.63 | 611.91 | +| 2 | 8.31 | 240.73 | 1.97 | 1013.19 | +| 4 | 8.64 | 463.10 | 2.55 | 1569.23 | +| 8 | 9.32 | 858.70 | 3.95 | 2025.62 | +| 16 | 13.90 | 1151.03 | 6.80 | 2354.21 | +| 32 | 26.72 | 1197.74 | 13.30 | 2405.46 | +| 64 | 51.02 | 1254.34 | 26.68 | 2398.95 | +| 128 | 100.26 | 1276.67 | 51.60 | 2480.67 | +| 256 | OOM | OOM | 101.55 | 2520.81 | + + +- Sequence length 1024 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 9.74 | 102.65 | 2.20 | 454.12 | +| 2 | 11.38 | 175.75 | 4.15 | 481.95 | +| 4 | 13.61 | 293.90 | 8.36 | 478.44 | +| 8 | 25.79 | 310.15 | 12.53 | 638.53 | +| 16 | 49.91 | 320.59 | 21.61 | 740.48 | +| 32 | 97.00 | 329.91 | 42.84 | 746.88 | +| 64 | 191.14 | 334.83 | 83.95 | 762.39 | +| 128 | OOM | OOM | 163.96 | 780.70 | +| 256 | OOM | OOM | 324.22 | 789.58 | + + + +- Sequence length 4096 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 32.82 | 30.47 | 18.23 | 54.87 | +| 2 | 65.25 | 30.65 | 35.64 | 56.11 | +| 4 | 128.73 | 31.07 | 103.67 | 38.58 | +| 8 | OOM | OOM | 119.45 | 66.98 | +| 16 | OOM | OOM | 166.25 | 96.24 | +| 32 | OOM | OOM | 333.98 | 95.81 | +| 64 | OOM | OOM | 662.29 | 96.63 | +| 128 | OOM | OOM | 1313.77 | 97.43 | +| 256 | | | | | + + + +### MI-250 / ROCm 5.2.3 / HIPCC-10736 + +#### 1 GCD + +- Sequence length 64 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 5.72 | 174.72 | 2.78 | 359.88 | +| 2 | 5.96 | 335.38 | 2.87 | 697.76 | +| 4 | 5.85 | 684.16 | 2.85 | 1404.31 | +| 8 | 6.15 | 1300.72 | 3.15 | 2540.72 | +| 16 | 6.14 | 2605.40 | 3.78 | 4231.12 | +| 32 | 7.73 | 4138.06 | 5.34 | 5993.50 | +| 64 | 14.38 | 4451.07 | 9.10 | 7030.42 | +| 128 | 26.18 | 4889.95 | 16.45 | 7780.40 | +| 256 | 49.95 | 5125.04 | 31.90 | 8023.98 | + + +- Sequence length 128 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 5.76 | 173.55 | 2.68 | 373.03 | +| 2 | 6.06 | 330.18 | 2.87 | 697.33 | +| 4 | 5.96 | 670.65 | 3.02 | 1324.91 | +| 8 | 6.03 | 1326.23 | 3.65 | 2194.62 | +| 16 | 9.35 | 1711.55 | 4.98 | 3212.12 | +| 32 | 16.46 | 1943.61 | 8.48 | 3775.22 | +| 64 | 30.83 | 2075.74 | 15.44 | 4146.40 | +| 128 | 58.74 | 2179.24 | 30.57 | 4187.68 | +| 256 | 115.27 | 2220.87 | 59.28 | 4318.61 | + + +- Sequence length 384 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 5.78 | 172.87 | 2.97 | 336.14 | +| 2 | 6.02 | 332.30 | 3.45 | 579.89 | +| 4 | 8.00 | 499.85 | 4.68 | 854.16 | +| 8 | 13.79 | 580.01 | 7.47 | 1070.24 | +| 16 | 24.39 | 656.06 | 13.04 | 1226.77 | +| 32 | 45.56 | 702.33 | 24.26 | 1318.80 | +| 64 | 87.84 | 728.57 | 47.87 | 1336.92 | +| 128 | 172.57 | 741.71 | 95.22 | 1344.26 | +| 256 | 352.27 | 726.71 | 185.94 | 1376.78 | + + + +- Sequence length 1024 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 6.86 | 145.71 | 4.20 | 237.84 | +| 2 | 12.41 | 161.21 | 5.82 | 343.62 | +| 4 | 22.25 | 179.80 | 10.20 | 392.26 | +| 8 | 41.94 | 190.73 | 18.91 | 423.05 | +| 16 | 81.03 | 197.45 | 37.86 | 422.60 | +| 32 | 159.06 | 201.19 | 71.65 | 446.62 | +| 64 | 321.51 | 199.06 | 148.86 | 429.95 | +| 128 | OOM | OOM | 277.53 | 461.21 | +| 256 | OOM | OOM | 563.07 | 454.65 | + + +- Sequence length 4096 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | 49.89 | 20.04 | 16.18 | 61.81 | +| 2 | 93.22 | 21.45 | 30.67 | 65.21 | +| 4 | 183.57 | 21.79 | 66.78 | 59.90 | +| 8 | 366.57 | 21.82 | 117.49 | 68.09 | +| 16 | OOM | OOM | 231.15 | 69.22 | +| 32 | OOM | OOM | 459.46 | 69.65 | +| 64 | OOM | OOM | 1031.86 | 62.02 | +| 128 | | | | | +| 256 | | | | | + + +#### 2 GCDs + +- Sequence length 64 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | | | | | +| 2 | 5.52 | 362.55 | 2.80 | 714.99 | +| 4 | 6.04 | 661.73 | 2.89 | 1385.05 | +| 8 | 6.07 | 1317.20 | 2.82 | 2835.38 | +| 16 | 6.02 | 2659.82 | 3.29 | 4866.99 | +| 32 | 6.09 | 5257.45 | 3.83 | 8352.10 | +| 64 | 8.53 | 7506.95 | 5.81 | 11013.02 | +| 128 | 15.34 | 8346.14 | 10.00 | 12806.23 | +| 256 | 28.44 | 9002.30 | 18.92 | 13528.13 | + + +- Sequence length 128 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | | | | | +| 2 | 5.58 | 358.62 | 2.68 | 745.20 | +| 4 | 6.20 | 644.91 | 2.83 | 1411.55 | +| 8 | 6.08 | 1316.09 | 3.21 | 2492.88 | +| 16 | 5.89 | 2716.79 | 3.86 | 4144.50 | +| 32 | 9.86 | 3247.03 | 5.41 | 5915.33 | +| 64 | 17.71 | 3614.25 | 9.64 | 6640.53 | +| 128 | 32.74 | 3909.15 | 17.81 | 7186.25 | +| 256 | 62.73 | 4080.77 | 35.73 | 7165.20 | + + +- Sequence length 384 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | | | | | +| 2 | 5.57 | 358.88 | 3.09 | 647.71 | +| 4 | 6.12 | 653.83 | 3.62 | 1104.69 | +| 8 | 8.35 | 958.19 | 4.94 | 1620.06 | +| 16 | 14.29 | 1119.38 | 8.29 | 1930.01 | +| 32 | 26.10 | 1226.17 | 14.96 | 2139.07 | +| 64 | 50.01 | 1279.72 | 28.22 | 2268.02 | +| 128 | 97.55 | 1312.15 | 55.94 | 2288.37 | +| 256 | 193.00 | 1326.44 | 111.27 | 2300.68 | + + + +- Sequence length 1024 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | | | | | +| 2 | 6.80 | 294.16 | 4.36 | 458.93 | +| 4 | 13.01 | 307.55 | 6.43 | 622.23 | +| 8 | 23.39 | 341.99 | 11.52 | 694.52 | +| 16 | 44.45 | 359.94 | 21.83 | 732.90 | +| 32 | 87.23 | 366.84 | 43.73 | 731.77 | +| 64 | 172.92 | 370.12 | 82.92 | 771.85 | +| 128 | 352.09 | 363.54 | 173.14 | 739.29 | +| 256 | OOM | OOM | 322.97 | 792.64 | + + +- Sequence length 4096 + +| Batch size | PT Latency (ms) | PT QPS (seq/s) | AIT Latency (ms) | AIT QPS (seq/s) | +|------------|-----------------|----------------|------------------|-----------------| +| 1 | | | | | +| 2 | 54.67 | 36.58 | 18.31 | 109.23 | +| 4 | 104.19 | 38.39 | 35.09 | 113.99 | +| 8 | 206.62 | 38.72 | 77.03 | 103.86 | +| 16 | 412.58 | 38.78 | 133.59 | 119.77 | +| 32 | OOM | OOM | 263.40 | 121.49 | +| 64 | OOM | OOM | 524.11 | 122.11 | +| 128 | OOM | OOM | 1186.20 | 107.91 | +| 256 | | | | | + + +### Note Performance Results + +- For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. +- To benchmark MI-250, the first step is to run `python3 benchmark_ait.py` to generate all necessary model dynamic library files with single GCD. Then run `./benchmark_mi250.sh {batch_size}` to simulate data parallel execution on 2 GCDs, each GCD is processing half of the batch. +- To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. +- To benchmark MI-250 2 GCDs, we observed performance regression with rocm perf-determ mode. The 2 GCDs number is running without perf-determ mode set with command `rocm-smi -d x --resetperfdeterminism`, where `x` is the GPU id. +- PyTorch Eager result doesn't reflect [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/), mainly due to BetterTransformer integration to TIMM/Transformer package is not yet landed. +- Performance results are what we can reproduced. It should not be used for other purposes. diff --git a/examples/03_bert/benchmark_ait.py b/examples/03_bert/benchmark_ait.py new file mode 100644 index 000000000..9847cb910 --- /dev/null +++ b/examples/03_bert/benchmark_ait.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from collections import OrderedDict + +from typing import Dict, List + +import click +import numpy as np +import torch +from aitemplate.compiler import compile_model, Model + +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target + +from modeling.bert import BertBaseEncodersOnly, BertBaseUncased +from modeling.torch_model import BertBaseUncased as BertPt + + +def mark_output(y: Tensor) -> None: + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +def create_bert_inputs( + batch_size: int, seq_length: int, dtype: str = "int64" +) -> List[Tensor]: + input_ids = Tensor( + shape=[batch_size, seq_length], + name="input_ids", + dtype=dtype, + is_input=True, + ) + token_type_ids = Tensor( + shape=[batch_size, seq_length], + name="token_type_ids", + dtype=dtype, + is_input=True, + ) + position_ids = Tensor( + shape=[batch_size, seq_length], + name="position_ids", + dtype=dtype, + is_input=True, + ) + return [input_ids, token_type_ids, position_ids] + + +def create_bert_encoders_input( + batch_size: int, seq_length: int, hidden: int, dtype: str = "float16" +): + encoder_input = Tensor( + shape=[batch_size, seq_length, hidden], + name="input", + dtype=dtype, + is_input=True, + ) + return [encoder_input] + + +def create_bert_inputs_pt( + batch_size: int, seq_length: int, dtype: torch.dtype = torch.int64 +) -> Dict[str, torch.Tensor]: + input_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() + token_type_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() + position_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() + + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + } + + +def create_bert_encoders_inputs_pt( + batch_size: int, seq_length: int, hidden_size: int +) -> Dict[str, torch.Tensor]: + encoder_input = torch.randn([batch_size, seq_length, hidden_size]).cuda().half() + return {"input": encoder_input} + + +def map_pt_params( + ait_bert, pt_bert, batch_size: int, seq_length: int +) -> Dict[str, torch.Tensor]: + pt_params = dict(pt_bert.named_parameters()) + mapped_pt_params = OrderedDict() + for name, _ in ait_bert.named_parameters(): + ait_name = name.replace(".", "_") + if name in pt_params: + mapped_pt_params[ait_name] = pt_params[name] + continue + + if name.endswith("self.qkv.weight"): + prefix = name[: -len("qkv.weight")] + q_weight = pt_params[prefix + "query.weight"] + k_weight = pt_params[prefix + "key.weight"] + v_weight = pt_params[prefix + "value.weight"] + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + mapped_pt_params[ait_name] = qkv_weight + elif name.endswith("self.qkv.bias"): + prefix = name[: -len("qkv.bias")] + q_bias = pt_params[prefix + "query.bias"] + k_bias = pt_params[prefix + "key.bias"] + v_bias = pt_params[prefix + "value.bias"] + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) + mapped_pt_params[ait_name] = qkv_bias + elif name.endswith("self.proj.weight"): + prefix = name[: -len("self.proj.weight")] + pt_name = prefix + "output.dense.weight" + mapped_pt_params[ait_name] = pt_params[pt_name] + elif name.endswith("self.proj.bias"): + prefix = name[: -len("self.proj.bias")] + pt_name = prefix + "output.dense.bias" + mapped_pt_params[ait_name] = pt_params[pt_name] + elif name.endswith("cu_length"): + cu_len = np.cumsum([0] + [seq_length] * batch_size).astype("int32") + mapped_pt_params[ait_name] = torch.from_numpy(cu_len).cuda() + else: + pt_param = pt_bert.get_parameter(name) + mapped_pt_params[ait_name] = pt_param + + return mapped_pt_params + + +def benchmark( + batch_size: int, + seq_length: int, + hidden_size: int, + mod: Model, + graph_mode: bool, + encoders_only: bool, +): + if encoders_only: + inputs = create_bert_encoders_inputs_pt(batch_size, seq_length, hidden_size) + else: + inputs = create_bert_inputs_pt(batch_size, seq_length) + outputs = [torch.empty(mod.get_output_maximum_shape(0)).cuda().half()] + + # warm up + t, _, __ = mod.benchmark_with_tensors( + inputs, + outputs, + count=100, + repeat=4, + graph_mode=graph_mode, + ) + # benchmark + t, _, __ = mod.benchmark_with_tensors( + inputs, + outputs, + count=100, + repeat=4, + graph_mode=graph_mode, + ) + print(f"batch_size: {batch_size}, seq_length: {seq_length}, latency: {t}") + dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1") + dev_flag = dev_flag.replace(",", "_") + with open(f"bert_ait_benchmark_dev_{dev_flag}.txt", "a") as f: + f.write(f"batch_size: {batch_size}, seq_length: {seq_length}, latency: {t}\n") + + +def compile_module( + batch_size: int, + seq_length: int, + hidden_size: int, + activation: str, + use_fp16_acc: bool, + encoders_only: bool, + pt_model: torch.nn.Module, +) -> None: + model_name = f"BERT_{activation}_{batch_size}_{seq_length}" + target = detect_target(use_fp16_acc=use_fp16_acc) + + if encoders_only: + inputs = create_bert_encoders_input(batch_size, seq_length, hidden_size) + else: + inputs = create_bert_inputs(batch_size, seq_length) + + if encoders_only: + model = BertBaseEncodersOnly(batch_size, seq_length, hidden_act=activation) + else: + model = BertBaseUncased(batch_size, seq_length, hidden_act=activation) + + # Mark all parameters with name same to PyTorch name convention + model.name_parameter_tensor() + # Forward the input tensor to the model, get output tensor + y = model(*inputs) + # Mark output tensor + mark_output(y) + + params = map_pt_params(model, pt_model, batch_size, seq_length) + + mod = compile_model(y, target, "./tmp", model_name) + + for k, v in params.items(): + mod.set_constant_with_tensor(k, v) + + return mod + + +@click.command() +@click.option("--batch-size", type=int, default=0, help="Inference batch size") +@click.option("--seq-length", type=int, default=0, help="Inference sequence length") +@click.option( + "--activation", + type=str, + default="fast_gelu", + help="Activation function applied on BERT, currently only support fast_gelu on Rocm. CUDA supports both gelu and fast_gelu. No effect if framework is pt.", +) +@click.option( + "--graph-mode", + type=bool, + default=True, + help="Use CUDA graph or not. hipGraph is not supported yet. No effect if framework is pt.", +) +@click.option( + "--use-fp16-acc", + type=bool, + default=True, + help="Use fp16 accumulation or not (TensorRT is using fp16_acc). No effect if framework is pt.", +) +@click.option( + "--use-pretrained-pt-model", + type=bool, + default=True, + help="Whether or not to use the pretrained BERT model weights.", +) +@click.option( + "--encoders-only", + type=bool, + default=True, + help="Whether or not to run the BERT benchmark with encoders only. If enabled, only the transformer blocks without BERT embeddings are benchmarked.", +) +def compile_and_benchmark( + batch_size: int, + seq_length: int, + activation: str, + graph_mode: bool, + use_fp16_acc: bool, + use_pretrained_pt_model: bool, + encoders_only: bool, +): + if detect_target().name() == "rocm": + graph_mode = False + assert activation in ( + "fast_gelu" + ), f"Unsupported activation: {activation} on rocm" + + pt_model = BertPt(pretrained=use_pretrained_pt_model)._model + pt_model.eval() + hidden_size = pt_model.config.hidden_size + + if batch_size < 1: + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + else: + batch_sizes = [batch_size] + + if seq_length < 1: + seq_lengths = ( + [64, 128, 384, 512, 1024, 4096] if encoders_only else [64, 128, 384, 512] + ) + else: + seq_lengths = [seq_length] + + for seq_length in seq_lengths: + for bs in batch_sizes: + mod = compile_module( + bs, + seq_length, + hidden_size, + activation, + use_fp16_acc, + encoders_only, + pt_model, + ) + benchmark(bs, seq_length, hidden_size, mod, graph_mode, encoders_only) + + +if __name__ == "__main__": + torch.manual_seed(4896) + compile_and_benchmark() diff --git a/examples/03_bert/benchmark_mi250.sh b/examples/03_bert/benchmark_mi250.sh new file mode 100644 index 000000000..32e935650 --- /dev/null +++ b/examples/03_bert/benchmark_mi250.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +HIP_VISIBLE_DEVICES=0 python3 benchmark.py --batch-size $1 & +HIP_VISIBLE_DEVICES=1 python3 benchmark.py --batch-size $1 && fg diff --git a/examples/03_bert/benchmark_pt.py b/examples/03_bert/benchmark_pt.py new file mode 100644 index 000000000..586df4fea --- /dev/null +++ b/examples/03_bert/benchmark_pt.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import click +import torch +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from modeling.torch_model import BertBaseUncased + + +def benchmark_pt(pretrained=True, batchsize=0): + bert = BertBaseUncased(pretrained) + model = bert._model + model.eval() + + if batchsize == 0: + candidate_batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + else: + candidate_batch_sizes = [batchsize] + + with torch.inference_mode(): + for seq_length in [64, 128, 384, 512]: + for batch_size in candidate_batch_sizes: + try: + input_ids, token_type_ids, position_ids = bert.generate_inputs( + batch_size, seq_length + ) + bert.forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + # warmup + t = benchmark_torch_function( + 100, + bert.forward, + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + # benchmark + t = benchmark_torch_function( + 100, + bert.forward, + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + print( + f"bert pt: batch_size: {batch_size}, seq_length: {seq_length}, {t} ms", + ) + with open("bert_pt_benchmark.txt", "a") as f: + f.write( + f"batch_size: {batch_size}, seq_length: {seq_length} latency: {t} ms\n" + ) + except RuntimeError: + # pt runs out of memory + break + + +def benchmark_pt_encoders_only(pretrained=True, batchsize=0): + model = BertBaseUncased(pretrained) + pt_bert = model._model + pt_bert.eval() + + encoder = pt_bert.bert.encoder + hidden_size = pt_bert.config.hidden_size + + if batchsize == 0: + candidate_batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + else: + candidate_batch_sizes = [batchsize] + + for seq_length in [64, 128, 384, 512, 1024, 4096]: + for batch_size in candidate_batch_sizes: + try: + encoder_input = ( + torch.randn([batch_size, seq_length, hidden_size]).cuda().half() + ) + encoder.forward(encoder_input) + # warmup + t = benchmark_torch_function( + 100, + encoder.forward, + encoder_input, + ) + # benchmark + t = benchmark_torch_function( + 100, + encoder.forward, + encoder_input, + ) + print( + f"bert encoders pt: batch_size: {batch_size}, seq_length: {seq_length}, {t} ms", + ) + with open("bert_encoders_pt_benchmark.txt", "a") as f: + f.write( + f"batch_size: {batch_size}, seq_length: {seq_length} latency: {t} ms\n" + ) + except RuntimeError: + # pt runs out of memory + break + + +@click.command() +@click.option( + "--use-pretrained-pt-model", + type=bool, + default=True, + help="Whether or not to use the pretrained BERT model weights.", +) +@click.option( + "--encoders-only", + type=bool, + default=True, + help="Whether or not to run the BERT benchmark with encoders only. If enabled, only the transformer blocks without BERT embeddings are benchmarked.", +) +@click.option( + "--batch-size", + type=int, + default=0, + help="The batch size to use for the benchmark. If 0, the batch size is default [1 : 128].", +) +def benchmark( + use_pretrained_pt_model: bool, + encoders_only: bool, + batch_size: int, +): + if encoders_only: + benchmark_pt_encoders_only(use_pretrained_pt_model, batch_size) + else: + benchmark_pt(use_pretrained_pt_model, batch_size) + + +if __name__ == "__main__": + torch.manual_seed(4896) + benchmark() diff --git a/examples/03_bert/demo.py b/examples/03_bert/demo.py new file mode 100644 index 000000000..d783b6423 --- /dev/null +++ b/examples/03_bert/demo.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import click + +import torch + +from benchmark_ait import compile_module +from modeling.torch_model import BertBaseUncased as BertPt +from transformers import BertTokenizer + + +def prepare_data(prompt: str): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + result = tokenizer(prompt, return_attention_mask=False, return_tensors="pt") + target_size = result["input_ids"].size() + if target_size[1] > 512: + raise ValueError("Sequence length > 512 is not supported") + + result["position_ids"] = ( + torch.arange(target_size[1], dtype=torch.int64) + .reshape(result["input_ids"].size()) + .contiguous() + .cuda() + ) + return result + + +def run_model( + prompt: str, activation: str, graph_mode: bool, use_fp16_acc: bool, verify: bool +): + inputs = prepare_data(prompt) + inputs_pt = {name: data.cuda() for name, data in inputs.items()} + batch_size, seq_len = inputs["input_ids"].size() + + pt_model = BertPt(pretrained=True)._model + pt_model.eval() + hidden_size = pt_model.config.hidden_size + + mod = compile_module( + batch_size, seq_len, hidden_size, activation, use_fp16_acc, False, pt_model + ) + + outputs = [torch.empty(mod.get_output_maximum_shape(0)).half().cuda()] + mod.run_with_tensors(inputs_pt, outputs, graph_mode=graph_mode) + + print(f"Logits: {outputs[0]}") + if verify: + pt_outputs = pt_model.bert(**inputs_pt) + torch.allclose(outputs[0], pt_outputs.last_hidden_state, 1e-1, 1e-1) + print("Verification done!") + + +@click.command() +@click.option( + "--prompt", + type=str, + default="The quick brown fox jumps over the lazy dog.", + help="The prompt to give BERT.", +) +@click.option( + "--activation", + type=str, + default="fast_gelu", + help="Activation function applied on BERT, currently only support gelu and fast_gelu", +) +@click.option( + "--graph_mode", + type=bool, + default=True, + help="Use CUDA graph or not. (hipGraph is not supported yet)", +) +@click.option( + "--use_fp16_acc", + type=bool, + default=True, + help="Use fp16 accumulation or not (TensorRT is using fp16_acc)", +) +@click.option( + "--verify", + type=bool, + default=True, + help="Verify AIT outputs against PT", +) +def run_demo( + prompt: str, + activation: str, + graph_mode: bool, + use_fp16_acc: bool, + verify: bool, +): + run_model(prompt, activation, graph_mode, use_fp16_acc, verify) + + +if __name__ == "__main__": + torch.manual_seed(4896) + run_demo() diff --git a/examples/03_bert/modeling/__init__.py b/examples/03_bert/modeling/__init__.py new file mode 100644 index 000000000..5cf1a826f --- /dev/null +++ b/examples/03_bert/modeling/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/03_bert/modeling/bert.py b/examples/03_bert/modeling/bert.py new file mode 100644 index 000000000..a3a29b54f --- /dev/null +++ b/examples/03_bert/modeling/bert.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Tuple + +from aitemplate.compiler import ops +from aitemplate.frontend import nn, Tensor +from aitemplate.testing import detect_target + +# pylint: disable=W0102 + +USE_CUDA = detect_target().name() == "cuda" + + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, layer_norm_eps): + """dense + add is included in nn.MultiheadAttention. + This class now only contains LayerNorm. + """ + super().__init__() + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: Tensor) -> Tensor: + if not USE_CUDA: + hidden_states = ( + hidden_states + if hidden_states._rank() == 2 + else ops.reshape()(hidden_states, [-1, hidden_states._size(-1)]) + ) + # [B, S, H] on cuda, [B * S, H] on rocm + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertAttention(nn.Module): + def __init__( + self, + batch_size, + seq_len, + hidden_size, + num_attention_heads, + layer_norm_eps, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + ): + super().__init__() + self.self = nn.MultiheadAttention( + dim=hidden_size, + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_attention_heads, + qkv_bias=True, + attn_drop=attention_probs_dropout_prob, + proj_drop=hidden_dropout_prob, + has_residual=True, + ) + self.output = BertSelfOutput(hidden_size, layer_norm_eps) + + def forward( + self, + hidden_states: Tensor, + ) -> Tuple[Tensor]: + self_output = self.self(hidden_states, hidden_states) + attention_output = self.output(self_output) + outputs = (attention_output,) + return outputs + + +# FFN block +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super().__init__() + # dense + activation + self.dense = nn.Linear( + hidden_size, intermediate_size, specialization=hidden_act + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__( + self, hidden_size, intermediate_size, layer_norm_eps, hidden_dropout_prob + ): + super().__init__() + assert hidden_dropout_prob == 0.0 + # dense + add + self.dense = nn.Linear(intermediate_size, hidden_size, specialization="add") + self.dropout = nn.Dropout(hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states, input_tensor) + # hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__( + self, + hidden_size, + batch_size, + seq_len, + num_attention_heads, + intermediate_size, + hidden_act, + layer_norm_eps, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + super().__init__() + self.attention = BertAttention( + batch_size=batch_size, + seq_len=seq_len, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + ) + self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) + self.output = BertOutput( + hidden_size, intermediate_size, layer_norm_eps, hidden_dropout_prob + ) + + def feed_forward(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def forward( + self, + hidden_states: Tensor, + ): + # [B, S, H] + shape = hidden_states.shape() + # [B, S, H] on cuda, [B * S, H] on rocm + self_attention_outputs = self.attention(hidden_states) + layer_output = self.feed_forward(self_attention_outputs[0]) + # [B * S, H] to [B, S, H] on rocm + layer_output = ( + layer_output + if layer_output._rank() == 3 + else ops.reshape()(layer_output, shape) + ) + return (layer_output,) + + +class BertEncoder(nn.Module): + def __init__( + self, + num_hidden_layers, + hidden_size, + batch_size, + seq_len, + num_attention_heads, + intermediate_size, + hidden_act, + layer_norm_eps, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + super().__init__() + self.layer = nn.ModuleList( + [ + BertLayer( + batch_size=batch_size, + seq_len=seq_len, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + ) + for _ in range(num_hidden_layers) + ] + ) + + def forward( + self, + hidden_states: Tensor, + ): + for layer_module in self.layer: + layer_outputs = layer_module(hidden_states) + hidden_states = layer_outputs[0] + + return layer_outputs + + +class BertModel(nn.Module): + def __init__( + self, + batch_size, + seq_len, + vocab_size, + max_position_embeddings, + type_vocab_size, + num_hidden_layers, + hidden_size, + num_attention_heads, + intermediate_size, + hidden_act, + layer_norm_eps, + attention_probs_dropout_prob, + hidden_dropout_prob, + add_pooling_layer=False, + ): + super().__init__() + assert not add_pooling_layer + + self.embeddings = nn.BertEmbeddings( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + ) + self.encoder = BertEncoder( + batch_size=batch_size, + seq_len=seq_len, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + ) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ): + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + encoder_outputs = self.encoder( + embedding_output, + ) + return encoder_outputs + + +class BertModelEncodersOnly(nn.Module): + def __init__( + self, + batch_size, + seq_len, + num_hidden_layers, + hidden_size, + num_attention_heads, + intermediate_size, + hidden_act, + layer_norm_eps, + attention_probs_dropout_prob, + hidden_dropout_prob, + add_pooling_layer=False, + ): + super().__init__() + assert not add_pooling_layer + + self.encoder = BertEncoder( + batch_size=batch_size, + seq_len=seq_len, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + ) + + def forward( + self, + encoder_input: Tensor, + ): + encoder_outputs = self.encoder(encoder_input) + return encoder_outputs + + +class BertBaseUncased(nn.Module): + """Bert base uncased with no classification head.""" + + def __init__( + self, + batch_size, + seq_len, + vocab_size=30522, + max_position_embeddings=512, + type_vocab_size=2, + num_hidden_layers=12, + hidden_size=768, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + layer_norm_eps=1e-12, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + ): + super().__init__() + self.bert = BertModel( + batch_size=batch_size, + seq_len=seq_len, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + add_pooling_layer=False, + ) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + outputs = self.bert( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + return outputs + + +class BertBaseEncodersOnly(nn.Module): + """Bert base uncased with no classification head and no embeddings.""" + + def __init__( + self, + batch_size, + seq_len, + num_hidden_layers=12, + hidden_size=768, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + layer_norm_eps=1e-12, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + ): + super().__init__() + self.bert = BertModelEncodersOnly( + batch_size=batch_size, + seq_len=seq_len, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + add_pooling_layer=False, + ) + + def forward( + self, + encoder_input: Tensor, + ) -> Tensor: + outputs = self.bert(encoder_input) + return outputs diff --git a/examples/03_bert/modeling/torch_model.py b/examples/03_bert/modeling/torch_model.py new file mode 100644 index 000000000..cbc965c70 --- /dev/null +++ b/examples/03_bert/modeling/torch_model.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from transformers import AutoModelForMaskedLM, BertForMaskedLM + + +class BertBaseUncased: + def __init__(self, pretrained=True): + if not pretrained: + pretrained = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") + self._model = BertForMaskedLM(pretrained.config).cuda().half() + else: + self._model = ( + AutoModelForMaskedLM.from_pretrained("bert-base-uncased").cuda().half() + ) + self._vocab_size = 30522 + + def forward(self, *args, **kwargs): + # runs the full model with classification head + outputs = self._model(*args, **kwargs) + return outputs.logits + + def generate_inputs(self, batch_size, seq_len): + dtype = torch.long + input_ids = torch.randint( + 0, self._vocab_size, (batch_size, seq_len), dtype=dtype + ).cuda() + token_type_ids = torch.zeros(input_ids.size(), dtype=dtype).cuda() + position_ids = ( + torch.arange(seq_len, dtype=dtype) + .reshape((1, -1)) + .expand(batch_size, -1) + .contiguous() + .cuda() + ) + return (input_ids, token_type_ids, position_ids) + + def get_parameters(self): + return dict(self._model.named_parameters()) diff --git a/examples/04_vit/README.md b/examples/04_vit/README.md new file mode 100644 index 000000000..fc747b4a1 --- /dev/null +++ b/examples/04_vit/README.md @@ -0,0 +1,126 @@ +# Vision Transformer (VIT) + +In this example, we will demo how to lower a pretrained Vision Transformer from TIMM, and run inference in AITemplate. We tested on two variants of Vision Transformer: Base version with 224x224 input / patch 16, and Large version with 384x384 input / patch 16. + +## Code structure +``` +modeling + vision_transformer.py # VIT definition using AIT's frontend API +weight_utils.py # Utils to convert TIMM VIT weights to AIT +verification.py # Numerical verification between TIMM and AIT +benchmark_pt.py # Benchmark code for PyTorch +benchmark_ait.py # Benchmark code for AITemplate +``` + +## Reference Speed vs PyTorch Eager + +### A100-40GB / CUDA 11.6.2 +_PT = PyTorch 1.12 Eager_ + +- vit_base_patch16_224 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 4.95 | 202.15 | 1.02 | 979.31 | +| 2 | 5.26 | 380.43 | 1.15 | 1735.64 | +| 4 | 5.51 | 726.08 | 1.57 | 2543.72 | +| 8 | 5.56 | 1439.03 | 2.20 | 3642.16 | +| 16 | 8.59 | 1863.35 | 3.64 | 4396.74 | +| 32 | 15.95 | 2006.62 | 6.51 | 4916.93 | +| 64 | 31.48 | 2032.77 | 12.67 | 5052.52 | +| 128 | 59.86 | 2138.35 | 25.10 | 5099.77 | +| 256 | 115.00 | 2226.10 | 48.55 | 5273.03 | + + +- vit_large_patch16_384 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 9.88 | 101.17 | 3.84 | 260.21 | +| 2 | 11.90 | 168.02 | 5.87 | 340.98 | +| 4 | 21.20 | 188.66 | 11.49 | 348.09 | +| 8 | 39.33 | 203.43 | 19.09 | 419.07 | +| 16 | 76.00 | 210.54 | 36.19 | 442.08 | +| 32 | 147.24 | 217.33 | 70.03 | 456.93 | +| 64 | 291.00 | 219.93 | 135.25 | 473.21 | +| 128 | 578.99 | 221.08 | 267.09 | 479.24 | +| 256 | 1204.16 | 212.60 | 538.97 | 474.98 | + + +### MI-250 / ROCm 5.2.3 / HIPCC-10736 +_PT = PyTorch 1.12 Eager_ + +#### 1 GCD + +- vit_base_patch16_224 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 3.54 | 282.12 | 3.49 | 286.26 | +| 2 | 4.43 | 451.73 | 3.78 | 528.84 | +| 4 | 6.09 | 657.02 | 4.05 | 986.95 | +| 8 | 9.65 | 829.27 | 5.31 | 1507.06 | +| 16 | 16.62 | 962.98 | 8.50 | 1882.72 | +| 32 | 29.87 | 1071.25 | 14.43 | 2218.07 | +| 64 | 56.58 | 1131.08 | 26.52 | 2413.45 | +| 128 | 110.28 | 1160.73 | 51.62 | 2479.69 | +| 256 | 217.07 | 1179.35 | 102.82 | 2489.89 | + + + +- vit_large_patch16_384 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | 12.90 | 77.51 | 9.70 | 103.05 | +| 2 | 22.42 | 89.19 | 13.40 | 149.29 | +| 4 | 38.16 | 104.83 | 22.12 | 180.86 | +| 8 | 70.58 | 113.35 | 38.46 | 208.00 | +| 16 | 136.28 | 117.40 | 70.44 | 227.15 | +| 32 | 261.97 | 122.15 | 138.14 | 231.65 | +| 64 | 541.90 | 118.10 | 270.01 | 237.02 | +| 128 | 1108.36 | 115.49 | 534.97 | 239.27 | +| 256 | 2213.09 | 115.68 | 1063.24 | 240.77 | + + +#### 2 GCDs + +- vit_base_patch16_224 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | | | | | +| 2 | 3.49 | 572.95 | 3.59 | 556.55 | +| 4 | 4.11 | 974.26 | 3.97 | 1006.80 | +| 8 | 5.88 | 1359.64 | 4.23 | 1889.44 | +| 16 | 9.75 | 1641.06 | 5.71 | 2800.69 | +| 32 | 17.55 | 1823.03 | 9.34 | 3426.32 | +| 64 | 31.31 | 2043.79 | 16.24 | 3940.53 | +| 128 | 60.33 | 2121.64 | 30.97 | 4133.14 | +| 256 | 117.96 | 2170.29 | 59.82 | 4279.21 | + + +- vit_large_patch16_384 + +| Batch size | PT Latency (ms) | PT QPS (im/s) | AIT Latency (ms) | AIT QPS (im/s) | +|------------|-----------------|---------------|------------------|----------------| +| 1 | | | | | +| 2 | 12.73 | 157.07 | 10.52 | 190.13 | +| 4 | 22.97 | 174.12 | 14.94 | 267.82 | +| 8 | 39.78 | 201.08 | 24.55 | 325.85 | +| 16 | 74.95 | 213.48 | 43.95 | 364.07 | +| 32 | 146.18 | 218.91 | 82.04 | 390.06 | +| 64 | 283.04 | 226.12 | 162.62 | 393.55 | +| 128 | 583.03 | 219.54 | 313.34 | 408.51 | +| 256 | 1197.56 | 213.77 | 621.71 | 411.77 | + + + +### Note for Performance Results + +- For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. +- To benchmark MI-250, the first step is to run `python3 benchmark_ait.py` to generate all necessary model dynamic library files with single GCD. Then run `./benchmark_mi250.sh {batch_size}` to simulate data parallel execution on 2 GCDs, each GCD is processing half of the batch. +- To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. +- To benchmark MI-250 2 GCDs, we observed performance regression with rocm perf-determ mode. The 2 GCDs number is running without perf-determ mode set with command `rocm-smi -d x --resetperfdeterminism`, where `x` is the GPU id. +- PyTorch Eager result doesn't reflect [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/), mainly due to BetterTransformer integration to TIMM/Transformer package is not yet landed. +- Performance results are what we can reproduce. It should not be used for other purposes. diff --git a/examples/04_vit/benchmark_ait.py b/examples/04_vit/benchmark_ait.py new file mode 100644 index 000000000..c302d297d --- /dev/null +++ b/examples/04_vit/benchmark_ait.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""benchmark for vit""" + +import os + +import click +import numpy as np +import torch +from aitemplate.compiler import compile_model, Model + +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target + +from modeling.vision_transformer import VisionTransformer +from weight_utils import export_to_torch_tensor + +# flake8: noqa + + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +USE_CUDA = detect_target().name() == "cuda" + + +def compile_vit( + model_name, + batch_size, + class_token=False, + global_pool="avg", + use_fp16_acc=True, +): + img_size = 224 + patch_size = 16 + embed_dim = 768 + num_heads = 12 + depth = 12 + if model_name == "vit_base_patch16_224": + img_size = 224 + patch_size = 16 + embed_dim = 768 + num_heads = 12 + depth = 12 + elif model_name == "vit_large_patch16_384": + img_size = 384 + patch_size = 16 + embed_dim = 1024 + num_heads = 16 + depth = 24 + seqlen = (img_size // patch_size) ** 2 + (1 if class_token else 0) + ait_model = VisionTransformer( + batch_size=batch_size, + img_size=img_size, + class_token=class_token, + global_pool=global_pool, + num_heads=num_heads, + embed_dim=embed_dim, + patch_size=patch_size, + depth=depth, + act_layer="GELU", + ) + ait_model.name_parameter_tensor() + inputs_ait = Tensor( + [batch_size, img_size, img_size, 3], name="input0", is_input=True + ) + Y = ait_model(inputs_ait) + mark_output(Y) + + target = detect_target(use_fp16_acc=use_fp16_acc) + exe_module = compile_model( + Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seqlen) + ) + return exe_module + + +def benchmark(model_name, batch_size, mod=None, graph_mode=True): + # load mod + if model_name == "vit_base_patch16_224": + img_size = 224 + patch_size = 16 + embed_dim = 768 + num_heads = 12 + depth = 12 + elif model_name == "vit_large_patch16_384": + img_size = 384 + patch_size = 16 + embed_dim = 1024 + num_heads = 16 + depth = 24 + else: + raise NotImplementedError + + seqlen = (img_size // patch_size) ** 2 + + if mod is None: + model_dir = f"vision_transformer_bs{batch_size}_seq{seqlen}" + mod = Model(os.path.join("./tmp", model_dir, "test.so")) + + # prepare params + params_ait = export_to_torch_tensor(model_name) + if detect_target().name() == "cuda": + ait_key = "attn_cu_length" + for i in range(depth): + prefix = "blocks_%d" % (i) + cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32") + params_ait[f"{prefix}_{ait_key}"] = torch.from_numpy(cu_len).cuda() + + # set weights + for name, weight in params_ait.items(): + mod.set_constant_with_tensor(name, weight) + + # prepare input/output tensor + inputs = [torch.randn([batch_size, img_size, img_size, 3]).cuda().half()] + ys = [] + num_ouputs = len(mod.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = mod.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + # warm up + t, _, __ = mod.benchmark_with_tensors( + inputs, + ys, + count=100, + repeat=4, + graph_mode=graph_mode, + ) + # benchmark + t, _, __ = mod.benchmark_with_tensors( + inputs, + ys, + count=100, + repeat=4, + graph_mode=graph_mode, + ) + print(f"batch_size: {batch_size}, latency: {t}") + dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1") + dev_flag = dev_flag.replace(",", "_") + with open(f"{model_name}_ait_benchmark_dev_{dev_flag}.txt", "a") as f: + f.write(f"batch_size: {batch_size}, latency: {t}\n") + + +@click.command() +@click.option("--model-name", type=str, default="vit_base_patch16_224") +@click.option( + "--use-fp16-acc", + type=bool, + default=True, + help="Whether to use FP16 for accumulation (similar to TensorRT)", +) +@click.option("--use-graph", type=bool, default=True, help="Whether to use CUDA graph") +@click.option("--batch-size", type=int, default=0, help="Batch size") +def main( + model_name="vit_base_patch16_224", use_fp16_acc=True, use_graph=True, batch_size=0 +): + if detect_target().name() == "rocm": + use_graph = False + if batch_size < 1: + for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256): + compile_vit(model_name, bs, use_fp16_acc=use_fp16_acc) + benchmark(model_name, bs, graph_mode=use_graph) + else: + benchmark(model_name, batch_size, graph_mode=use_graph) + + +if __name__ == "__main__": + main() diff --git a/examples/04_vit/benchmark_mi250.sh b/examples/04_vit/benchmark_mi250.sh new file mode 100644 index 000000000..883846b68 --- /dev/null +++ b/examples/04_vit/benchmark_mi250.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & +HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg diff --git a/examples/04_vit/benchmark_pt.py b/examples/04_vit/benchmark_pt.py new file mode 100644 index 000000000..48834e295 --- /dev/null +++ b/examples/04_vit/benchmark_pt.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import click +import torch +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from timm.models.vision_transformer import VisionTransformer +from torch import nn + + +def create_vit(model_name): + if model_name == "vit_base_patch16_224": + img_size = 224 + embed_dim = 768 + class_token = False + global_pool = "avg" + depth = 12 + patch_size = 16 + num_heads = 12 + elif model_name == "vit_large_patch16_384": + img_size = 384 + embed_dim = 1024 + class_token = False + global_pool = "avg" + depth = 24 + patch_size = 16 + num_heads = 16 + else: + raise NotImplementedError + model = ( + VisionTransformer( + img_size=img_size, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + class_token=class_token, + global_pool=global_pool, + depth=depth, + patch_size=patch_size, + num_heads=num_heads, + embed_dim=embed_dim, + ) + .cuda() + .half() + ) + return model + + +def benchmark(model_name, batch_size, img_size): + if model_name == "vit_base_patch16_224": + img_size = 224 + elif model_name == "vit_large_patch16_384": + img_size = 384 + model = create_vit(model_name) + with torch.inference_mode(): + input_shape = (batch_size, 3, img_size, img_size) + input_data = torch.randn(input_shape).cuda().half() + # warm up + benchmark_torch_function(100, model, input_data) + # benchmark + t = benchmark_torch_function(100, model, input_data) + print("batch_size: {}, time: {}".format(batch_size, t)) + dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1") + dev_flag = dev_flag.replace(",", "_") + with open(f"{model_name}_pt_benchmark_dev_{dev_flag}.txt", "a") as f: + f.write("batch_size: {}, latency: {}\n".format(batch_size, t)) + + +@click.command() +@click.option("--model-name", type=str, default="vit_base_patch16_224") +@click.option("--batch-size", default=0, type=int) +def main(model_name, batch_size): + img_size = 224 + if model_name == "vit_base_patch16_224": + img_size = 224 + elif model_name == "vit_large_patch16_384": + img_size = 384 + else: + raise NotImplementedError + if batch_size == 0: + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256]: + benchmark(model_name, batch_size, img_size) + else: + benchmark(model_name, batch_size, img_size) + + +if __name__ == "__main__": + main() diff --git a/examples/04_vit/modeling/vision_transformer.py b/examples/04_vit/modeling/vision_transformer.py new file mode 100644 index 000000000..5b4fb01f1 --- /dev/null +++ b/examples/04_vit/modeling/vision_transformer.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from functools import partial + +from aitemplate.compiler import ops +from aitemplate.frontend import nn +from aitemplate.testing import detect_target + +# pylint: disable=W0102 + +USE_CUDA = detect_target().name() == "cuda" + + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer="GELU", + drop=0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, + hidden_features, + specialization="fast_gelu" if act_layer == "GELU" else "relu", + ) + self.fc2 = nn.Linear(hidden_features, out_features, specialization="add") + + def forward(self, x, res): + shape = get_shape(x) + x = self.fc1(x) + x = self.fc2(x, res) + return ops.reshape()(x, shape) + + +class Block(nn.Module): + def __init__( + self, + dim, + batch_size, + seq_len, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + init_values=None, + drop_path=0.0, + act_layer="GELU", + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = nn.MultiheadAttention( + dim, + batch_size, + seq_len, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = nn.DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.ls2 = nn.Identity() + self.drop_path2 = nn.DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + x = self.attn(self.norm1(x), x) + x = self.mlp(self.norm2(x), x) + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size // patch_size, img_size // patch_size) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_dim = embed_dim + + conv_op = ( + nn.Conv2dBiasFewChannels + if detect_target().name() == "cuda" + else nn.Conv2dBias + ) + self.proj = conv_op( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.proj_norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, H, W, C = get_shape(x) + x = self.proj(x) + if self.flatten: + x = ops.reshape()(x, [B, -1, self.embed_dim]) + x = self.proj_norm(x) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + img_size=224, + batch_size=1, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool="token", + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + fc_norm=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + weight_init="", + embed_layer=PatchEmbed, + norm_layer=nn.LayerNorm, + act_layer=None, + block_fn=Block, + dtype="float16", + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ("", "avg", "token") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(shape=[1, 1, embed_dim], dtype=dtype) if class_token else None + ) + self.cls_token_mask = ( + nn.Parameter(shape=[batch_size, 1, embed_dim], dtype=dtype) + if class_token + else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(shape=[1, embed_len, embed_dim], dtype=dtype) + self.pos_drop = nn.Dropout(p=drop_rate) + seq_len = (img_size // patch_size) ** 2 + (1 if class_token else 0) + self.pool_size = img_size // patch_size + + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0, + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + if global_pool == "avg": + self.pool = nn.AvgPool2d(kernel_size=self.pool_size, stride=1, padding=0) + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed.tensor() + if self.cls_token is not None: + cls_token_expand = ops.expand()( + self.cls_token.tensor(), [get_shape(x)[0], -1, -1] + ) + cls_token_expand = cls_token_expand + self.cls_token_mask.tensor() + x = ops.concatenate()([cls_token_expand, x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + cls_token_expand = ops.expand()( + self.cls_token.tensor(), [get_shape(x)[0], -1, -1] + ) + cls_token_expand = cls_token_expand + self.cls_token_mask.tensor() + x = ops.concatenate()([cls_token_expand, x], dim=1) + x = x + self.pos_embed.tensor() + return self.pos_drop(x) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.blocks(x) + x = self.norm(x) + return x + + def _global_pool(self, x): + batch, seq, d = get_shape(x) + x = ops.reshape()(x, [batch, self.pool_size, self.pool_size, d]) + y = self.pool(x) + return ops.reshape()(y, [batch, d]) + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + if self.global_pool == "avg": + x = self._global_pool(x) + else: + batch, seq, d = get_shape(x) + x = ops.dynamic_slice()( + x, start_indices=[0, 0, 0], end_indices=[batch, 1, d] + ) + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x diff --git a/examples/04_vit/verification.py b/examples/04_vit/verification.py new file mode 100644 index 000000000..0584707bf --- /dev/null +++ b/examples/04_vit/verification.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import click +import numpy as np +import torch +from aitemplate.compiler import compile_model +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from modeling.vision_transformer import VisionTransformer +from timm.models.vision_transformer import vit_base_patch16_224, vit_large_patch16_384 + +from weight_utils import export_to_torch_tensor + + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + + +USE_CUDA = detect_target().name() == "cuda" + + +def compile_vit( + batch_size=128, + img_size=224, + patch_size=16, + embed_dim=768, + num_heads=12, + depth=12, + class_token=True, + global_pool="token", + use_fp16_acc=True, +): + seqlen = (img_size // patch_size) ** 2 + (1 if class_token else 0) + ait_model = VisionTransformer( + batch_size=batch_size, + img_size=img_size, + class_token=class_token, + global_pool=global_pool, + num_heads=num_heads, + embed_dim=embed_dim, + patch_size=patch_size, + depth=depth, + act_layer="GELU", + ) + ait_model.name_parameter_tensor() + inputs_ait = Tensor( + [batch_size, img_size, img_size, 3], name="input0", is_input=True + ) + Y = ait_model(inputs_ait) + mark_output(Y) + + target = detect_target(use_fp16_acc=use_fp16_acc) + exe_module = compile_model( + Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seqlen) + ) + return exe_module + + +def verification( + model_name, + batch_size=3, + use_fp16_acc=True, +): + img_size = 224 + embed_dim = 768 + depth = 12 + patch_size = 16 + num_heads = 12 + class_token = True + global_pool = "token" + if model_name == "vit_base_patch16_224": + img_size = 224 + embed_dim = 768 + depth = 12 + patch_size = 16 + num_heads = 12 + pt_mod = vit_base_patch16_224(pretrained=True).cuda().half() + elif model_name == "vit_large_patch16_384": + img_size = 384 + embed_dim = 1024 + depth = 24 + patch_size = 16 + num_heads = 16 + pt_mod = vit_large_patch16_384(pretrained=True).cuda().half() + + seqlen = (img_size // patch_size) ** 2 + (1 if class_token else 0) + input_pt = torch.randn([batch_size, 3, img_size, img_size]).cuda().half() * 255 + pt_ys = pt_mod(input_pt) + pt_ys = pt_ys.reshape((batch_size, 1, -1)) + + ait_mod = compile_vit( + batch_size=batch_size, + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + num_heads=num_heads, + depth=depth, + class_token=True, + global_pool=global_pool, + use_fp16_acc=use_fp16_acc, + ) + + # convert weights + params_ait = export_to_torch_tensor(model_name, True) + params_ait["cls_token_mask"] = torch.zeros((batch_size, 1, embed_dim)).cuda().half() + if detect_target().name() == "cuda": + ait_key = "attn_cu_length" + for i in range(depth): + prefix = "blocks_%d" % (i) + cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32") + params_ait[f"{prefix}_{ait_key}"] = torch.from_numpy(cu_len).cuda() + + # set weights + for name, weight in params_ait.items(): + ait_mod.set_constant_with_tensor(name, weight) + + inputs = [input_pt.permute((0, 2, 3, 1)).contiguous()] + ys = [] + num_ouputs = len(ait_mod.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = ait_mod.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + ait_mod.run_with_tensors(inputs, ys) + eps = 1e-1 + np.testing.assert_allclose( + pt_ys.detach().cpu().numpy(), + ys[0].cpu().numpy(), + atol=eps, + rtol=eps, + ) + print("vision transformer verification pass") + + +@click.command() +@click.option("--model-name", type=str, default="vit_base_patch16_224") +@click.option("--use-fp16-acc", type=bool, default=True) +def main(model_name, use_fp16_acc): + if model_name not in ("vit_base_patch16_224", "vit_large_patch16_384"): + raise ValueError( + "model name should be vit_base_patch16_224 or vit_large_patch16_384" + ) + verification(model_name, use_fp16_acc=use_fp16_acc) + + +if __name__ == "__main__": + main() diff --git a/examples/04_vit/weight_utils.py b/examples/04_vit/weight_utils.py new file mode 100644 index 000000000..49d3c9eed --- /dev/null +++ b/examples/04_vit/weight_utils.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""script for converting vit model from timm to ait +""" +import pickle + +import click +import torch +import torch.nn as nn +from aitemplate.testing.detect_target import detect_target +from timm.models.vision_transformer import ( + VisionTransformer, + vit_base_patch16_224, + vit_large_patch16_384, +) + + +def convert_vit(model_name, pretrained=False): + img_size = 224 + embed_dim = 768 + patch_size = 16 + depth = 12 + mod = None + if model_name == "vit_base_patch16_224": + if pretrained: + mod = vit_base_patch16_224(pretrained=pretrained).cuda().half() + else: + mod = ( + VisionTransformer( + img_size=img_size, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + class_token=False, + global_pool="avg", + depth=depth, + patch_size=patch_size, + num_heads=12, + embed_dim=embed_dim, + ) + .cuda() + .half() + ) + elif model_name == "vit_large_patch16_384": + img_size = 384 + embed_dim = 1024 + depth = 24 + if pretrained: + mod = vit_large_patch16_384(pretrained=pretrained).cuda().half() + else: + mod = ( + VisionTransformer( + img_size=img_size, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + class_token=False, + global_pool="avg", + depth=24, + patch_size=patch_size, + num_heads=16, + embed_dim=embed_dim, + ) + .cuda() + .half() + ) + else: + print(model_name) + raise NotImplementedError + params_pt = mod.named_parameters() + params_ait = {} + params_ait = {} + for key, arr in params_pt: + ait_key = key.replace(".", "_") + if len(arr.shape) == 4: + arr = arr.permute((0, 2, 3, 1)).contiguous() + if detect_target().name() == "cuda": + conv0_w_pad = ( + torch.zeros((embed_dim, patch_size, patch_size, 4)).cuda().half() + ) + conv0_w_pad[:, :, :, :3] = arr + arr = conv0_w_pad + params_ait[f"{ait_key}"] = arr + return params_ait + + +def export_to_torch_tensor(model_name, pretrained=False): + params_ait = convert_vit(model_name, pretrained) + return params_ait + + +@click.command() +@click.option("--model_name", default="vit_base_patch16_224", help="model name") +@click.option("--param-path", default="vit.pkl", help="saved numpy weights path") +@click.option("--pretrained", default=False, help="use pretrained weights") +def export_to_numpy(model_name, param_path, pretrained=False): + params_ait = convert_vit(model_name, pretrained) + params_np = {k: v.detach().cpu().numpy() for k, v in params_ait.items()} + + with open(param_path, "wb") as f: + pickle.dump(params_np, f) + + +if __name__ == "__main__": + export_to_numpy() diff --git a/examples/05_stable_diffusion/README.md b/examples/05_stable_diffusion/README.md new file mode 100644 index 000000000..a98021a3a --- /dev/null +++ b/examples/05_stable_diffusion/README.md @@ -0,0 +1,136 @@ +## Stable Diffusion Example + +In this example, we show how to build fast AIT modules for CLIP, UNet, VAE models, and benchmark/run them. + +### Build Dependencies + +The AIT stable diffusion example depends on `diffusers` and `transformers`. + +Verify the library versions. We have tested transformers 4.21/4.22/4.23, diffusers 0.3/0.4 and torch 1.11/1.12. + +``` +>>> import transformers +>>> transformers.__version__ +'4.21.2' +>>> import diffusers +>>> diffusers.__version__ +'0.3.0' +>>> torch.__version__ +'1.12.1+cu116' +``` + +### Build AIT modules for CLIP, UNet, VAE + +Build the AIT moduels by running `compile.py`, + +``` +python3 examples/05_stable_diffusion/compile.py +``` +It generates three folders: `./tmp/CLIPTextModel`, `./tmp/UNet2DConditionModel`, `./tmp/AutoencoderKL`. In each folder, there is a `test.so` file which is the generated AIT module for the model. + +#### Multi-GPU profiling +AIT needs to do profiling to select the best algorithms for CUTLASS and CK. +To enable multiple GPUs for profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform. + +### Prepare Weights and Benchmark + +In this step, we dowanload the Stable Diffusion weights for each model, and use them to initialize the parameters in AIT modules. Then we benchmark the AIT modules. + +1. Register in Hugging Face Hub to obtain an access token for Stable Diffusion weights. See [user access tokens](https://huggingface.co/docs/hub/security-tokens). + +2. (Optional) Run `benchmark.py` with the access token to initialize the weights and benchmark. + +``` +python3 examples/05_stable_diffusion/benchmark.py --token ACCESS_TOKEN +``` + +### Run Models + +Run AIT models with an example image: + +``` +python3 examples/05_stable_diffusion/demo.py --token ACCESS_TOKEN +``` + +Check the resulted image: `example_ait.png` + + +### Sample outputs + +Command: `python3 examples/05_stable_diffusion/demo.py --token hf_xxx --prompt "Mountain Rainier in van Gogh's world"` + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_ait_rainier.png) + +Command: `python3 examples/05_stable_diffusion/demo.py --token hf_xxx --prompt "Sitting in a tea house in Japan with Mount Fuji in the background, sunset professional portrait, Nikon 85mm f/1.4G"` + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_ait_fuji.png) + +Command: `python3 examples/05_stable_diffusion/demo.py --token hf_xxx --prompt "A lot of wild flowers with North Cascade Mountain in background, sunset professional photo, Unreal Engine"` + +![sample](https://raw.githubusercontent.com/AITemplate/webdata/main/imgs/example_ait_cascade2.png) + +## Results + +_PT = PyTorch 1.12 Eager_ + +_OOM = Out of Memory_ +### A100-40GB / CUDA 11.6, 50 steps + +| Module | PT Latency (ms) | AIT Latency (ms) | +|----------|-----------------|------------------| +| CLIP | 9.48 | 0.87 | +| UNet | 60.52 | 22.47 | +| VAE | 47.78 | 37.43 | +| Pipeline | 3058.27 | 1282.98 | + +- PT: 17.50 it/s +- AIT: 42.45 it/s + +### RTX 3080-10GB / CUDA 11.6, 50 steps + +| Module | PT Latency (ms) | AIT Latency (ms) | +|----------|-----------------|------------------| +| CLIP | OOM | 0.85 | +| UNet | OOM | 40.22 | +| VAE | OOM | 44.12 | +| Pipeline | OOM | 2163.43 | + +- PT: OOM +- AIT: 24.51 it/s + +### MI-250 1 GCD, 50 steps + +| Module | PT Latency (ms) | AIT Latency (ms) | +|----------|-----------------|------------------| +| CLIP | 6.16 | 2.98 | +| UNet | 78.42 | 62.18 | +| VAE | 63.83 | 164.50 | +| Pipeline | 4300.16 | 3476.07 | + +- PT: 12.43 it/s +- AIT: 15.60 it/s + +## Batched Version + +A batched version of AIT Stable Diffusion can be found at: https://github.com/terrychenism/AIT_StableDiffusion/tree/main/examples/05_stable_diffusion + + +Some reference results are taken from the repo: + +### A100-40GB, 25 Steps + +| Batch size | AIT Latency (ms) | AVG im/s | +|------------|------------------|----------| +| 1 | 695 | 0.69 | +| 3 | 1651 | 0.55 | +| 8 | 3975 | 0.50 | +| 16 | 7906 | 0.49 | + + + +### Note for Performance Results + +- For all benchmarks we render the images of size 512x512 +- For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. +- To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. +- Performance results are what we can reproduced & take reference. It should not be used for other purposes. diff --git a/examples/05_stable_diffusion/benchmark.py b/examples/05_stable_diffusion/benchmark.py new file mode 100644 index 000000000..9035ad73e --- /dev/null +++ b/examples/05_stable_diffusion/benchmark.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +import click + +import numpy as np +import torch +from aitemplate.compiler import Model +from aitemplate.testing import detect_target +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from diffusers import StableDiffusionPipeline + +from torch import autocast +from transformers import CLIPTokenizer + +USE_CUDA = detect_target().name() == "cuda" + +access_token = True +pipe = None + + +def get_int_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("AIT output_{} shape: {}".format(i, y_shape)) + + +def benchmark_unet( + batch_size=2, + hh=64, + ww=64, + dim=320, + benchmark_pt=False, + verify=False, +): + + exe_module = Model("./tmp/UNet2DConditionModel/test.so") + if exe_module is None: + print("Error!! Cannot find compiled module for UNet2DConditionModel.") + exit(-1) + + # run PT unet model + pt_mod = pipe.unet + pt_mod = pt_mod.eval() + + latent_model_input_pt = torch.randn(batch_size, 4, hh, ww).cuda().half() + text_embeddings_pt = torch.randn(batch_size, 64, 768).cuda().half() + timesteps_pt = torch.Tensor([1, 1]).cuda().half() + + with autocast("cuda"): + pt_ys = pt_mod( + latent_model_input_pt, + timesteps_pt, + encoder_hidden_states=text_embeddings_pt, + ).sample + + # PT benchmark + if benchmark_pt: + args = (latent_model_input_pt, 1, text_embeddings_pt) + pt_time = benchmark_torch_function(100, pt_mod, *args) + print(f"PT batch_size: {batch_size}, {pt_time} ms") + with open("sd_pt_benchmark.txt", "a") as f: + f.write(f"unet batch_size: {batch_size}, latency: {pt_time} ms\n") + + print("pt output:", pt_ys.shape) + + # run AIT unet model + inputs = { + "input0": latent_model_input_pt.permute((0, 2, 3, 1)).contiguous(), + "input1": timesteps_pt, + "input2": text_embeddings_pt, + } + + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys) + + # verification + y_transpose = ys[0].permute((0, 3, 1, 2)) + + if verify: + eps = 1e-1 + np.testing.assert_allclose( + pt_ys.detach().cpu().numpy(), + y_transpose.cpu().numpy(), + atol=eps, + rtol=eps, + ) + print("UNet2DCondition verification pass") + + # AIT benchmark + # warmup + exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) + # benchmark + t, _, _ = exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) + with open("sd_ait_benchmark.txt", "a") as f: + f.write(f"unet batch_size: {batch_size}, latency: {t} ms\n") + + +def benchmark_clip( + batch_size=1, + seqlen=64, + dim=768, + num_heads=12, + hidden_size=768, + vocab_size=49408, + max_position_embeddings=77, + benchmark_pt=False, + verify=False, +): + mask_seq = 0 + version = "openai/clip-vit-large-patch14" + + exe_module = Model("./tmp/CLIPTextModel/test.so") + if exe_module is None: + print("Error!! Cannot find compiled module for CLIPTextModel.") + exit(-1) + + # run PT clip + pt_mod = pipe.text_encoder + pt_mod = pt_mod.eval() + + tokenizer = CLIPTokenizer.from_pretrained(version) + text_input = tokenizer( + ["a photo of an astronaut riding a horse on mars"], + padding="max_length", + max_length=seqlen, + truncation=True, + return_tensors="pt", + ) + input_ids = text_input["input_ids"].cuda() + + attention_mask = torch.ones((batch_size, seqlen)) + attention_mask[-1, -mask_seq:] = 0 + attention_mask = None + + position_ids = torch.arange(seqlen).expand((1, -1)).cuda() + pt_ys = pt_mod(input_ids, attention_mask, position_ids) + print("pt output:", pt_ys[0].shape) + + # PT benchmark + if benchmark_pt: + args = (input_ids, attention_mask, position_ids) + pt_time = benchmark_torch_function(100, pt_mod, *args) + print(f"PT batch_size: {batch_size}, {pt_time} ms") + with open("sd_pt_benchmark.txt", "a") as f: + f.write(f"clip batch_size: {batch_size}, latency: {pt_time} ms\n") + + # run AIT clip + inputs = { + "input0": input_ids, + "input1": position_ids, + } + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys) + + # verification + if verify: + eps = 1e-1 + pt_np = pt_ys[0].detach().cpu().numpy() + np.testing.assert_allclose( + pt_np, + ys[0].cpu().numpy(), + atol=eps, + rtol=eps, + ) + print("CLIPTextTransformer verification pass") + + # AIT benchmark + # warmup + exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) + # benchmark + t, _, _ = exe_module.benchmark_with_tensors(inputs, ys, count=100, repeat=4) + with open("sd_ait_benchmark.txt", "a") as f: + f.write(f"clip batch_size: {batch_size}, latency: {t} ms\n") + + +def benchmark_vae(batch_size=1, height=64, width=64, benchmark_pt=False, verify=False): + + latent_channels = 4 + + exe_module = Model("./tmp/AutoencoderKL/test.so") + if exe_module is None: + print("Error!! Cannot find compiled module for AutoencoderKL.") + exit(-1) + + # run PT vae + pt_vae = pipe.vae + pt_vae = pt_vae.cuda().half() + pt_vae.eval() + + pt_input = torch.rand([batch_size, latent_channels, height, width]).cuda().half() + print("pt_input shape", pt_input.shape) + with autocast("cuda"): + pt_output = pt_vae.decode(pt_input).sample + pt_output = pt_output.half() + + # PT benchmark + if benchmark_pt: + args = (pt_input,) + pt_time = benchmark_torch_function(100, pt_vae.decode, *args) + print(f"PT batch_size: {batch_size}, {pt_time} ms") + with open("sd_pt_benchmark.txt", "a") as f: + f.write(f"vae batch_size: {batch_size}, latency: {pt_time} ms\n") + + # run AIT vae + y = ( + torch.empty( + pt_output.size(0), + pt_output.size(2), + pt_output.size(3), + pt_output.size(1), + ) + .cuda() + .half() + ) + ait_input_pt_tensor = torch.permute(pt_input, (0, 2, 3, 1)).contiguous() + print("input pt tensor size: ", ait_input_pt_tensor.shape) + print("output pt tensor size: ", y.shape) + exe_module.run_with_tensors([ait_input_pt_tensor], [y]) + + # verification + if verify: + y_pt = torch.permute(y, (0, 3, 1, 2)) + eps = 1e-1 + np.testing.assert_allclose( + pt_output.detach().cpu().numpy(), + y_pt.cpu().numpy(), + atol=eps, + rtol=eps, + ) + logging.info("VAE Verification done!") + + # AIT benchmark: + # warmup + exe_module.benchmark_with_tensors([ait_input_pt_tensor], [y], count=100, repeat=4) + # benchmark + t, _, _ = exe_module.benchmark_with_tensors( + [ait_input_pt_tensor], [y], count=100, repeat=4 + ) + with open("sd_ait_benchmark.txt", "a") as f: + f.write(f"vae batch_size: {batch_size}, latency: {t} ms\n") + + +@click.command() +@click.option("--token", default="", help="access token") +@click.option("--verify", type=bool, default=False, help="verify correctness") +@click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark") +def benchmark_diffusers(token, verify, benchmark_pt): + logging.getLogger().setLevel(logging.INFO) + np.random.seed(0) + torch.manual_seed(4896) + + global access_token, pipe + if token != "": + access_token = token + + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=access_token, + ).to("cuda") + + # CLIP + benchmark_clip(benchmark_pt=benchmark_pt, verify=verify) + # UNet + benchmark_unet(batch_size=2, benchmark_pt=benchmark_pt, verify=verify) + # VAE + benchmark_vae(benchmark_pt=benchmark_pt, verify=verify) + + +if __name__ == "__main__": + benchmark_diffusers() diff --git a/examples/05_stable_diffusion/benchmark_pt.py b/examples/05_stable_diffusion/benchmark_pt.py new file mode 100644 index 000000000..3534eaf62 --- /dev/null +++ b/examples/05_stable_diffusion/benchmark_pt.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import click +import torch + +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from diffusers import StableDiffusionPipeline + + +@click.command() +@click.option("--token", default="", help="access token") +@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") +@click.option( + "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" +) +def run(token, prompt, benchmark): + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=token, + ).to("cuda") + + with torch.autocast("cuda"): + image = pipe(prompt).images[0] + if benchmark: + t = benchmark_torch_function(10, pipe, prompt) + print(f"sd pt e2e: {t} ms") + + image.save("example_pt.png") + + +if __name__ == "__main__": + run() diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py new file mode 100644 index 000000000..d6bd33c9f --- /dev/null +++ b/examples/05_stable_diffusion/compile.py @@ -0,0 +1,353 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from collections import OrderedDict + +import click +import numpy as np + +import torch + +from aitemplate.compiler import compile_model +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from diffusers import StableDiffusionPipeline + +from modeling.clip import CLIPTextTransformer as ait_CLIPTextTransformer + +from modeling.unet_2d_condition import UNet2DConditionModel as ait_UNet2DConditionModel + +from modeling.vae import AutoencoderKL as ait_AutoencoderKL + + +USE_CUDA = detect_target().name() == "cuda" + +access_token = True +pipe = None + + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("AIT output_{} shape: {}".format(i, y_shape)) + + +def map_unet_params(pt_mod, dim): + pt_params = dict(pt_mod.named_parameters()) + params_ait = {} + for key, arr in pt_params.items(): + if len(arr.shape) == 4: + arr = arr.permute((0, 2, 3, 1)).contiguous() + elif key.endswith("ff.net.0.proj.weight"): + w1, w2 = arr.chunk(2, dim=0) + params_ait[key.replace(".", "_")] = w1 + params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 + continue + elif key.endswith("ff.net.0.proj.bias"): + w1, w2 = arr.chunk(2, dim=0) + params_ait[key.replace(".", "_")] = w1 + params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 + continue + params_ait[key.replace(".", "_")] = arr + + params_ait["arange"] = ( + torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() + ) + return params_ait + + +def map_vae_params(ait_module, pt_module, batch_size, seq_len): + pt_params = dict(pt_module.named_parameters()) + mapped_pt_params = OrderedDict() + for name, _ in ait_module.named_parameters(): + ait_name = name.replace(".", "_") + if name in pt_params: + if ( + "conv" in name + and "norm" not in name + and name.endswith(".weight") + and len(pt_params[name].shape) == 4 + ): + mapped_pt_params[ait_name] = torch.permute( + pt_params[name], [0, 2, 3, 1] + ).contiguous() + else: + mapped_pt_params[ait_name] = pt_params[name] + elif name.endswith("attention.qkv.weight"): + prefix = name[: -len("attention.qkv.weight")] + q_weight = pt_params[prefix + "query.weight"] + k_weight = pt_params[prefix + "key.weight"] + v_weight = pt_params[prefix + "value.weight"] + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + mapped_pt_params[ait_name] = qkv_weight + elif name.endswith("attention.qkv.bias"): + prefix = name[: -len("attention.qkv.bias")] + q_bias = pt_params[prefix + "query.bias"] + k_bias = pt_params[prefix + "key.bias"] + v_bias = pt_params[prefix + "value.bias"] + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) + mapped_pt_params[ait_name] = qkv_bias + elif name.endswith("attention.proj.weight"): + prefix = name[: -len("attention.proj.weight")] + pt_name = prefix + "proj_attn.weight" + mapped_pt_params[ait_name] = pt_params[pt_name] + elif name.endswith("attention.proj.bias"): + prefix = name[: -len("attention.proj.bias")] + pt_name = prefix + "proj_attn.bias" + mapped_pt_params[ait_name] = pt_params[pt_name] + elif name.endswith("attention.cu_length"): + cu_len = np.cumsum([0] + [seq_len] * batch_size).astype("int32") + mapped_pt_params[ait_name] = torch.from_numpy(cu_len).cuda() + else: + pt_param = pt_module.get_parameter(name) + mapped_pt_params[ait_name] = pt_param + + return mapped_pt_params + + +def map_clip_params(pt_mod, batch_size, seqlen, depth): + + params_pt = list(pt_mod.named_parameters()) + + params_ait = {} + pt_params = {} + for key, arr in params_pt: + pt_params[key.replace("text_model.", "")] = arr + + pt_params = dict(pt_mod.named_parameters()) + for key, arr in pt_params.items(): + name = key.replace("text_model.", "") + ait_name = name.replace(".", "_") + if name.endswith("out_proj.weight"): + ait_name = ait_name.replace("out_proj", "proj") + elif name.endswith("out_proj.bias"): + ait_name = ait_name.replace("out_proj", "proj") + elif name.endswith("q_proj.weight"): + ait_name = ait_name.replace("q_proj", "qkv") + prefix = key[: -len("q_proj.weight")] + q = pt_params[prefix + "q_proj.weight"] + k = pt_params[prefix + "k_proj.weight"] + v = pt_params[prefix + "v_proj.weight"] + qkv_weight = torch.cat([q, k, v], dim=0) + params_ait[ait_name] = qkv_weight + continue + elif name.endswith("q_proj.bias"): + ait_name = ait_name.replace("q_proj", "qkv") + prefix = key[: -len("q_proj.bias")] + q = pt_params[prefix + "q_proj.bias"] + k = pt_params[prefix + "k_proj.bias"] + v = pt_params[prefix + "v_proj.bias"] + qkv_bias = torch.cat([q, k, v], dim=0) + params_ait[ait_name] = qkv_bias + continue + elif name.endswith("k_proj.weight"): + continue + elif name.endswith("k_proj.bias"): + continue + elif name.endswith("v_proj.weight"): + continue + elif name.endswith("v_proj.bias"): + continue + params_ait[ait_name] = arr + + if USE_CUDA: + for i in range(depth): + prefix = "encoder_layers_%d_self_attn_cu_length" % (i) + cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32") + params_ait[prefix] = torch.from_numpy(cu_len).cuda() + + return params_ait + + +def compile_unet( + batch_size=2, + hh=64, + ww=64, + dim=320, + use_fp16_acc=False, + convert_conv_to_gemm=False, +): + + ait_mod = ait_UNet2DConditionModel(sample_size=64, cross_attention_dim=768) + ait_mod.name_parameter_tensor() + + # set AIT parameters + pt_mod = pipe.unet + pt_mod = pt_mod.eval() + params_ait = map_unet_params(pt_mod, dim) + + latent_model_input_ait = Tensor( + [batch_size, hh, ww, 4], name="input0", is_input=True + ) + timesteps_ait = Tensor([2], name="input1", is_input=True) + text_embeddings_pt_ait = Tensor([batch_size, 64, 768], name="input2", is_input=True) + + Y = ait_mod(latent_model_input_ait, timesteps_ait, text_embeddings_pt_ait) + mark_output(Y) + + target = detect_target( + use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm + ) + compile_model(Y, target, "./tmp", "UNet2DConditionModel", constants=params_ait) + + +def compile_clip( + batch_size=1, + seqlen=64, + dim=768, + num_heads=12, + hidden_size=768, + vocab_size=49408, + max_position_embeddings=77, + use_fp16_acc=False, + convert_conv_to_gemm=False, +): + mask_seq = 0 + causal = True + depth = 12 + + ait_mod = ait_CLIPTextTransformer( + num_hidden_layers=depth, + hidden_size=dim, + num_attention_heads=num_heads, + batch_size=batch_size, + seq_len=seqlen, + causal=causal, + mask_seq=mask_seq, + ) + ait_mod.name_parameter_tensor() + + pt_mod = pipe.text_encoder + pt_mod = pt_mod.eval() + params_ait = map_clip_params(pt_mod, batch_size, seqlen, depth) + + input_ids_ait = Tensor( + [batch_size, seqlen], name="input0", dtype="int64", is_input=True + ) + position_ids_ait = Tensor( + [batch_size, seqlen], name="input1", dtype="int64", is_input=True + ) + Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) + mark_output(Y) + + target = detect_target( + use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm + ) + compile_model(Y, target, "./tmp", "CLIPTextModel", constants=params_ait) + + +def compile_vae( + batch_size=1, height=64, width=64, use_fp16_acc=False, convert_conv_to_gemm=False +): + in_channels = 3 + out_channels = 3 + down_block_types = [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ] + up_block_types = [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ] + block_out_channels = [128, 256, 512, 512] + layers_per_block = 2 + act_fn = "silu" + latent_channels = 4 + sample_size = 512 + + ait_vae = ait_AutoencoderKL( + batch_size, + height, + width, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + latent_channels=latent_channels, + sample_size=sample_size, + ) + ait_input = Tensor( + shape=[batch_size, height, width, latent_channels], + name="vae_input", + is_input=True, + ) + ait_vae.name_parameter_tensor() + + pt_mod = pipe.vae + pt_mod = pt_mod.eval() + params_ait = map_vae_params(ait_vae, pt_mod, batch_size, height * width) + + Y = ait_vae.decode(ait_input) + mark_output(Y) + target = detect_target( + use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm + ) + compile_model( + Y, + target, + "./tmp", + "AutoencoderKL", + constants=params_ait, + ) + + +@click.command() +@click.option("--token", default="", help="access token") +@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") +@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") +def compile_diffusers(token, use_fp16_acc=True, convert_conv_to_gemm=True): + logging.getLogger().setLevel(logging.INFO) + np.random.seed(0) + torch.manual_seed(4896) + + if detect_target().name() == "rocm": + convert_conv_to_gemm = False + + global access_token, pipe + if token != "": + access_token = token + + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=access_token, + ).to("cuda") + + # CLIP + compile_clip(use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + # UNet + compile_unet( + batch_size=2, + use_fp16_acc=use_fp16_acc, + convert_conv_to_gemm=convert_conv_to_gemm, + ) + # VAE + compile_vae(use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + + +if __name__ == "__main__": + compile_diffusers() diff --git a/examples/05_stable_diffusion/demo.py b/examples/05_stable_diffusion/demo.py new file mode 100644 index 000000000..5a7b8b79e --- /dev/null +++ b/examples/05_stable_diffusion/demo.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import click +import torch + +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from pipeline_stable_diffusion_ait import StableDiffusionAITPipeline + + +@click.command() +@click.option("--token", default="", help="access token") +@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") +@click.option( + "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" +) +def run(token, prompt, benchmark): + pipe = StableDiffusionAITPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=token, + ).to("cuda") + + with torch.autocast("cuda"): + image = pipe(prompt).images[0] + if benchmark: + t = benchmark_torch_function(10, pipe, prompt) + print(f"sd e2e: {t} ms") + + image.save("example_ait.png") + + +if __name__ == "__main__": + run() diff --git a/examples/05_stable_diffusion/modeling/attention.py b/examples/05_stable_diffusion/modeling/attention.py new file mode 100644 index 000000000..efabc3c0c --- /dev/null +++ b/examples/05_stable_diffusion/modeling/attention.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Implementations are translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py. +""" + +from typing import Optional + +from aitemplate.compiler.ops import reshape + +from aitemplate.frontend import nn, Tensor + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + Parameters: + batch_size (:obj:`int`): The number of examples per batch. + height (:obj:`int`): Height of each image example. + width (:obj:`int`): Width of each image example. + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + batch_size: int, + height: int, + width: int, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.batch_size = batch_size + self.height = height + self.width = width + self.channels = channels + self.num_heads = ( + channels // num_head_channels if num_head_channels is not None else 1 + ) + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_groups, channels, eps) + self.attention = nn.MultiheadAttention( + channels, + batch_size, + height * width, + self.num_heads, + qkv_bias=True, + has_residual=True, + ) + self.rescale_output_factor = rescale_output_factor + + def forward(self, hidden_states) -> Tensor: + """ + input hidden_states shape: [batch, height, width, channel] + output shape: [batch, height, width, channel] + """ + residual = hidden_states + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = reshape()( + hidden_states, [self.batch_size, self.height * self.width, self.channels] + ) + + batch, hw, channel = hidden_states.shape() + if ( + batch.value() != self.batch_size + or hw.value() != self.width * self.height + or channel.value() != self.channels + ): + raise RuntimeError( + "nchw params do not match! " + f"Expected: {self.batch_size}, {self.channels}, {self.height} * {self.width}, " + f"actual: {batch}, {channel}, {hw}." + ) + + res = self.attention(hidden_states, residual) * (1 / self.rescale_output_factor) + res = reshape()(res, [self.batch_size, self.height, self.width, self.channels]) + + return res diff --git a/examples/05_stable_diffusion/modeling/clip.py b/examples/05_stable_diffusion/modeling/clip.py new file mode 100644 index 000000000..c66ecfb90 --- /dev/null +++ b/examples/05_stable_diffusion/modeling/clip.py @@ -0,0 +1,590 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from inspect import isfunction +from typing import Optional + +from aitemplate.compiler import ops +from aitemplate.frontend import nn, Tensor +from aitemplate.testing import detect_target + +# pylint: disable=W0102 + +USE_CUDA = detect_target().name() == "cuda" + + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + dtype="float16", + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + + self.to_q_weight = nn.Parameter(shape=[inner_dim, query_dim], dtype=dtype) + self.to_k_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) + self.to_v_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None, residual=None): + nheads = self.heads + d = self.dim_head + + layout = "20314" if USE_CUDA else "m2n3" + + bs, seqlen, _ = get_shape(x) + q = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( + ops.reshape()(x, [bs * seqlen, -1]), self.to_q_weight.tensor() + ) + context = default(context, x) + + seqlen = get_shape(context)[1] + k = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( + ops.reshape()(context, [bs * seqlen, -1]), self.to_k_weight.tensor() + ) + v = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( + ops.reshape()(context, [bs * seqlen, -1]), self.to_v_weight.tensor() + ) + + if USE_CUDA: + q = q * self.scale + attn = ops.bmm_rcr()( + (ops.reshape()(q, [bs * nheads, -1, d])), + (ops.reshape()(k, [bs * nheads, -1, d])), + ) + attn = ops.softmax()(attn, -1) + v = ops.reshape()(v, [bs * nheads, -1, d]) + out = ops.bmm_rrr_permute((nheads,))(attn, v) + else: + OP = ops.bmm_softmax_bmm_permute(shape=(nheads,), scale=self.scale) + out = OP( + (ops.reshape()(q, [bs * nheads, -1, d])), + (ops.reshape()(k, [bs * nheads, -1, d])), + (ops.reshape()(v, [bs * nheads, -1, d])), + ) + out = ops.reshape()(out, [bs, -1, nheads * d]) + proj = self.to_out(out) + proj = ops.reshape()(proj, [bs, -1, nheads * d]) + if residual is not None: + return proj + residual + else: + return proj + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, specialization="mul") + self.gate = nn.Linear(dim_in, dim_out, specialization="fast_gelu") + + def forward(self, x): + return self.proj(x, self.gate(x)) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential( + nn.Linear(dim, inner_dim, specialization="fast_gelu"), + ) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x, residual=None): + shape = ops.size()(x) + x = self.net(x) + x = ops.reshape()(x, shape) + if residual is not None: + return x + residual + else: + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + self.param = (dim, n_heads, d_head, context_dim, gated_ff, checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), residual=x) + x = self.attn2(self.norm2(x), context=context, residual=x) + x = self.ff(self.norm3(x), residual=x) + return x + + +def Normalize(in_channels): + return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) # Group Norm + + self.proj_in = nn.Conv2dBias( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2dBias( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, h, w, c = get_shape(x) + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = ops.reshape()(x, [b, -1, c]) + for block in self.transformer_blocks: + x = block(x, context=context) + x = ops.reshape()(x, [b, h, w, c]) + x = self.proj_out(x) + return x + x_in + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size=768, + num_attention_heads=12, + attention_dropout=0.0, + batch_size=1, + seq_len=16, + layer_norm_eps=1e-5, + hidden_dropout_prob=0.0, + causal=False, + mask_seq=0, + ): + super().__init__() + self.attn = nn.MultiheadAttention( + dim=hidden_size, + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_attention_heads, + qkv_bias=True, + attn_drop=attention_dropout, + proj_drop=hidden_dropout_prob, + has_residual=False, + causal=causal, + mask_seq=mask_seq, + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + causal_attention_mask: Optional[Tensor] = None, + output_attentions: Optional[bool] = False, + residual: Optional[Tensor] = None, + ): + if residual is not None: + self_output = self.attn(hidden_states, residual) + else: + self_output = self.attn(hidden_states) + return self_output + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, x): + x1 = x * 1.702 + x1 = ops.sigmoid(x1) + x = x * x1 + return x + + +class CLIPMLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer="GELU", + drop=0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, + hidden_features, + ) + self.activation_fn = QuickGELUActivation() + self.fc2 = nn.Linear(hidden_features, out_features, specialization="add") + + def forward(self, x, res): + shape = get_shape(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x, res) + return ops.reshape()(x, shape) + + +class CLIPEncoderLayer(nn.Module): + def __init__( + self, + hidden_size=768, + num_attention_heads=12, + attention_dropout=0.0, + mlp_ratio=4.0, + batch_size=1, + seq_len=16, + causal=False, + mask_seq=0, + ): + super().__init__() + self.embed_dim = hidden_size + self.self_attn = nn.MultiheadAttention( + dim=hidden_size, + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_attention_heads, + qkv_bias=True, + attn_drop=attention_dropout, + proj_drop=0, + has_residual=True, + causal=causal, + mask_seq=mask_seq, + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim) + self.mlp = CLIPMLP(hidden_size, int(hidden_size * mlp_ratio)) + self.layer_norm2 = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: Tensor, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states, residual) + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states, residual) + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + Args: + config: CLIPConfig + """ + + def __init__( + self, + num_hidden_layers=12, + output_attentions=False, + output_hidden_states=False, + use_return_dict=False, + hidden_size=768, + num_attention_heads=12, + batch_size=1, + seq_len=64, + causal=False, + mask_seq=0, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + batch_size=batch_size, + seq_len=seq_len, + causal=causal, + mask_seq=mask_seq, + ) + for _ in range(num_hidden_layers) + ] + ) + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.use_return_dict = use_return_dict + + def forward( + self, + inputs_embeds, + attention_mask: Optional[Tensor] = None, + causal_attention_mask: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.use_return_dict + + encoder_states = () if output_hidden_states else None + # all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for _, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + + return hidden_states + + +class CLIPTextEmbeddings(nn.Module): + def __init__( + self, + hidden_size=768, + vocab_size=49408, + max_position_embeddings=77, + dtype="float16", + ): + super().__init__() + embed_dim = hidden_size + + self.token_embedding = nn.Embedding(shape=[vocab_size, embed_dim], dtype=dtype) + self.position_embedding = nn.Embedding( + shape=[max_position_embeddings, embed_dim], dtype=dtype + ) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + inputs_embeds: Optional[Tensor] = None, + ) -> Tensor: + + input_shape = ops.size()(input_ids) + + # [B * S] + input_ids = ops.reshape()(input_ids, [-1]) + + position_ids = ops.reshape()(position_ids, [-1]) + + if inputs_embeds is None: + inputs_embeds = ops.batch_gather()(self.token_embedding.tensor(), input_ids) + + position_embeddings = ops.batch_gather()( + self.position_embedding.tensor(), position_ids + ) + + embeddings = inputs_embeds + position_embeddings + + embeddings = ops.reshape()(embeddings, [input_shape[0], input_shape[1], -1]) + + return embeddings + + +class CLIPTextTransformer(nn.Module): + def __init__( + self, + hidden_size=768, + output_attentions=False, + output_hidden_states=False, + use_return_dict=False, + num_hidden_layers=12, + num_attention_heads=12, + batch_size=1, + seq_len=64, + causal=False, + mask_seq=0, + ): + super().__init__() + embed_dim = hidden_size + self.embeddings = CLIPTextEmbeddings() + self.encoder = CLIPEncoder( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + batch_size=batch_size, + seq_len=seq_len, + causal=causal, + mask_seq=mask_seq, + ) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.use_return_dict = use_return_dict + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.final_layer_norm(last_hidden_state) + return last_hidden_state diff --git a/examples/05_stable_diffusion/modeling/embeddings.py b/examples/05_stable_diffusion/modeling/embeddings.py new file mode 100644 index 000000000..36b96a4fb --- /dev/null +++ b/examples/05_stable_diffusion/modeling/embeddings.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math + +from aitemplate.compiler import ops +from aitemplate.frontend import nn, Tensor + + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +def get_timestep_embedding( + timesteps: Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(get_shape(timesteps)) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + + exponent = (-math.log(max_period)) * Tensor( + shape=[half_dim], dtype="float16", name="arange" + ) + + exponent = exponent * (1.0 / (half_dim - downscale_freq_shift)) + + emb = ops.exp(exponent) + emb = ops.reshape()(timesteps, [-1, 1]) * ops.reshape()(emb, [1, -1]) + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + if flip_sin_to_cos: + emb = ops.concatenate()( + [ops.cos(emb), ops.sin(emb)], + dim=-1, + ) + else: + emb = ops.concatenate()( + [ops.sin(emb), ops.cos(emb)], + dim=-1, + ) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim, specialization="swish") + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb diff --git a/examples/05_stable_diffusion/modeling/resnet.py b/examples/05_stable_diffusion/modeling/resnet.py new file mode 100644 index 000000000..03e4f8023 --- /dev/null +++ b/examples/05_stable_diffusion/modeling/resnet.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from aitemplate.compiler import ops +from aitemplate.frontend import nn + + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2dBias(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2dBias(self.channels, self.out_channels, 3, 1, 1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, x): + assert get_shape(x)[-1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) + + return x + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2dBias( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, x): + assert get_shape(x)[-1] == self.channels + x = self.conv(x) + + return x + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_nin_shortcut=None, + up=False, + down=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm( + num_groups=groups, + num_channels=in_channels, + eps=eps, + affine=True, + use_swish=True, + ) + + self.conv1 = nn.Conv2dBias( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = nn.GroupNorm( + num_groups=groups_out, + num_channels=out_channels, + eps=eps, + affine=True, + use_swish=True, + ) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2dBias( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.upsample = self.downsample = None + + self.use_nin_shortcut = ( + self.in_channels != self.out_channels + if use_nin_shortcut is None + else use_nin_shortcut + ) + + if self.use_nin_shortcut: + self.conv_shortcut = nn.Conv2dBias( + in_channels, out_channels, 1, 1, 0 + ) # kernel_size=1, stride=1, padding=0) # conv_bias_add + else: + self.conv_shortcut = None + + def forward(self, x, temb=None): + hidden_states = x + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm1( + hidden_states + ) # .float()).type(hidden_states.dtype) # fused swish + # hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + x = self.upsample(x) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + x = self.downsample(x) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(ops.silu(temb)) + bs, dim = get_shape(temb) + temb = ops.reshape()(temb, [bs, 1, 1, dim]) + hidden_states = hidden_states + temb + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm2(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + + out = hidden_states + x + + return out diff --git a/examples/05_stable_diffusion/modeling/unet_2d_condition.py b/examples/05_stable_diffusion/modeling/unet_2d_condition.py new file mode 100644 index 000000000..9c1d9f07c --- /dev/null +++ b/examples/05_stable_diffusion/modeling/unet_2d_condition.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, Tuple + +from aitemplate.frontend import nn + +from modeling.embeddings import TimestepEmbedding, Timesteps +from modeling.unet_blocks import get_down_block, get_up_block, UNetMidBlock2DCrossAttn + + +class UNet2DConditionModel(nn.Module): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, + ): + super().__init__() + self.center_input_sample = center_input_sample + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2dBias(in_channels, block_out_channels[0], 3, 1, 1) + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + use_swish=True, + ) + + self.conv_out = nn.Conv2dBias(block_out_channels[0], out_channels, 3, 1, 1) + + def forward( + self, + sample, + timesteps, + encoder_hidden_states, + return_dict: bool = True, + ): + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + + # 1. time + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "attentions") + and downsample_block.attentions is not None + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states + ) + + # 5. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + if ( + hasattr(upsample_block, "attentions") + and upsample_block.attentions is not None + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples + ) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample) + sample = self.conv_out(sample) + return sample diff --git a/examples/05_stable_diffusion/modeling/unet_blocks.py b/examples/05_stable_diffusion/modeling/unet_blocks.py new file mode 100644 index 000000000..75de2e0c8 --- /dev/null +++ b/examples/05_stable_diffusion/modeling/unet_blocks.py @@ -0,0 +1,761 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# flake8: noqa +from aitemplate.compiler import ops + +from aitemplate.frontend import nn, Tensor +from aitemplate.testing import detect_target +from modeling.attention import AttentionBlock + +from modeling.clip import SpatialTransformer +from modeling.resnet import Downsample2D, ResnetBlock2D, Upsample2D + +# pylint: disable=W0102 + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, + downsample_padding=None, +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + SpatialTransformer( + in_channels, + attn_num_head_channels, + in_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_upsample=True, + ): + super().__init__() + + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.concatenate()( + [hidden_states, res_hidden_states], dim=-1 + ) + + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.concatenate()( + [hidden_states, res_hidden_states], dim=-1 + ) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + batch_size, + height, + width, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + **kwargs, + ): + super().__init__() + + if attention_type != "default": + raise NotImplementedError( + f"attention_type must be default! current value: {attention_type}" + ) + + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + AttentionBlock( + batch_size, + height, + width, + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/examples/05_stable_diffusion/modeling/vae.py b/examples/05_stable_diffusion/modeling/vae.py new file mode 100644 index 000000000..6a239f233 --- /dev/null +++ b/examples/05_stable_diffusion/modeling/vae.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py. +""" + +from typing import Tuple + +from aitemplate.frontend import nn, Tensor +from modeling.unet_blocks import get_up_block, UNetMidBlock2D + + +class Decoder(nn.Module): + def __init__( + self, + batch_size, + height, + width, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2dBias( + in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 + ) + + # mid + self.mid_block = UNetMidBlock2D( + batch_size, + height, + width, + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=num_groups_out, + eps=1e-6, + use_swish=True, + ) + self.conv_out = nn.Conv2dBias( + block_out_channels[0], out_channels, kernel_size=3, padding=1, stride=1 + ) + + def forward(self, z) -> Tensor: + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_out(sample) + + return sample + + +class AutoencoderKL(nn.Module): + def __init__( + self, + batch_size: int, + height: int, + width: int, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + sample_size: int = 32, + ): + super().__init__() + self.decoder = Decoder( + batch_size, + height, + width, + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + self.post_quant_conv = nn.Conv2dBias( + latent_channels, latent_channels, kernel_size=1, stride=1, padding=0 + ) + + def decode(self, z: Tensor, return_dict: bool = True): + + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self): + raise NotImplementedError("Only decode() is implemented for AutoencoderKL!") diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py new file mode 100644 index 000000000..211fc99d9 --- /dev/null +++ b/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py @@ -0,0 +1,371 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect + +import os +import warnings +from typing import List, Optional, Union + +import torch +from aitemplate.compiler import Model + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +class StableDiffusionAITPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + workdir = "tmp/" + self.clip_ait_exe = self.init_ait_module( + model_name="CLIPTextModel", workdir=workdir + ) + self.unet_ait_exe = self.init_ait_module( + model_name="UNet2DConditionModel", workdir=workdir + ) + self.vae_ait_exe = self.init_ait_module( + model_name="AutoencoderKL", workdir=workdir + ) + + def init_ait_module( + self, + model_name, + workdir, + ): + mod = Model(os.path.join(workdir, model_name, "test.so")) + return mod + + def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states): + exe_module = self.unet_ait_exe + timesteps_pt = timesteps.expand(latent_model_input.shape[0]) + inputs = { + "input0": latent_model_input.permute((0, 2, 3, 1)) + .contiguous() + .cuda() + .half(), + "input1": timesteps_pt.cuda().half(), + "input2": encoder_hidden_states.cuda().half(), + } + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + noise_pred = ys[0].permute((0, 3, 1, 2)).float() + return noise_pred + + def clip_inference(self, input_ids, seqlen=64): + exe_module = self.clip_ait_exe + position_ids = torch.arange(seqlen).expand((1, -1)).cuda() + inputs = { + "input0": input_ids, + "input1": position_ids, + } + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + return ys[0].float() + + def vae_inference(self, vae_input): + exe_module = self.vae_ait_exe + inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + vae_out = ys[0].permute((0, 3, 1, 2)).float() + return vae_out + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=64, # self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.clip_inference(text_input.input_ids.to(self.device)) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.clip_inference( + uncond_input.input_ids.to(self.device) + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=latents_device, + ) + else: + if latents.shape != latents_shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" + ) + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet_inference( + latent_model_input, t, encoder_hidden_states=text_embeddings + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, i, latents, **extra_step_kwargs + ).prev_sample + else: + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_inference(latents) + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="pt" + ).to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values + ) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/examples/06_how_to_add_an_op/how_to_add_an_op.py b/examples/06_how_to_add_an_op/how_to_add_an_op.py new file mode 100644 index 000000000..cd1646aeb --- /dev/null +++ b/examples/06_how_to_add_an_op/how_to_add_an_op.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +import jinja2 +import torch + +from aitemplate import backend +from aitemplate.backend import registry +from aitemplate.backend.backend_spec import CUDASpec, ROCMSpec +from aitemplate.compiler import compile_model +from aitemplate.compiler.base import IntVar, Operator, Tensor +from aitemplate.testing import detect_target + + +class add_one(Operator): + def __init__(self): + super().__init__() + # required, unique identity of operator category + self._attrs["op"] = "add_one" + # we can put whatever we want into the op attrs for later use + self._attrs["has_profiler"] = False + self._attrs["nop"] = False + + def __call__(self, x: Tensor) -> Tensor: + # each operator needs to keep a record of input tensors + self._attrs["inputs"] = [x] + # optional, to set depth of the op based on inputs' depth, used in DFS + self._set_depth() + # infer output shape + output_shape = self._infer_shape(x) + # create output Tensor, of which the source op is the current op + output = Tensor(output_shape, src_ops={self}) + # remember current op's outputs + self._attrs["outputs"] = [output] + return output + + def _infer_shape(self, x) -> List[IntVar]: + return x.shape() + + def gen_function(self) -> str: + # this function will be used in codegen + # here we only need to redirect to backend codegen function + target = backend.target.Target.current() + func_key = f"{target.name()}.{self._attrs['op']}.gen_function" + func = registry.get(func_key) + return func(self._attrs) + + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{kernel}} + +} // namespace + +{{func_signature}} +{ + invoke_add_one(output, input, num_elements, stream); +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(half* output, + const half* input, + const int64_t num_elements, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}int64_t num_elements = 1; +{% for dim_name in dim_names %} +{{indent}}num_elements *= {{dim_name}}; +{% endfor %} + +{{indent}}{{func_name}}( +{{indent}} {{output}}, {{input}}, num_elements, stream /* default stream */ +{{indent}}); + """ +) + + +KERNEL_TEMPLATE = jinja2.Template( + """ +__global__ void add_one(half* output, const half* input, const int64_t num_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + output[idx] = input[idx] + half(1.0); + } +} + +void invoke_add_one(half* output, const half* input, int64_t num_elements, {{prefix}}Stream_t stream) { + if (num_elements < 1024) { + dim3 grid(1); + dim3 block(num_elements); + add_one<<>>(output, input, num_elements); + } else { + dim3 grid((num_elements + 1024 - 1) / 1024); + dim3 block(1024); + add_one<<>>(output, input, num_elements); + } +} + """ +) + + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + """reinterpret_cast( + {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})""" +) + + +def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> str: + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 1 + + output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda + ) + input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda + ) + + dim_names = [dim._attrs["name"] for dim in func_attrs["inputs"][0].shape()] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + dim_names=dim_names, + indent=indent, + ) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + prefix = backend_spec.prefix + return FUNC_TEMPLATE.render( + header_files=header_files, + kernel=KERNEL_TEMPLATE.render(prefix=prefix), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], prefix=prefix + ), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + ).strip() + ) + + +CUDA_HEADER_FILES = """ +#include +""" + + +@registry.reg("cuda.add_one.gen_function") +def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: + return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec()) + + +@registry.reg("cuda.add_one.func_decl") +def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: + return gen_function_decl(func_attrs, CUDASpec()) + + +@registry.reg("cuda.add_one.func_call") +def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + return gen_function_call(func_attrs, indent, is_cuda=True) + + +HIP_HEADER_FILES = """ +#include +#include +""" + + +@registry.reg("rocm.add_one.gen_function") +def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: + return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec()) + + +@registry.reg("rocm.add_one.func_decl") +def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: + return gen_function_decl(func_attrs, ROCMSpec()) + + +@registry.reg("rocm.add_one.func_call") +def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + return gen_function_call(func_attrs, indent, is_cuda=False) + + +def create_ait_model(shapes): + X = Tensor( + shape=shapes, + dtype="float16", + name="X", + is_input=True, + ) + Y = add_one()(X) + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + return Y + + +def verify_add_one(): + shapes = [16, 512] + x = torch.randn(shapes).cuda().half() + y_pt = x + 1.0 + + Y = create_ait_model([16, 512]) + target = detect_target() + with compile_model(Y, target, "./tmp", "add_one") as module: + y = torch.empty(shapes).cuda().half() + inputs = {"X": x} + outputs = {"Y": y} + module.run_with_tensors(inputs, outputs) + print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2)) + + +verify_add_one() diff --git a/examples/07_how_to_run_pt_model/how_to_run_pt_model.py b/examples/07_how_to_run_pt_model/how_to_run_pt_model.py new file mode 100644 index 000000000..993b7c69f --- /dev/null +++ b/examples/07_how_to_run_pt_model/how_to_run_pt_model.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict + +import torch + +from aitemplate.compiler import compile_model +from aitemplate.frontend import nn, Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from aitemplate.utils.graph_utils import sorted_graph_pseudo_code + + +class PTSimpleModel(torch.nn.Module): + def __init__(self, hidden, eps: float = 1e-5): + super().__init__() + self.dense1 = torch.nn.Linear(hidden, 4 * hidden) + self.act1 = torch.nn.functional.gelu + self.dense2 = torch.nn.Linear(4 * hidden, hidden) + self.layernorm = torch.nn.LayerNorm(hidden, eps=eps) + + def forward(self, input): + hidden_states = self.dense1(input) + hidden_states = self.act1(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = hidden_states + input + hidden_states = self.layernorm(hidden_states) + return hidden_states + + +class AITSimpleModel(nn.Module): + def __init__(self, hidden, eps: float = 1e-5): + super().__init__() + self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu") + self.dense2 = nn.Linear(4 * hidden, hidden) + self.layernorm = nn.LayerNorm(hidden, eps=eps) + + def forward(self, input): + hidden_states = self.dense1(input) + hidden_states = self.dense2(hidden_states) + hidden_states = hidden_states + input + hidden_states = self.layernorm(hidden_states) + return hidden_states + + +def map_pt_params(ait_model, pt_model): + ait_model.name_parameter_tensor() + pt_params = dict(pt_model.named_parameters()) + mapped_pt_params = OrderedDict() + for name, _ in ait_model.named_parameters(): + ait_name = name.replace(".", "_") + assert name in pt_params + mapped_pt_params[ait_name] = pt_params[name] + return mapped_pt_params + + +def verify_simple_model(batch_size=1024, hidden=512): + # create pt model + pt_model = PTSimpleModel(hidden).cuda().half() + + # create pt input + x = torch.randn([batch_size, hidden]).cuda().half() + + # run pt model + pt_model.eval() + y_pt = pt_model(x) + + # create ait model + ait_model = AITSimpleModel(hidden) + X = Tensor( + shape=[batch_size, hidden], + name="X", + dtype="float16", + is_input=True, + ) + Y = ait_model(X) + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + + # map pt weights to ait + weights = map_pt_params(ait_model, pt_model) + + # code gen + target = detect_target() + with compile_model( + Y, target, "./tmp", "simple_model_demo", constants=weights + ) as module: + # create storage for output tensor + y = torch.empty([batch_size, hidden]).cuda().half() + + # inputs and outputs dict + inputs = {"X": x} + outputs = {"Y": y} + + # run + module.run_with_tensors(inputs, outputs, graph_mode=True) + + # verify output is correct + print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2)) + + # benchmark ait and pt + count = 1000 + ait_t, _, _ = module.benchmark_with_tensors( + inputs, outputs, graph_mode=True, count=count + ) + print(f"AITemplate time: {ait_t} ms/iter") + + pt_t = benchmark_torch_function(count, pt_model.forward, x) + print(f"PyTorch eager time: {pt_t} ms/iter") + + # check out the fused graph + # there are only fused ops in the final graph + # gemm_rcr_bias_fast_gelu, gemm_rcr_bias_add, and layernorm + graph = module.debug_sorted_graph + print("Final graph:") + print(sorted_graph_pseudo_code(graph)) + + +verify_simple_model() diff --git a/licenses/LICENSE.composable_kernel.txt b/licenses/LICENSE.composable_kernel.txt new file mode 100644 index 000000000..2fe9a8455 --- /dev/null +++ b/licenses/LICENSE.composable_kernel.txt @@ -0,0 +1,28 @@ +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE.cub.txt b/licenses/LICENSE.cub.txt new file mode 100644 index 000000000..6aeea8da6 --- /dev/null +++ b/licenses/LICENSE.cub.txt @@ -0,0 +1,24 @@ +Copyright (c) 2010-2011, Duane Merrill. All rights reserved. +Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.cutlass.txt b/licenses/LICENSE.cutlass.txt new file mode 100644 index 000000000..d9219ec9b --- /dev/null +++ b/licenses/LICENSE.cutlass.txt @@ -0,0 +1,27 @@ +Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.dmlc.txt b/licenses/LICENSE.dmlc.txt new file mode 100644 index 000000000..8dada3eda --- /dev/null +++ b/licenses/LICENSE.dmlc.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/licenses/LICENSE.flash_attention.txt b/licenses/LICENSE.flash_attention.txt new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/licenses/LICENSE.flash_attention.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/licenses/LICENSE.hipcub.txt b/licenses/LICENSE.hipcub.txt new file mode 100644 index 000000000..c284d2bd9 --- /dev/null +++ b/licenses/LICENSE.hipcub.txt @@ -0,0 +1,25 @@ +Copyright (c) 2010-2011, Duane Merrill. All rights reserved. +Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. +Modifications Copyright (c) 2019-2021, Advanced Micro Devices, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.markdown_table.txt b/licenses/LICENSE.markdown_table.txt new file mode 100644 index 000000000..6a5cab0c2 --- /dev/null +++ b/licenses/LICENSE.markdown_table.txt @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2020 hvalev + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/licenses/LICENSE.oneflow.txt b/licenses/LICENSE.oneflow.txt new file mode 100644 index 000000000..f31ebbb41 --- /dev/null +++ b/licenses/LICENSE.oneflow.txt @@ -0,0 +1,202 @@ +Copyright 2020 The OneFlow Authors. All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/licenses/LICENSE.pydot.txt b/licenses/LICENSE.pydot.txt new file mode 100644 index 000000000..741171aa6 --- /dev/null +++ b/licenses/LICENSE.pydot.txt @@ -0,0 +1,21 @@ +Copyright (c) 2014 Carlos Jenkins +Copyright (c) 2014 Lance Hepler +Copyright (c) 2004 Ero Carrera + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE.pytorch.txt b/licenses/LICENSE.pytorch.txt new file mode 100644 index 000000000..04f9ad110 --- /dev/null +++ b/licenses/LICENSE.pytorch.txt @@ -0,0 +1,77 @@ +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.tensorrt.txt b/licenses/LICENSE.tensorrt.txt new file mode 100644 index 000000000..e29455903 --- /dev/null +++ b/licenses/LICENSE.tensorrt.txt @@ -0,0 +1,337 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2021 NVIDIA Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + + PORTIONS LICENSED AS FOLLOWS + + > tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py + + BSD 3-Clause License + + Copyright (c) Soumith Chintala 2016, + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + > samples/common/windows/getopt.c + + Copyright (c) 2002 Todd C. Miller + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + Sponsored in part by the Defense Advanced Research Projects + Agency (DARPA) and Air Force Research Laboratory, Air Force + Materiel Command, USAF, under agreement number F39502-99-1-0512. + + + Copyright (c) 2000 The NetBSD Foundation, Inc. + All rights reserved. + + This code is derived from software contributed to The NetBSD Foundation + by Dieter Baron and Thomas Klausner. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS + BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + - Copyright (c) 2002 Todd C. Miller + - Copyright (c) 2000 The NetBSD Foundation, Inc. + + + > parsers/common/ieee_half.h + > samples/common/half.h + > third_party/ieee/half.h + + The MIT License + + Copyright (c) 2012-2017 Christian Rau + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + + > plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu + > plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h + > plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh + + Copyright 2020 SenseTime + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + DETR + + Copyright 2020 - present, Facebook, Inc + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/licenses/license.header.txt b/licenses/license.header.txt new file mode 100644 index 000000000..78af24e7a --- /dev/null +++ b/licenses/license.header.txt @@ -0,0 +1,13 @@ + Copyright (c) Meta Platforms, Inc. and affiliates. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/python/aitemplate/__init__.py b/python/aitemplate/__init__.py new file mode 100644 index 000000000..ed1d8a72e --- /dev/null +++ b/python/aitemplate/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import sys + +from . import backend, compiler, frontend, testing, utils +from ._libinfo import __version__ # noqa + +if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 7): + PY3STATEMENT = "The minimal Python requirement is Python 3.7" + raise Exception(PY3STATEMENT) + +__all__ = ["backend", "compiler", "frontend", "testing", "utils"] + +root_logger = logging.getLogger(__name__) +info_handle = logging.StreamHandler() +formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s") +info_handle.setFormatter(formatter) +root_logger.addHandler(info_handle) +root_logger.propagate = False + +DEFAULT_LOGLEVEL = logging.getLogger().level +log_level_str = os.environ.get("LOGLEVEL", None) +LOG_LEVEL = ( + getattr(logging, log_level_str.upper()) + if log_level_str is not None + else DEFAULT_LOGLEVEL +) +root_logger.setLevel(LOG_LEVEL) diff --git a/python/aitemplate/_libinfo.py b/python/aitemplate/_libinfo.py new file mode 100644 index 000000000..6aacc3444 --- /dev/null +++ b/python/aitemplate/_libinfo.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# current version +# We use the version of the incoming release for code +__version__ = "0.1.dev0" diff --git a/python/aitemplate/backend/__init__.py b/python/aitemplate/backend/__init__.py new file mode 100644 index 000000000..8e7aaca0d --- /dev/null +++ b/python/aitemplate/backend/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backend for AITemplate. +""" +from . import ( # noqa + backend_spec, + builder, + codegen, + cuda, + profiler_runner, + registry, + rocm, + target, +) + +__all__ = [ + "builder", + "codegen", + "cuda", + "profiler_runner", + "registry", + "rocm", + "target", +] diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py new file mode 100644 index 000000000..44daa1f3c --- /dev/null +++ b/python/aitemplate/backend/backend_spec.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backend Specifications. +""" + +from dataclasses import dataclass, field + +from typing import Dict, List, Tuple + +import jinja2 + +from ..compiler.ops.common.epilogue import FuncEnum +from .target import Target + + +@dataclass +class BackendSpec: + dtype_to_backend_fp16_dtype: Dict[str, str] = field( + default_factory=lambda: { + "float16": "half", + } + ) + + dtype_to_backend_dtype: Dict[str, str] = field( + default_factory=lambda: { + "float16": "half", + "float": "float", + "int64": "int64_t", + } + ) + + backend_datatype_convertors: Dict[str, Dict[str, str]] = field( + default_factory=lambda: { + "half": {"float": "__half2float"}, + "float": {"half": "__float2half_rn"}, + } + ) + + read_num_elements_to_backend_type: List[Tuple[int, str]] = field( + default_factory=lambda: [ + (8, "uint4"), + (4, "uint2"), + (2, "uint"), + (1, "half"), + ] + ) + op_num_elements_to_backend_type: List[Tuple[int, str]] = field( + default_factory=lambda: [ + (2, "half2"), + (1, "half"), + ] + ) + op_type_priority_list: List[str] = field( + default_factory=lambda: [ + "half2", + "half", + "float", + ] + ) + + func_enum_to_func_name: Dict[FuncEnum, Dict[str, str]] = field( + default_factory=lambda: { + FuncEnum.ADD: { + "half2": "__hadd2", + "half": "__hadd", + "float": "__fadd_rn", + }, + FuncEnum.SUB: { + "half2": "__hsub2", + "half": "__hsub", + "float": "__fsub_rn", + }, + FuncEnum.MUL: { + "half2": "__hmul2", + "half": "__hmul", + "float": "__fmul_rn", + }, + FuncEnum.DIV: { + "half2": "__h2div", + "half": "__hdiv", + "float": "__fdiv_rn", + }, + FuncEnum.COS: { + "half2": "h2cos", + "half": "hcos", + "float": "cosf", + }, + FuncEnum.SIN: { + "half2": "h2sin", + "half": "hsin" if Target.current().name() == "cuda" else "hsin_custom", + "float": "sinf", + }, + FuncEnum.TANH: { + "half2": "fast_tanh", + "half": "fast_tanh", + "float": "tanh", + }, + FuncEnum.ABS: { + "half2": "__habs2", + "half": "__habs", + "float": "fabsf", + }, + FuncEnum.LOGE: { + "half2": "h2log", + "half": "hlog", + "float": "logf", + }, + FuncEnum.EXP: { + "half2": "h2exp", + "half": "hexp", + "float": "expf", + }, + FuncEnum.SQRT: { + "half2": "h2sqrt", + "half": "hsqrt", + "float": "sqrtf", + }, + FuncEnum.MAX: { + "half2": "hmax2_nan", + "half": "hmax_nan", + "float": "fmaxf_nan", + }, + FuncEnum.MIN: { + "half2": "hmin2_nan", + "half": "hmin_nan", + "float": "fminf_nan", + }, + FuncEnum.SIGN: { + "half2": "h2sign_custom", + "half": "sign_custom", + "float": "sign_custom", + }, + FuncEnum.SIGMOID: { + "half2": "h2sigmoid_custom", + "half": "hsigmoid_custom", + "float": "fsigmoid_custom", + }, + FuncEnum.LRELU: { + "half2": "leaky_relu", + "half": "leaky_relu", + "float": "leaky_relu", + }, + FuncEnum.HARDTANH: { + "half2": "h2hard_tanh", + "half": "hard_tanh", + "float": "hard_tanh", + }, + FuncEnum.RELU: {"half2": "relu", "half": "relu", "float": "relu"}, + FuncEnum.NAN_TO_NUM: { + "half2": "nan_to_num", + "half": "nan_to_num", + "float": "nan_to_num", + }, + FuncEnum.CLAMP_NAN_TO_NUM: { + "half2": "clamp_nan_to_num", + "half": "clamp_nan_to_num", + "float": "clamp_nan_to_num", + }, + FuncEnum.SILU: { + "half2": "h2silu", + "half": "hsilu", + "float": "fsilu", + }, + } + ) + + def get_backend_type( + self, + num_elements: int, + dtype: str, + num_elements_to_backend_type_list: List[Tuple[int, str]], + ) -> str: + if dtype != "float16": + raise NotImplementedError("Unsupported dtype {}!".format(dtype)) + for num, backend_type in num_elements_to_backend_type_list: + if num_elements % num == 0: + return backend_type + raise RuntimeError( + "Failed to infer data type! num_elements: {}, num_elements_to_backend_type_list: {}".format( + num_elements, num_elements_to_backend_type_list + ) + ) + + def get_candidate_op_types(self, op_t: str) -> List[str]: + res = [] + found = False + for t in self.op_type_priority_list: + if t == op_t: + found = True + if found: + res.append(t) + return res + + def get_dtype_to_dtype(self, dtype: str, type_dict: Dict[str, str]): + data_type = type_dict.get(dtype) + if not data_type: + raise NotImplementedError("Unsupported dtype {}!".format(dtype)) + return data_type + + def get_fp16_dtype(self, dtype: str): + return self.get_dtype_to_dtype(dtype, self.dtype_to_backend_fp16_dtype) + + def dtype_to_backend_type(self, dtype: str): + return self.get_dtype_to_dtype(dtype, self.dtype_to_backend_dtype) + + +@dataclass +class ROCMSpec(BackendSpec): + backend_name = "rocm" + index_type = "int64_t" + prefix = "hip" + stream = "stream" + cub = "hipcub" + + cast_to_half_ptr_template = jinja2.Template("reinterpret_cast({{name}})") + cast_to_const_half_ptr_template = jinja2.Template( + "reinterpret_cast({{name}})" + ) + header_src_template = jinja2.Template( + """ +#include +#include +{{extra_header}} + """ + ) + half2_data_ref = ".data" + + dtype_to_ck_type: Dict[str, str] = field( + default_factory=lambda: { + "float16": "ck::half_t", + "float": "float", + } + ) + + def dtype_to_lib_type(self, dtype: str): + return self.get_dtype_to_dtype(dtype, self.dtype_to_ck_type) + + +@dataclass +class CUDASpec(BackendSpec): + backend_name = "cuda" + index_type = "int64_t" + prefix = "cuda" + stream = "stream" + cub = "cub" + + cast_to_half_ptr_template = jinja2.Template("reinterpret_cast({{name}})") + cast_to_const_half_ptr_template = jinja2.Template( + "reinterpret_cast({{name}})" + ) + header_src_template = jinja2.Template( + """ +#include +{{extra_header}} + """ + ) + + half2_data_ref = "" + dtype_to_cutlass_type: Dict[str, str] = field( + default_factory=lambda: { + "float16": "cutlass::half_t", + "float": "float", + } + ) + + def dtype_to_lib_type(self, dtype: str): + return self.get_dtype_to_dtype(dtype, self.dtype_to_cutlass_type) diff --git a/python/aitemplate/backend/builder.py b/python/aitemplate/backend/builder.py new file mode 100644 index 000000000..80699a79b --- /dev/null +++ b/python/aitemplate/backend/builder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Builder is a module to compile generated source code files into binary objects. +""" + +from __future__ import annotations + +import multiprocessing + +import os +import pathlib +import re +import typing +from typing import Optional + +import jinja2 + +from ..utils import logger +from .target import Target +from .task_runner import BaseRunner, Task + +# pylint: disable=W0221,C0103 + + +def process_task(task: Task) -> None: + """This function extracts stdout and stderr from a finished task. + If the task process return code is not 0, will mark the task as + a failed task. + + Parameters + ---------- + task : Task + A compiling task + """ + stdout = task._stdout + stderr = task._stderr + if task._proc.returncode != 0: + task._failed = True + logger.info( + __name__, + "Failed: [{name}]\ncmd:\n{cmd}\nstderr:\n{stderr}\nstdout:{stdout}".format( + name=task._name, cmd=task._cmd, stderr=stderr, stdout=stdout + ), + ) + task._ret = -1 + else: + logger.debug( + __name__, + "Successful: [{name}]\ncmd:\n{cmd}\nstderr:\n{stderr}\nstdout:{stdout}".format( + name=task._name, cmd=task._cmd, stderr=stderr, stdout=stdout + ), + ) + task._ret = 0 + + +def process_return(task: Task) -> None: + """This function process the task. If task is timeout or failed, + raise a runtime error. + + Parameters + ---------- + task : Task + A compiling task. + + Raises + ------ + RuntimeError + Compiling failed. + """ + if not task.is_timeout() and task.is_failed(): + raise RuntimeError(f"Building failed. Logs:\n{task._stdout}\n{task._stderr}") + + +class Runner(BaseRunner): + """A parallel runner for compiling tasks. + Runner is inherited from BaseRunner. + """ + + def __init__(self, devs: list[int], timeout: int = 10): + """Initialize a parallel runner for building + + Parameters + ---------- + devs : list[int] + CPU ids for compiling + timeout : int, optional + Compiling timeout, by default 10 (seconds) + """ + super().__init__(devs, "builder", timeout) + logger.info( + __name__, + "Using {n} CPU for building".format(n=devs), + ) + self._ftask_proc = process_task + self._fret_proc = process_return + + def push(self, idx: typing.Union[int, str], cmd: str, target: Target) -> None: + """Push a building task into runner + + Parameters + ---------- + idx : Union[int, str] + Task id + cmd : str + bash command for compiling + target : Target + Target device type for building + """ + self._queue.append(Task(idx, cmd, target, shell=True)) + + def pull(self) -> list[None]: + """Pull building results. + Check whether all building tasks are successful. + + Returns + ------- + list + An empty list + """ + ret = super().pull(self._ftask_proc, self._fret_proc) + return ret + + +class Builder(object): + """Builder is a module to compile generated source code + files into binary objects. + """ + + def __init__(self, n_jobs: int = -1, timeout: int = 180) -> None: + """Initialize a parallel builder for compiling source code. + + Parameters + ---------- + n_jobs : int, optional + Run how many parallel compiling job, + by default -1, which will set n_jobs to `multiprocessing.cpu_count()` + timeout : int, optional + Timeout value, by default 180 (seconds) + """ + if n_jobs < 0: + n_jobs = multiprocessing.cpu_count() + num_builder = os.environ.get("NUM_BUILDERS", None) + if num_builder is not None: + n_jobs = int(num_builder) + self._runner = Runner(n_jobs, timeout) + + def build_objs( + self, + files: list[typing.Tuple[str, str]], + cc_cmd: str, + binary_cc_cmd: Optional[str] = None, + ): + """Generate building task for each source code file, then build in parallel + + Parameters + ---------- + files : list[Tuple[str, str]] + list of tuples of source code path and object file path + cc_cmd : str + command line template for building objects + binary_cc_cmd : optional, str + command line template for turning raw binary files (those ending in .bin) into + objects. Since most compilation jobs will not need to compile these, this argument + is optional. + """ + for idx, fpair in enumerate(files): + src, target = fpair + logger.info(__name__, "Building " + target) + if src.endswith(".bin"): + if binary_cc_cmd is None: + raise ValueError( + "Cannot compile .bin file without specifying binary_cc_cmd!" + ) + + src_path = pathlib.Path(src) + target_path = pathlib.Path(target) + compile_cmd = binary_cc_cmd.format( + target=target_path.name, src=src_path.name + ) + containing_dir = str(src_path.parent.absolute()) + # Have to cd into the containing dir so ld doesn't include + # the path in the symbol names; unfortunately, there's no other + # way to control this. + if logger.is_debug(): + cmd = f"cd {containing_dir} && {compile_cmd} && cd -" + else: + # If not in debug mode, remove the original .bin file which can potentially be quite large. + cmd = f"cd {containing_dir} && {compile_cmd} && rm {src_path.name} && cd -" + else: + cmd = cc_cmd.format(target=target, src=src) + + logger.debug(__name__, f"The cmd for building {target} is : {cmd}") + self._runner.push(idx, cmd, target) + self._runner.join() + self._runner.pull() + + def build_so(self, target: Target, objs: list[str]): + """Generate a task to build all objects into a dynamic library + + Parameters + ---------- + target : Target + Device target of dynamic library + objs : list[str] + List of all object file paths for building the dynamic library. + """ + logger.info(__name__, "Building " + target) + cc = Target.current().cc() + compile_options = Target.current().compile_options() + fpic = "-fPIC" + if "nvcc" in cc: + fpic = "-Xcompiler=-fPIC" + cmd = ( + "{cc} -shared ".format(cc=cc) + + fpic + + " " + + compile_options + + " -o {target} {objs}".format(target=target, objs=" ".join(objs)) + ) + logger.debug(__name__, f"The cmd for building {target} is {cmd}") + self._runner.push(0, cmd, target) + self._runner.join() + self._runner.pull() + + def gen_makefile(self, file_pairs, dll_name, workdir, test_name): + + makefile_template = jinja2.Template( + """ +CC = {{cc}} +CFLAGS = {{CFLAGS}} +fPIC_flag = {{fPIC}} + +obj_files = {{obj_files}} + +%.obj : %.{{cpp}} + {{cfile_cmd}} +%.obj : %.bin + {{bfile_cmd}} + +.PHONY: all +all: {{target}} + +{{target}}: $(obj_files) + $(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files) + +clean: + rm -f *.obj test.so +""" + ) + + obj_files = [pair[1].split("/")[-1] for pair in file_pairs] + obj_files = " ".join(obj_files) + + cc = Target.current().cc() + compile_options = Target.current().compile_options() + + fpic, cpp = "-fPIC", "cpp" + if "nvcc" in cc: + fpic, cpp = "-Xcompiler=-fPIC", "cu" + + cfile_cmd = Target.current().compile_cmd(False).format(target="$@", src="$<") + bfile_cmd = Target.current().binary_compile_cmd() + if not bfile_cmd: + bfile_cmd = "" + else: + bfile_cmd = bfile_cmd.format(target="$@", src="$<") + + makefile_str = makefile_template.render( + cc=cc, + cpp=cpp, + CFLAGS=compile_options, + fPIC=fpic, + obj_files=obj_files, + target=dll_name, + cfile_cmd=cfile_cmd, + bfile_cmd=bfile_cmd, + ) + + dumpfile = os.path.join(workdir, test_name, "Makefile") + with open(dumpfile, "w+") as f: + # fix the makefile indentation + f.write(re.sub("^ ", "\t", makefile_str, flags=re.M)) diff --git a/python/aitemplate/backend/codegen.py b/python/aitemplate/backend/codegen.py new file mode 100644 index 000000000..fcd806882 --- /dev/null +++ b/python/aitemplate/backend/codegen.py @@ -0,0 +1,744 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +This module is for generating the final C++ +source code in files from Tensor and Operators. +Functions in this module will be used for generating +function source code files, profiler source code files, +and model driver source code files. +""" + +from __future__ import annotations + +import io +import os +from typing import Any, Dict, List, Optional, Tuple + +from aitemplate.backend.main_templates import MODEL_CONTAINER_TEMPLATE, MODEL_TEMPLATE +from aitemplate.compiler.base import Operator +from aitemplate.compiler.tensor_accessor import TensorAccessor + +from aitemplate.compiler.transform.memory_planning import Workspace + +from ..compiler.base import get_dtype_size, IntImm, IntVar, Tensor +from . import registry +from .target import Target + +# pylint: disable=C0103,W0613,C0301 + +DTYPE_TO_POINTERTYPE: Dict[str, str] = { + "float32": "float*", + "float": "float*", + "int": "int32_t*", + "int32": "int32_t*", + "int64": "int64_t*", +} + + +def gen_profiler(sorted_graph: list[Tensor], workdir: str, dynamic_profiling_strategy): + """Generate operator profiler source code files for the given graph + + Parameters + ---------- + sorted_graph : list[Tensor] + The network after running toposort transformation + workdir : str + Target directory for generated C++ source code files + dynamic_profiling_strategy: DynamicProfileStrategy, optional + A dynamic profiling strategy, used to filter generated profiles at compile time. + Pass-through to gen_profiler kernels of nodes in the graph. + See also: :func:`~aitemplate.compiler.transform.profile.profile` + """ + for node in sorted_graph: + for func in node.src_ops(): + if "has_profiler" in func._attrs and func._attrs["has_profiler"]: + func.gen_profiler(workdir, dynamic_profiling_strategy) + + +def gen_function_src( + sorted_graph: list[Tensor], workdir: str, model_name: str = "" +) -> list[Tuple[str, str]]: + """Generate functions source code files for the given graph + + Parameters + ---------- + sorted_graph : list[Tensor] + The network after running toposort transformation + workdir : str + Target directory for generated C++ source code files + model_name : str, optional + Sub working directory in the workdir for the given model, by default "" + + Returns + ------- + list[Tuple[str, str]] + List of tuple (source file path, object file path) + """ + target = Target.current() + file_pairs = [] + exist_func = set() + prefix = os.path.join(workdir, model_name) + for node in sorted_graph: + for func in node.src_ops(): + fname = func._attrs["name"] + if fname not in exist_func: + src_path = os.path.join(prefix, fname + target.src_extension()) + obj_path = os.path.join(prefix, fname + ".obj") + file_pairs.append((src_path, obj_path)) + with open(src_path, "w") as fo: + fo.write(func.gen_function()) + exist_func.add(fname) + return file_pairs + + +def map_set( + map_name: str, + key_name: str, + value_name: Optional[str] = None, + indent: str = " ", +) -> str: + """Generate a string setting a value in a map. + + If value name is given, sets map_name["key_name"] = value_name. Else, sets + map_name["key_name"] = key_name. Special maps like dim_map may make + additional modificiations to the LHS of this expression. + + Parameters + ---------- + map_name : str + The map to use + key_name : str + The key to set. Will be put into quotes. + value_name : Optional[str] + If set, force map_name["key_name"] = value_name + indent : str + For formatting + + Returns + ------- + str + The formatted map set statement. + """ + if value_name is not None: + value = value_name + else: + value = key_name + if map_name == "dim_map": + # Because ROCM backend uses int64_t while CUDA uses int, + # this is a temporary workaround to cast int64_t* to int*. + # FIXME: After we unified the two backends, + # reinterpret_cast should be removed. + value = f"reinterpret_cast(&{value})" + + return f'{indent}{map_name}["{key_name}"] = {value};' + + +def set_value(lhs: Any, rhs: Any, indent: str = " ") -> str: + return f"{indent}{lhs} = {rhs};" + + +def set_value_from_map(map_name: Any, var_name: Any, indent: str = " ") -> str: + """Generate a string that sets a value to something stored in a map. + + Parameters + ---------- + map_name : str + The map to use + var_name : str + The var_name, used as the name of the value and the key. + indent : str + For formatting + + Returns + ------- + str + The formatted statement. + """ + key = var_name + value = var_name + return f'{indent}{value} = static_cast({map_name}["{key}"]);' + + +def dtype_to_enumerator(dtype): + def _impl(dtype): + if dtype == "float16": + return "kHalf" + elif dtype == "float32" or dtype == "float": + return "kFloat" + elif dtype == "int32" or dtype == "int": + return "kInt" + elif dtype == "int64": + return "kLong" + else: + raise AssertionError(f"unknown dtype {dtype}") + + return f"AITemplateDtype::{_impl(dtype)}" + + +def count_inputs_outputs(graph): + n_inputs = n_outputs = 0 + for node in graph: + if node._attrs["is_input"]: + n_inputs += 1 + if node._attrs["is_output"]: + n_outputs += 1 + return n_inputs, n_outputs + + +def check_not_null( + tensor: Tensor, + tensor_idx: Optional[int] = None, + skip_if_lower_bound_is_zero: bool = False, +) -> str: + """ + Generate a nullptr check to be used by pointer initialization code. + + If skip_if_lower_bound_is_zero == True, no code will be generated + when the Tensor has at least one dynamic dim with a lower bound + of zero. This is most useful for outputs; we put the nullptr + checks at the start of the inference, but we won't know output + shapes until after Run() finishes. We therefore just relax the check + for these outputs - only allow them to be null if their lower bound + is zero, otherwise never allow them to be null. + """ + name = tensor._attrs["name"] + if tensor_idx is None: + check = name + else: + check = f"params[{tensor_idx}].ptr" + + shape = ["1"] + lower_bound_is_zero = False + for dim in tensor._attrs["shape"]: + lower_bound_is_zero |= dim.lower_bound() == 0 + if skip_if_lower_bound_is_zero and lower_bound_is_zero: + return "" + if isinstance(dim, IntImm): + shape.append(str(dim._attrs["values"][0])) + else: + shape.append(dim._attrs["name"]) + + nullptr_check = f"{check} == nullptr" + condition = ( + nullptr_check + # If the lower bound of the shape is positive, never allow + # the tensor to be null. + if not lower_bound_is_zero + # Otherwise, allow it to be null only if the (possibly dynamic) + # size is zero. + else f"{nullptr_check} && {'*'.join(shape)} != 0" + ) + return f""" +if ({condition}) {{ + throw std::runtime_error("Constant {name} was not set! Set the value with set_constant."); +}} + """ + + +def device_copy(dst_tensor: Tensor, src_tensor: Tensor, dst_idx: int) -> str: + src_name = src_tensor._attrs["name"] + dst_ptr = f"params[{dst_idx}].ptr" + shape = ["1"] + for dim in dst_tensor._attrs["shape"]: + if isinstance(dim, IntImm): + shape.append(str(dim._attrs["values"][0])) + else: + shape.append(dim._attrs["name"]) + shape = "*".join(shape) + size = f"{shape} * {get_dtype_size(dst_tensor._attrs['dtype'])}" + return f"DEVICE_CHECK(DeviceToDeviceCopy({dst_ptr}, {src_name}, {size}, stream));" + + +class ModelContainerGenerator: + def __init__( + self, + max_blob_size: int, + max_constant_blob_size: int, + workspace: Workspace, + num_inputs: int, + num_outputs: int, + constants_data_file: io.BytesIO, + output_name_to_idx: Dict[str, int], + ): + self.target = Target.current() + self.f_var_decl = registry.get(self.target.name() + ".lib.var_decl") + self.f_ptr_decl = registry.get(self.target.name() + ".lib.ptr_decl") + + self.constants_data_file = constants_data_file + + self.exist_funcs = set() + self.func_decl = [] + self.tensor_slice = [] + self.tensor_map_set = [] + self.set_inputs = [] + self.func_seq = [] + self.tensor_decl = [] + self.dim_decl = [] + self.device_to_device_copies = [] + self.function_state = [] + self.set_up_constants = [] + self.set_up_param_names = [] + self.set_up_param_dtypes = [] + self.set_up_output_shapes = [] + self.set_up_param_dynamic_shapes = [] + self.state_record = set() + self.visited_func = set() + self.visited_dims = set() + self.set_up_constant_names = [] + self.param_name_to_ptr_idx = {} + + self.num_constants = 0 + self.constants_data_size = 0 + self.owned_constants_init = [] + + self.input_idx = 0 + self.unbound_constant_idx = 0 + self.output_name_to_idx = output_name_to_idx + + ( + self.max_blob_size, + self.max_constant_blob_size, + self.workspace, + self.num_inputs, + self.num_outputs, + ) = ( + max_blob_size, + max_constant_blob_size, + workspace, + num_inputs, + num_outputs, + ) + + def _tensor_slice_func( + self, + node: Tensor, + blob_name: str, + indent=" ", + ) -> str: + offset = node._attrs["offset"] + name = node._attrs["name"] + return f"{indent}{name} = reinterpret_cast({blob_name} + {offset});" + + def _record_param_tensor_info(self, tensor: Tensor, idx: int) -> None: + def max_value(var_or_imm): + if isinstance(var_or_imm, IntImm): + return var_or_imm.value() + else: + assert isinstance(var_or_imm, IntVar) + return var_or_imm.upper_bound() + + shape_init = ", ".join(str(max_value(dim)) for dim in tensor._attrs["shape"]) + param_shape_init = ", ".join( + f'&{dim._attrs["name"]}' for dim in tensor._attrs["shape"] + ) + self.set_up_output_shapes.append( + set_value(f"max_param_shapes_[{idx}]", f"{{{shape_init}}}") + ) + param_shape_init = ", ".join( + f'ParamDim({dim.lower_bound()}, {dim.upper_bound()}, &{dim._attrs["name"]})' + for dim in tensor._attrs["shape"] + ) + self.set_up_param_dynamic_shapes.append( + set_value(f"params[{idx}].shape_ptrs", f"{{{param_shape_init}}}") + ) + name = tensor._attrs["name"] + self.set_up_param_names.append(set_value(f"param_names_[{idx}]", f'"{name}"')) + self.set_up_param_dtypes.append( + set_value( + f"param_dtypes_[{idx}]", + dtype_to_enumerator(tensor.dtype()), + ) + ) + + def _codegen_param_setup( + self, + tensor: Tensor, + ) -> None: + """ + Generate code needed for setting up a constant in Model/ModelContainer. + """ + name = tensor._attrs["name"] + data = tensor._attrs["data"] + if data is not None: + # Owned constant. Set up logic for copying the constant in from *.so. + assert ( + tensor._attrs["offset"] >= 0 + ), f"Constant node '{name}' must have non-negative offset" + self.set_up_constants.append(self._tensor_slice_func(tensor, "constants")) + num_bytes = len(data) + self.constants_data_file.write(data.to_bytes()) + + constant_info = f'ConstantInfo{{"{name}", {self.constants_data_size}, {tensor._attrs["offset"]}, {num_bytes}}}' + self.owned_constants_init.append(constant_info) + self.constants_data_size += num_bytes + self.num_constants += 1 + else: + # Unbound constant. We will expect the user to set this via SetConstant. + self.set_up_constant_names.append( + set_value( + f'unbound_constant_name_to_idx_["{name}"]', + self.unbound_constant_idx, + ) + ) + self._record_param_tensor_info( + tensor, self.unbound_constant_idx + self.num_inputs + self.num_outputs + ) + self.unbound_constant_idx += 1 + self.set_inputs.append(check_not_null(tensor)) + self.set_up_constants.append( + set_value( + f'constant_name_to_ptr_["{name}"]', + f"const_cast(reinterpret_cast(&{name}))", + ) + ) + + def _codegen_input_tensor(self, tensor: Tensor) -> None: + name = tensor._attrs["name"] + view = tensor._attrs["is_view_of"] + assert ( + view is None + ), f"_codegen_input_tensor cannot be called with a view; expected a non-view tensor with is_input=True, got: {tensor}" + self.set_inputs.append( + set_value( + name, + f"static_cast(params[{self.input_idx}].ptr)", + ) + ) + self.set_inputs.append(check_not_null(tensor)) + self.param_name_to_ptr_idx[name] = self.input_idx + self._record_param_tensor_info(tensor, self.input_idx) + self.input_idx += 1 + + def _get_output_idx(self, name: str) -> int: + assert ( + name in self.output_name_to_idx + ), f"Tensor {name} was marked as an output, but its index was not found in output_name_to_index" + # Add num_inputs since we internally store outputs in the same array as inputs w/ + # inputs first + return self.output_name_to_idx[name] + self.num_inputs + + def _codegen_output_aliases_tensor(self, tensor: Tensor) -> None: + name = tensor._attrs["name"] + view = tensor._attrs["is_view_of"] + if tensor._attrs["external_tensor"] is not None: + self.set_inputs.append(set_value(name, view._attrs["name"])) + return + is_view = view is not None + if is_view: + ptr_idx = self.param_name_to_ptr_idx[view._attrs["name"]] + self.set_inputs.append(set_value(name, view._attrs["name"])) + else: + ptr_idx = self._get_output_idx(name) + self.set_inputs.append( + set_value( + name, + f"static_cast(params[{ptr_idx}].ptr)", + ) + ) + + self.param_name_to_ptr_idx[name] = ptr_idx + if tensor._attrs["is_output"]: + self._record_param_tensor_info(tensor, ptr_idx) + self.set_inputs.append( + check_not_null(tensor, skip_if_lower_bound_is_zero=True) + ) + + def _codegen_output_tensor(self, tensor: Tensor) -> None: + is_param = tensor._attrs["is_param"] + is_input = tensor._attrs["is_input"] + view = tensor._attrs["is_view_of"] + is_view = view is not None + external_tensor = tensor._attrs["external_tensor"] + name = tensor._attrs["name"] + + output_idx = self._get_output_idx(name) + + if is_param: + self._codegen_param_setup(tensor) + self._record_param_tensor_info(tensor, output_idx) + self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx)) + elif external_tensor is not None: + # Special view cases for outputs; we can hit this case if the output + # is a view of a constant, input, or another output. + assert ( + is_view + ), f"orig_tensor is not None, but node {name} is not marked as a view! Node: {tensor}" + self.set_inputs.append( + check_not_null(tensor, output_idx, skip_if_lower_bound_is_zero=True) + ) + self.set_inputs.append(set_value(name, view._attrs["name"])) + self.device_to_device_copies.append( + device_copy(tensor, external_tensor, output_idx) + ) + self._record_param_tensor_info(tensor, output_idx) + elif is_input: + # Inputs that are also outputs require an extra copy + self.set_inputs.append( + set_value( + name, + f"static_cast(params[{self.input_idx}].ptr)", + ) + ) + self._record_param_tensor_info(tensor, self.input_idx) + self._record_param_tensor_info(tensor, output_idx) + self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx)) + self.input_idx += 1 + else: + self._codegen_output_aliases_tensor(tensor) + + def _process_dims(self, shape: List[IntVar]) -> None: + for dim in shape: + if dim._attrs["name"] in self.visited_dims: + continue + intimm = 0 + if len(dim._attrs["values"]) == 1: + intimm = dim._attrs["values"][0] + self.dim_decl.append(self.f_var_decl(dim._attrs["name"], intimm)) + self.visited_dims.add(dim._attrs["name"]) + + def _process_dims_for_tensor(self, node: Tensor) -> None: + self._process_dims(node._attrs["shape"]) + + def _process_dims_for_tensor_accessors( + self, tensor_accessors: List[TensorAccessor] + ) -> None: + if tensor_accessors is None: + return + for accessor in tensor_accessors: + self._process_dims(accessor.original_shapes) + + def _process_dims_for_op(self, node: Operator) -> None: + self._process_dims_for_tensor_accessors(node._attrs.get("input_accessors")) + self._process_dims_for_tensor_accessors(node._attrs.get("output_accessors")) + + def _process_src_ops(self, node: Tensor) -> None: + funcs = node.src_ops() + for func in funcs: + f_func_decl = registry.get( + ".".join((self.target.name(), func._attrs["op"], "func_decl")) + ) + f_func_call = registry.get( + ".".join((self.target.name(), func._attrs["op"], "func_call")) + ) + if func._attrs["name"] not in self.exist_funcs: + self.func_decl.append(f_func_decl(func._attrs)) + self.exist_funcs.add(func._attrs["name"]) + + # Only code gen func once for ops with multiple outputs + # The func can get renamed during refine_graph pass. + # We use original_name here because it's unique. + if func._attrs["original_name"] not in self.visited_func: + self.visited_func.add(func._attrs["original_name"]) + self.func_seq.append(f_func_call(func._attrs, indent=" ")) + if "int_state_flag" in func._attrs: + if func._attrs["name"] not in self.state_record: + self.function_state.append( + f' int64_t {func._attrs["name"]}_state {{0}};' + ) + self.state_record.add(func._attrs["name"]) + self._process_dims_for_op(func) + + def append_tensor(self, node: Tensor) -> None: + if node._attrs["nop"]: + return + name = node._attrs["name"] + dtype = node._attrs["dtype"] + self.tensor_decl.append(self.f_ptr_decl(name=name, dtype=dtype)) + + is_param = node._attrs["is_param"] + is_output = node._attrs["is_output"] + has_output_aliases = node._attrs["has_output_aliases"] + is_input = node._attrs["is_input"] + view = node._attrs["is_view_of"] + is_view = view is not None + + if is_output: + # Outputs have a ton of special cases that depend on + # is_input, is_view, etc, so this condition needs to + # be checked before all the others + self._codegen_output_tensor(node) + elif is_param: + self._codegen_param_setup(node) + elif is_input: + self._codegen_input_tensor(node) + elif has_output_aliases: + # Special case: internal tensor that aliases an output. + self._codegen_output_aliases_tensor(node) + elif not is_view: + # Normal, internal tensor that is not a view: point it to the + # internal blob of memory + assert ( + node._attrs["offset"] >= 0 + ), f"Non-parameter node '{name}' must have non-negative offset" + self.tensor_slice.append(self._tensor_slice_func(node, "blob_ptr")) + else: + # Normal view, point it to the same memory as whatever it + # aliases + self.set_inputs.append(set_value(name, view._attrs["name"])) + + self._process_dims_for_tensor(node) + self._process_src_ops(node) + + def generate_source(self) -> Dict[str, str]: + """ + Perform the codegen after adding all tensors. + The dictionary returned is a map from filename -> contents. + """ + device_functions_header_name = f"{self.target.name()}_device_functions.h" + result = {} + result[ + "device_functions-generated.h" + ] = f'#include "{device_functions_header_name}"' + + # Disable graph mode on ROCM because the updating operations + # are not supported + target_has_graph_mode = "true" if self.target.name() == "cuda" else "false" + + model_def = MODEL_TEMPLATE.render( + function_decl="\n".join(self.func_decl), + device_functions_header=device_functions_header_name, + set_inputs="\n".join(self.set_inputs), + tensor_slice="\n".join(self.tensor_slice), + tensor_map_set="\n".join(self.tensor_map_set), + set_up_constants="\n".join(self.set_up_constants), + device_to_device_copies="\n".join(self.device_to_device_copies), + set_up_param_dynamic_shapes="\n".join(self.set_up_param_dynamic_shapes), + function_seq=self.func_seq, + tensor_decl="\n".join(self.tensor_decl), + dim_decl="\n".join(self.dim_decl), + function_state="\n".join(self.function_state), + target_has_graph_mode=target_has_graph_mode, + unique_workspace_size=self.workspace.unique_size, + ) + + result["model-generated.h"] = model_def + + model_container_src_fname = f"model_container_base{self.target.src_extension()}" + model_container_base_src = MODEL_CONTAINER_TEMPLATE.render( + blob_size=self.max_blob_size, + workspace_size=self.workspace.total_size(), + num_inputs=self.num_inputs, + num_outputs=self.num_outputs, + param_size=self.max_constant_blob_size, + set_up_constant_names="\n".join(self.set_up_constant_names), + set_up_param_dtypes="\n".join(self.set_up_param_dtypes), + set_up_output_shapes="\n".join(self.set_up_output_shapes), + set_up_param_names="\n".join(self.set_up_param_names), + num_constants=self.num_constants, + num_unbound_constants=self.unbound_constant_idx, + owned_constants_init=",".join(self.owned_constants_init), + ) + result[model_container_src_fname] = model_container_base_src + return result + + +def _construct_output_name_to_index_map( + sorted_graph: List[Tensor], output_tensors: List[Tensor] +) -> Dict[str, int]: + """ + Use the given output ordering to construct a name -> index map + to be used for constructing an internal ordering during codegen. + + The indices in the map are propagated to an output's entire alias set. + If two outputs are part of the same alias set, only one of them propagates + its output index. + """ + result = {tensor._attrs["name"]: i for i, tensor in enumerate(output_tensors)} + + # Mark alias sets + for tensor in reversed(sorted_graph): + name = tensor._attrs["name"] + orig = tensor._attrs["is_view_of"] + if orig is None: + continue + orig_name = orig._attrs["name"] + if name in result and orig_name not in result: + result[orig_name] = result[name] + + return result + + +def gen_library_src( # noqa: C901 + sorted_graph: list[Tensor], + max_blob_size: int, + max_constant_blob_size: int, + workspace: Workspace, + workdir: str, + output_tensors: List[Tensor], + model_name: str = "", +) -> list[Tuple[str, str]]: + """Generate model driver source code files for the given graph + + Parameters + ---------- + sorted_graph : list[Tensor] + The network after running toposort transformation + max_blob_size : int + Total memory for input/output tensor and intermediate results, + calculated by memory planning transformation + workspace : Workspace + Workspace sizes, computed by memory planning + workdir : str + Target directory for generated C++ source code files + model_name : str, optional + Sub working directory in the workdir for the given model, by default "" + + Returns + ------- + list[Tuple[str, str]] + List of tuple (source file path, object file path) + """ + + def to_obj_name(name: str): + name, _ = os.path.splitext(name) + return f"{name}.obj" + + num_inputs, num_outputs = count_inputs_outputs(sorted_graph) + prefix = os.path.join(workdir, model_name) + constants_fname = os.path.join(prefix, "constants.bin") + constants_data_file = open(constants_fname, "wb") + + output_name_to_index = _construct_output_name_to_index_map( + sorted_graph, output_tensors + ) + + model_container_generator = ModelContainerGenerator( + max_blob_size, + max_constant_blob_size, + workspace, + num_inputs, + num_outputs, + constants_data_file, + output_name_to_index, + ) + for node in sorted_graph: + model_container_generator.append_tensor(node) + constants_data_file.close() + + files = model_container_generator.generate_source() + to_build = [(constants_fname, to_obj_name(constants_fname))] + for fname, contents in files.items(): + fname_full = os.path.join(prefix, fname) + with open(fname_full, "w") as fo: + fo.write(contents) + if not fname_full.endswith(".h"): + to_build.append((fname_full, to_obj_name(fname_full))) + + # Copy over static csrc/headers + sources = model_container_generator.target.copy_headers_and_csrc_to_workdir(prefix) + for fname in sources: + to_build.append((fname, to_obj_name(fname))) + + return to_build diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py new file mode 100644 index 000000000..99f24bb03 --- /dev/null +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -0,0 +1,839 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +backend concatenate function common templates. +""" +import jinja2 + +from . import tensor_accessor_codegen + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_output_type}} * /*output*/, + {{index_type}} *[] /*output_shape*/, + const {{elem_input_type}} *[] /*inputs*/, + const {{index_type}} *[], /* real_input_shapes, representing shapes of those inputs + whose masks are False, i.e. inputs that will be + copied to the output tensor by concat.*/ + const {{index_type}} *[], /* all_input_shapes, including both kinds of inputs, + i.e. not matter input_mask being True or False */ + const bool [] /*input_masks*/, + const {{index_type}} [] /*concat_dim_sizes*/, + {{index_type}} /*concat_dim*/, + {{index_type}} /*rank*/, + {{index_type}} /*num_real_inputs*/, + {{index_type}} /*num_all_inputs*/, + {{prefix}}Stream_t +); +""" +) + + +KERNEL_SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include + +{{header_src}} + +#ifndef CHECK_ERROR_CAT +#define CHECK_ERROR_CAT(expr) \\ + do { \\ + {{prefix}}Error_t status = (expr); \\ + if (status != {{prefix}}Success) { \\ + auto msg = std::string("Got error: ") + \\ + {{prefix}}GetErrorString(status) + \\ + " at " + __FILE__ + ": " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } while (0) +#endif // CHECK_ERROR_CAT + +#ifndef LAUNCH_CHECK_CAT +#define LAUNCH_CHECK_CAT() CHECK_ERROR_CAT({{prefix}}GetLastError()) +#endif // LAUNCH_CHECK_CAT + +{% if element_func_def %} +{{element_func_def}} +{% endif %} + +namespace { + +{{tensor_accessor_libs}} + +// TODO: support strided tensor with TensorAccessor +// For strided tensor, the index can be much larger than original if the stride is large +bool can_use_32bit_index_math(const int64_t elements, int64_t max_elem=std::numeric_limits::max()) { + if (elements >= max_elem) { + return false; + } + if (elements == 0) { + return max_elem > 0; + } + + return true; +} + +template +struct InputMetaData { + const T *inputs[NumInputs]; /* pointer to each input */ + TensorAccessor input_accessors[NumInputs]; + int64_t concat_dim_offsets[NumInputs]; /* offset of each input along + the concat dimension */ + int64_t concat_dim_values[NumInputs]; /* concat dimension value of + each input */ + int64_t num_elems[NumInputs]; /* number of elements of each input */ +}; + +template <{{index_type}} Rank> +struct OutputMetaData { + int64_t output_shape[Rank]; + int64_t output_strides[Rank]; +}; + +__host__ __device__ __forceinline__ +int64_t get_num_elems(const {{index_type}} *shape, {{index_type}} rank) { + int64_t num = 1; + for ({{index_type}} i = 0; i < rank; i++) { + num *= shape[i]; + } + return num; +} + +template +__host__ __device__ int64_t compute_output_elem_offset( + const int64_t *output_shape, + const int64_t *output_strides, + const INDEX_T input_concat_dim_value, + const INDEX_T concat_dim, + INDEX_T linear_idx) { + INDEX_T offset = 0; + for (INDEX_T i = Rank - 1; i >= 1; --i) { + INDEX_T cur_dim_size = + i == concat_dim ? input_concat_dim_value : output_shape[i]; + INDEX_T next_dim_idx = linear_idx / cur_dim_size; + INDEX_T cur_dim_idx = linear_idx - cur_dim_size * next_dim_idx; + INDEX_T cur_dim_offset = cur_dim_idx * static_cast(output_strides[i]); + offset += cur_dim_offset; + linear_idx = next_dim_idx; + } + return offset + linear_idx * static_cast(output_strides[0]); +} +} // namespace + +template +__global__ void +concatenate_kernel( + ELEM_T *orig_output, + OutputMetaData output_meta, + InputMetaData input_meta, + const INDEX_T concat_dim, + const INDEX_T output_concat_dim_stride) { + const INDEX_T tid = blockIdx.x * blockDim.x + threadIdx.x; + const INDEX_T block_y = blockIdx.y % NumInputs; + READ_T* output = reinterpret_cast(orig_output); + + READ_T* input = const_cast( + reinterpret_cast(input_meta.inputs[block_y])); + const TensorAccessor &input_accessor = input_meta.input_accessors[block_y]; + INDEX_T input_offset = input_meta.concat_dim_offsets[block_y]; + INDEX_T num_input_elems = input_meta.num_elems[block_y]; + INDEX_T input_concat_dim_value = input_meta.concat_dim_values[block_y]; + INDEX_T output_offset = input_offset * output_concat_dim_stride; + + constexpr unsigned read_t_sz = sizeof(READ_T); + constexpr unsigned elem_t_sz = sizeof(ELEM_T); + assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + constexpr INDEX_T n_of_elem_t = read_t_sz / elem_t_sz; + // number of READ_T elements per thread + INDEX_T reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; + const INDEX_T num_elems_in_read_t = num_input_elems / n_of_elem_t; + INDEX_T read_idx = tid; + +#pragma unroll + for (INDEX_T i = 0; i < reads_per_thread_in_read_t; + i++, read_idx += blockDim.x * gridDim.x) { + if (read_idx >= num_elems_in_read_t) { + break; + } + READ_T tmp_v = *(input_accessor.get(input, read_idx)); + /* make sure to adjust read_idx, which refers to location at + (read_idx * n_of_elem_t) actually */ + + INDEX_T output_elem_offset = + compute_output_elem_offset(output_meta.output_shape, + output_meta.output_strides, + input_concat_dim_value, + concat_dim, + read_idx * n_of_elem_t); + {% if element_func %} + output[(output_offset + output_elem_offset) / n_of_elem_t] = {{element_func}}(tmp_v); + {% else %} + output[(output_offset + output_elem_offset) / n_of_elem_t] = tmp_v; + {% endif %} + } +} + +enum class LoadVecType { + VT_HALF = 0, + VT_FLOAT, + VT_FLOAT2, + VT_FLOAT4 +}; + +template +static inline LoadVecType get_vec_type({{index_type}} dim_size) { + {{index_type}} size_elem_t = sizeof(ELEM_T); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + if (sizeof(vec_type) % size_elem_t == 0) { \\ + {{index_type}} n_of_elem_t = sizeof(vec_type) / size_elem_t; \\ + if (dim_size % n_of_elem_t == 0) { \\ + return load_vec_type; \\ + } \\ + } + + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + +#undef HANDLE_ONE_VEC_TYPE + throw std::runtime_error( + "Cannot resolve LoadVecType." + ); +} + +template +void concatenate_kernel_launcher( + ELEM_T *output, + const {{index_type}} *output_shape, + const ELEM_T *inputs[], + const {{index_type}} *real_input_shapes[], + const TensorAccessor *input_accessors[], + const int64_t concat_dim_offsets[], + const {{index_type}} concat_dim, + LoadVecType min_vec_type, + {{prefix}}Stream_t stream) { + + OutputMetaData output_meta; + output_meta.output_strides[Rank - 1] = 1; + output_meta.output_shape[Rank - 1] = output_shape[Rank - 1]; + for (INDEX_T i = Rank - 2; i >= 0; i--) { + output_meta.output_strides[i] = + output_meta.output_strides[i+1] * output_shape[i+1]; + output_meta.output_shape[i] = output_shape[i]; + } + + InputMetaData input_meta; + INDEX_T max_num_input_elems = 0; + for (INDEX_T i = 0; i < NumInputs; i++) { + INDEX_T num_elems = get_num_elems(real_input_shapes[i], Rank); + input_meta.inputs[i] = inputs[i]; + input_meta.input_accessors[i] = *(input_accessors[i]); + input_meta.concat_dim_offsets[i] = concat_dim_offsets[i]; + input_meta.concat_dim_values[i] = real_input_shapes[i][concat_dim]; + input_meta.num_elems[i] = num_elems; + + max_num_input_elems = num_elems > max_num_input_elems ? + num_elems : max_num_input_elems; + } + + constexpr INDEX_T elems_per_block = ThreadsPerBlock * ElemsPerThread; + INDEX_T m = (max_num_input_elems % elems_per_block != 0); + INDEX_T num_blocks_x = + (max_num_input_elems / elems_per_block) + m; + dim3 grid_config = dim3(static_cast(num_blocks_x), NumInputs); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + case load_vec_type: { \\ + if (ElemsPerThread * sizeof(ELEM_T) < sizeof(vec_type)) { \\ + throw std::runtime_error( \\ + std::string("No valid kernel available for ") + #vec_type); \\ + } \\ + concatenate_kernel \\ + <<>>( \\ + output, \\ + output_meta, \\ + input_meta, \\ + concat_dim, \\ + output_meta.output_strides[concat_dim]); \\ + LAUNCH_CHECK_CAT(); \\ + break; \\ + } + + switch (min_vec_type) { + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + default: + throw std::runtime_error("Invalid LoadVecType\\n"); + } + +#undef HANDLE_ONE_VEC_TYPE +} + +#undef CHECK_ERROR_CAT +#undef LAUNCH_CHECK_CAT +""" +) + + +DUMMY_KERNEL_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include +{{header_src}} + +void {{func_name}}( + {{elem_output_type}} *output, + {{index_type}} *output_shape[], + const {{elem_input_type}} *inputs[], + const {{index_type}} *real_input_shapes[], + const {{index_type}} *all_input_shapes[], + const bool input_masks[], + const {{index_type}} concat_dim_sizes[], + {{index_type}} concat_dim, + {{index_type}} rank, + {{index_type}} num_real_inputs, + {{index_type}} num_all_inputs, + {{prefix}}Stream_t stream + ) { +} +""" +) + + +INPUT_ACCESSOR_DEFS_TEMPLATE = jinja2.Template( + """ +{{input_accessors}} + +{{indent}}const TensorAccessor *input_accessors[{{num_real_inputs}}] = { + +{{indent}} {{input_accessor_refs}} + +{{indent}}}; +""" +) + + +EXEC_COND_TEMPLATE = jinja2.Template( + """ +{{indent}}if (rank == {{rank}} && num_real_inputs == {{num_real_inputs}}) { + +{{input_accessor_defs}} + +{{indent}} LoadVecType min_vec_type = LoadVecType::VT_FLOAT4; +{{indent}} int64_t accessor_idx = 0; +{{indent}} for ({{index_type}} i = 0; i < num_all_inputs; i++) { +{{indent}} int local_alignment; +{{indent}} if (!input_masks[i] || +{{indent}} input_accessors[accessor_idx]->stride_dim == -1) { +{{indent}} local_alignment = all_input_shapes[i][rank - 1]; +{{indent}} // int64_t is ok here because this happens on CPU +{{indent}} for (int64_t j = rank - 2; j >= concat_dim; j--) { +{{indent}} local_alignment *= all_input_shapes[i][j]; +{{indent}} } +{{indent}} } else { +{{indent}} local_alignment = +{{indent}} input_accessors[accessor_idx]->max_alignment(); +{{indent}} } +{{indent}} LoadVecType vec_type = get_vec_type<{{elem_type}}>(local_alignment); +{{indent}} min_vec_type = vec_type < min_vec_type ? vec_type : min_vec_type; +{{indent}} if (input_masks[i]) { +{{indent}} accessor_idx++; +{{indent}} } +{{indent}} } + +{{indent}} {{index_type}} local_output_shape[] = { +{% for idx in range(rank - 1) %} +{{indent}} *(output_shape[{{idx}}]), +{% endfor %} +{{indent}} *(output_shape[{{rank - 1}}]) +{{indent}} }; + +{{indent}}/* TODO: more profiling on ElemsPerThread and ThreadsPerBlock */ +{{indent}}if (use_int32_index_math) { +{{indent}} concatenate_kernel_launcher<{{elem_type}}, +{{indent}} int32_t, +{{indent}} {{rank}}/*Rank*/, +{{indent}} {{num_real_inputs}}/*NumInputs*/, +{{indent}} {{elems_per_thread}}/*ElemsPerThread*/, +{{indent}} {{threads_per_block}}/*THREADS_PER_BLOCK*/>( +{{indent}} output, local_output_shape, inputs, real_input_shapes, input_accessors, +{{indent}} concat_dim_offsets.data(), concat_dim, min_vec_type, stream); +{{indent}}} else { +{{indent}} concatenate_kernel_launcher<{{elem_type}}, +{{indent}} int64_t, +{{indent}} {{rank}}/*Rank*/, +{{indent}} {{num_real_inputs}}/*NumInputs*/, +{{indent}} {{elems_per_thread}}/*ElemsPerThread*/, +{{indent}} {{threads_per_block}}/*THREADS_PER_BLOCK*/>( +{{indent}} output, local_output_shape, inputs, real_input_shapes, input_accessors, +{{indent}} concat_dim_offsets.data(), concat_dim, min_vec_type, stream); +{{indent}}} +{{indent}}return; +{{indent}}} +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +{{kernel_src}} + +void {{func_name}}( + {{elem_output_type}} *output, + {{index_type}} *output_shape[], + const {{elem_input_type}} *inputs[], + const {{index_type}} *real_input_shapes[], /* real_input_shapes, representing + shapes of those inputs whose masks are False, + i.e. inputs that will be copied to the output + tensor by concat.*/ + const {{index_type}} *all_input_shapes[], /* all_input_shapes include both + kinds of inputs, i.e. no matter input_mask being + True or False */ + const bool input_masks[], + const {{index_type}} concat_dim_sizes[], + {{index_type}} concat_dim, + {{index_type}} rank, + {{index_type}} num_real_inputs, + {{index_type}} num_all_inputs, + {{prefix}}Stream_t stream + ) { + + if (rank <= 0) { + throw std::runtime_error("rank must be larger than 0!"); + } + if (concat_dim >= rank) { + throw std::runtime_error("concat_dim must be smaller than rank!"); + } + if (num_real_inputs < 1) { + throw std::runtime_error("the number of inputs must >= 1!"); + } + + for ({{index_type}} i = 0; i < rank; i++) { + if (i == concat_dim) continue; + {{index_type}} dim = real_input_shapes[0][i]; + for ({{index_type}} j = 1; j < num_real_inputs; j++) { + if (real_input_shapes[j][i] != dim) { + throw std::runtime_error( + "invalid input shape, func_name: {{func_name}}, dim: " + + std::to_string(dim) + ", input_shape: " + + std::to_string(real_input_shapes[j][i]) + ); + } + } + } + + {{index_type}} output_concat_dim_value = 0; + std::vector concat_dim_offsets; + + for ({{index_type}} i = 0; i < num_all_inputs; i++) { + if (input_masks[i]) { + concat_dim_offsets.push_back(output_concat_dim_value); + } + output_concat_dim_value += concat_dim_sizes[i]; + } + for ({{index_type}} i = 0; i < rank; i++) { + if (i == concat_dim) { + *(output_shape[i]) = output_concat_dim_value; + } else { + *(output_shape[i]) = real_input_shapes[0][i]; + } + } + + // If all input tensors are empty we are done + bool empty = false; + bool use_int32_index_math = true; + for (int i = 0; i < num_real_inputs; i++) { + int64_t num_elems = get_num_elems(real_input_shapes[i], rank); + if (get_num_elems(real_input_shapes[i], rank) != 0) { + empty = false; + // make sure input is valid for each non-zero-size tensor + if (!inputs[i]) { + throw std::runtime_error("NULL input is found at: " + std::to_string(i)); + } + } + if (input_masks[i]) { + use_int32_index_math &= can_use_32bit_index_math(num_elems); + } + } + + if (empty) { + return; + } + + // if the output has any zero dim size, we are done + for (int i = 0; i < rank; i++) { + if (*output_shape[i] == 0) + return; + } + // make sure output is valid + if (!output) { + throw std::runtime_error("output is NULL!"); + } + +{{exec_paths}} + + throw std::runtime_error( + "Unsupported concat kernel specialization!" + ); +} +""" +) + + +INPUT_SHAPE_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}}{{index_type}} {{input_shape_name}}[] = { +{{indent}} {{input_dims}} +{{indent}}}; +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ + +{{indent}} const {{input_elem_type}} *inputs[] = { +{{indent}} {{inputs}} +{{indent}} }; + +{{real_input_shape_defs}} + +{{indent}} const {{index_type}} *real_input_shapes[] = { +{{indent}} {{real_input_shapes}} +{{indent}} }; + +{{all_input_shape_defs}} + +{{indent}} const {{index_type}} *all_input_shapes[] = { +{{indent}} {{all_input_shapes}} +{{indent}} }; + +{{indent}} {{index_type}} *{{output}}_shape[] = { +{{indent}} {{output_dim_refs}} +{{indent}} }; + +{{indent}} {{index_type}} concat_dim_sizes[] = { +{{indent}} {{concat_dim_sizes}} +{{indent}} }; + +{{indent}} bool input_masks[] = { +{{indent}} {{input_masks}} +{{indent}} }; + +{{indent}} {{func_name}}( +{{indent}} {{output_ptr}}, +{{indent}} {{output}}_shape, +{{indent}} inputs, +{{indent}} real_input_shapes, +{{indent}} all_input_shapes, +{{indent}} input_masks, +{{indent}} concat_dim_sizes, +{{indent}} {{concat_dim}}/*concat_dim*/, +{{indent}} {{rank}}/*rank*/, +{{indent}} {{num_real_inputs}}/*num_real_inputs*/, +{{indent}} {{num_all_inputs}}/*num_all_inputs*/, +{{indent}} stream +{{indent}} ); +{{indent}}} +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Generate function declaration. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + Returns + ------- + str + Rendered function declaration. + """ + # get dtype from orig_x in case actual "inputs" is turned into empty + # by some transformation + orig_x = func_attrs["original_inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_backend_type(orig_x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + func_name=func_attrs["name"], + elem_output_type=output_type, + elem_input_type=input_type, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ) + + +def gen_function( + func_attrs, + backend_spec, + element_func=None, + element_func_def=None, +): + """Generates function body. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + index_type: str + Index type. + prefix: str + Backend function prefix, hip/cuda + dtype_to_backend_type: Dict[str, str] + header_src_template: jinja Template + Header src template. + + Returns + ------- + str + Rendered function body. + """ + inputs = func_attrs["inputs"] + original_inputs = func_attrs["original_inputs"] + orig_x = original_inputs[0] + y = func_attrs["outputs"][0] + x_shape = orig_x._attrs["shape"] + + input_type = backend_spec.dtype_to_backend_type(orig_x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + + # TODO: support type cast + if input_type != output_type: + raise NotImplementedError("input type must equal to output type") + + def _stride(shape, dim): + stride = 1 + for v in shape[dim:]: + stride = stride * v._attrs["values"][0] + return stride + + concat_dim = func_attrs["concat_dim"] + assert concat_dim < len(x_shape) + strides = [_stride(i._attrs["shape"], concat_dim) for i in inputs] + # the max number of elements in each concat loop iteration + elems_per_iter = max(strides) if len(strides) > 0 else 0 + threads_per_block = 128 + # minimal number of elems per thread is 8, max is 480 + elems_per_thread = min(480, (int((elems_per_iter / threads_per_block + 8) / 8) * 8)) + + input_accessors = [] + input_accessor_refs = [] + for i in range(len(inputs)): + accessor_name = f"input_accessor{i}" + input_accessor_refs.append(f"&{accessor_name}") + input_accessors.append( + tensor_accessor_codegen.TENSOR_ACCESSOR_TEMPLATE.render( + name=accessor_name, tensor_accessor=func_attrs["input_accessors"][i] + ) + ) + input_accessor_defs = INPUT_ACCESSOR_DEFS_TEMPLATE.render( + indent=" ", + input_accessors="".join(input_accessors), + num_real_inputs=len(inputs), + input_accessor_refs=", ".join(input_accessor_refs), + ) + + # TODO: consider to add profiling paths for tuning + # elems_per_thread and threads_per_block + exec_paths = EXEC_COND_TEMPLATE.render( + indent=" ", + rank=len(x_shape), + num_real_inputs=len(inputs), + input_accessor_defs=input_accessor_defs, + elem_type=input_type, + elems_per_thread=elems_per_thread, + threads_per_block=threads_per_block, + index_type=backend_spec.index_type, + ) + + header_src = backend_spec.header_src_template.render() + if len(inputs) > 0: + tensor_accessor_libs = tensor_accessor_codegen.get_libs() + kernel_src = KERNEL_SRC_TEMPLATE.render( + element_func=element_func, + element_func_def=element_func_def, + header_src=header_src, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + tensor_accessor_libs=tensor_accessor_libs, + ) + return SRC_TEMPLATE.render( + kernel_src=kernel_src, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + exec_paths=exec_paths, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ) + + return DUMMY_KERNEL_TEMPLATE.render( + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + header_src=header_src, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ) + + +def gen_function_call( + func_attrs, + backend_spec, + indent=" ", +): + """Generates function call. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + index_type: str + Index type. + cast_to_const_half_ptr_template: jinja template + Cast to const half ptr template. + cast_to_half_ptr_template: jinja template + Cast to half ptr template. + dtype_to_backend_type: Dict[str, str] + Stores python dtype to backend (rocm, cuda) type. + indent : str, optional + Indent for template, by default " ". + + Returns + ------- + str + Rendered function call. + """ + inputs = func_attrs["inputs"] + input_accessors = func_attrs["input_accessors"] + assert len(inputs) == len(input_accessors), ( + "expected inputs and input_accessors to have the same length, but got: " + f'{len(inputs)}, {len(input_accessors)}, op: {func_attrs["name"]}' + ) + original_inputs = func_attrs["original_inputs"] + orig_x = original_inputs[0] + y = func_attrs["outputs"][0] + concat_dim = func_attrs["concat_dim"] + + input_names = ",\n ".join( + [ + backend_spec.cast_to_const_half_ptr_template.render(name=i._attrs["name"]) + for i in inputs + ] + ) + real_input_shape_defs = [] + real_input_shape_names = [] + for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)): + input_shape_name = f'{i._attrs["name"]}_shape_{idx}' + orig_input_shape = input_accessor.original_shapes + dims = ", ".join([dim._attrs["name"] for dim in orig_input_shape]) + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + real_input_shape_defs.append(one_shape_def) + real_input_shape_names.append(input_shape_name) + + y_shape = y._attrs["shape"] + y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y_shape]) + casted_y_ptr = backend_spec.cast_to_half_ptr_template.render(name=y._attrs["name"]) + + input_masks = func_attrs["input_masks"] + input_indices = [idx for idx, m in enumerate(input_masks) if m is True] + assert len(inputs) == len(input_indices) + concat_dim_sizes = [ + "-1" if mask else str(original_inputs[idx]._attrs["shape"][concat_dim].value()) + for idx, mask in enumerate(input_masks) + ] + + # update dim size for real inputs + for input_accessor, input_index in zip(input_accessors, input_indices): + dim = input_accessor.original_shapes[concat_dim]._attrs["name"] + concat_dim_sizes[input_index] = dim + + input_masks_str = ", ".join( + ["true" if mask is True else "false" for mask in input_masks] + ) + + # all input shape defs and names, including those that are masked out + all_input_shape_defs = [] + all_input_shape_names = [] + # first, create shape defs for inputs that have been masked off + for ( + mask, + orig_input, + ) in zip(input_masks, original_inputs): + if mask is False: + orig_input_shape_name = f'orig_{orig_input._attrs["name"]}_shape' + if orig_input_shape_name not in all_input_shape_names: + dims = ", ".join( + [str(dim._attrs["values"][0]) for dim in orig_input._attrs["shape"]] + ) + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + input_shape_name=orig_input_shape_name, + input_dims=dims, + index_type=backend_spec.index_type, + ) + all_input_shape_defs.append(one_shape_def) + all_input_shape_names.append(orig_input_shape_name) + else: + all_input_shape_names.append("") + # update all_input_shapes with real input shapes + for idx, (input_tensor, input_index) in enumerate(zip(inputs, input_indices)): + input_shape_name = f'{input_tensor._attrs["name"]}_shape_{idx}' + all_input_shape_names[input_index] = input_shape_name + + return FUNC_CALL_TEMPLATE.render( + indent=indent, + input_elem_type=backend_spec.dtype_to_backend_type(orig_x._attrs["dtype"]), + inputs=input_names, + real_input_shape_defs="".join(real_input_shape_defs), + real_input_shapes=", ".join(real_input_shape_names), + all_input_shape_defs="".join(all_input_shape_defs), + all_input_shapes=", ".join(all_input_shape_names), + input_masks=input_masks_str, + concat_dim_sizes=", ".join(concat_dim_sizes), + output_dim_refs=y_dim_refs, + func_name=func_attrs["name"], + output=y._attrs["name"], + output_ptr=casted_y_ptr, + concat_dim=concat_dim, + rank=len(orig_x._attrs["shape"]), + num_real_inputs=len(inputs), + num_all_inputs=len(original_inputs), + index_type=backend_spec.index_type, + ) diff --git a/python/aitemplate/backend/common/elementwise_common.py b/python/aitemplate/backend/common/elementwise_common.py new file mode 100644 index 000000000..14872058a --- /dev/null +++ b/python/aitemplate/backend/common/elementwise_common.py @@ -0,0 +1,881 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backend-agnostic functions for elementwise codegen. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import jinja2 + +from ...compiler.base import IntImm, IntVar, Operator, Tensor +from ...compiler.tensor_accessor import TensorAccessor +from ...utils import shape_utils +from ..backend_spec import BackendSpec +from . import tensor_accessor_codegen + +CONSTANT_TEMPLATE = jinja2.Template( + """ +#define FUSED_ELE_THREAD_SIZE 256 + +const int N_ELEMENTS_PER_THREAD = sizeof({{read_t}}) / sizeof({{data_t}}); +const int N_ELEMENTS_PER_READ = sizeof({{read_t}}) / sizeof({{data_t}}); +const int N_OPS_PER_THREAD = sizeof({{read_t}}) / sizeof({{op_t}}); + """ +) + +KERNEL_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const {{read_t}}* input{{idx}}") +KERNEL_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("{{read_t}}* output{{idx}}") + +KERNEL_TMP_INPUT_TEMPLATE = jinja2.Template("p_tmp_i{{idx}}[i]") +KERNEL_TMP_OUTPUT_TEMPLATE = jinja2.Template("p_tmp_o{{idx}}[i]") + + +GET_STRIDED_ADDRESS_TEMPLATE = jinja2.Template( + """ + {% if tensor_accessor.is_contiguous %} + {{data_ptr}} = get_strided_address( + {{data_ptr}}, {{data_idx}}, {{tensor_accessor.offset}}, 0, 0); + {% else %} + {{data_ptr}} = get_strided_address( + {{data_ptr}}, {{data_idx}}, + {{tensor_accessor.offset}}, + {{tensor_accessor.original_total_elements_from_stride_dim}}, + {{tensor_accessor.actual_total_elements_from_stride_dim}}); + {% endif %} + """ +) + + +KERNEL_READ_INPUT_TEMPLATE = jinja2.Template( + """ + {{read_t}} *{{input_name}} = const_cast<{{read_t}}*>(input{{input_idx}}); + {{get_strided_address}} + {{read_t}} tmp_i{{input_idx}} = *{{input_name}}; + const {{op_t}}* p_tmp_i{{input_idx}} = reinterpret_cast(&tmp_i{{input_idx}}); + + """ +) + + +KERNEL_DEFINE_OUTPUTS_TEMPLATE = jinja2.Template( + """ + {% for idx in indexes %} + {{read_t}} tmp_o{{idx}}; + {{op_t}}* p_tmp_o{{idx}} = reinterpret_cast<{{op_t}}*>(&tmp_o{{idx}}); + {% endfor %} + """ +) + + +KERNEL_WRITE_OUTPUT_TEMPLATE = jinja2.Template( + """ + {{get_strided_address}} + *{{output_name}} = tmp_o{{output_idx}}; + """ +) + + +KERNEL_TEMPLATE = jinja2.Template( + """ +__global__ void +{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} int n_elements) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * FUSED_ELE_THREAD_SIZE + tid; + const int idx_elem = idx * N_ELEMENTS_PER_THREAD; + if (idx_elem >= n_elements) { + return; + } + {{read_inputs}} + {{define_outputs}} +#pragma unroll + for (int i = 0; i < N_OPS_PER_THREAD; ++i) { + {{fused_funcs}} + } + {{write_outputs}} +} + """ +) + +FUNC_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const {{data_t}}* input{{idx}}") +FUNC_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("{{data_t}}* output{{idx}}") +KERNEL_CALL_INPUT_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast(input{{idx}})" +) +KERNEL_CALL_OUTPUT_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast<{{read_t}}*>(output{{idx}})" +) + +FUNC_TEMPLATE = jinja2.Template( + """ +{{head}} + +namespace { + +{{constant}} + +{{custom_libs}} + +{{tensor_accessor_lib}} + +{{kernel_function}} + +} // namespace + +void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} int n_elements, {{prefix}}Stream_t stream) { + if (n_elements == 0) { + return; + } + int block_size = static_cast(std::ceil(static_cast(n_elements) / N_ELEMENTS_PER_THREAD / FUSED_ELE_THREAD_SIZE)); + {{func_name}}<<>>( + {{kernel_call_output_params}}, + {{kernel_call_input_params}}, + {{dynamic_dims_call}} + n_elements + ); +} + """ +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} int n_elements, {{prefix}}Stream_t stream); + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ + {{indent}}int {{func_name}}_n_elements = {{calculate_n}}; + {{indent}}invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{func_name}}_n_elements, {{stream}}); +{{indent}}} + """ +) + + +@dataclass +class ElementwiseMetaData: + func_name: str + op_t: str + args: List[Tensor] + outputs: List[Tensor] + + +@dataclass +class FusedElementwiseMetaData: + # Input / output Tensors and TensorAccessors. + inputs: List[Tensor] + outputs: List[Tensor] + input_accessors: List[TensorAccessor] + output_accessors: List[TensorAccessor] + + # Original input / output Tensors before graph transformation. + # Kept here for elementwise -> fused elementwise Tensor mapping. + original_inputs: List[Tensor] + original_outputs: List[Tensor] + + read_t: str + op_t: str + data_t: str + input_broadcast_sizes: List[List[IntVar]] + dynamic_dims: List[IntVar] + sub_funcs: List[ElementwiseMetaData] + + +def gen_function_single_thread( + fused_func_metadata, + input_names, + output_names, + type_converter, +) -> str: + """Per thread elementwise function codegen.""" + tensor_to_expr: Dict[Tensor, str] = {} + body = "" + + for tensor, name in zip(fused_func_metadata.original_inputs, input_names): + tensor_to_expr[tensor] = name + + tmp_output_idx: int = 0 + for func_metadata in fused_func_metadata.sub_funcs: + params: List[str] = [] + func_op_t = func_metadata.op_t + input_converter = None + output_converter = None + if func_op_t != fused_func_metadata.op_t: + input_converter = type_converter.get(fused_func_metadata.op_t).get( + func_op_t + ) + output_converter = type_converter.get(func_op_t).get( + fused_func_metadata.op_t + ) + assert ( + input_converter is not None + ), "Unsupported convertion from {} to {}".format( + fused_func_metadata.op_t, func_op_t + ) + assert ( + output_converter is not None + ), "Unsupported convertion from {} to {}".format( + func_op_t, fused_func_metadata.op_t + ) + + for arg in func_metadata.args: + if arg in tensor_to_expr: + param = tensor_to_expr[arg] + params.append( + "{}({})".format(input_converter, param) + if input_converter is not None + else param + ) + elif arg.is_a_const_num(): + if func_op_t[-1] == "2": + params.append( + "{}({},{})".format( + func_op_t, + str(arg._attrs["value"]), + str(arg._attrs["value"]), + ) + ) + else: + params.append("{}({})".format(func_op_t, str(arg._attrs["value"]))) + else: + raise RuntimeError( + "Cannot generate expression for node {}, ops: {}".format( + arg, func_metadata + ) + ) + assert ( + len(func_metadata.outputs) == 1 + ), "Operator has more than 1 output! Operator: {}".format(func_metadata) + + output = func_metadata.outputs[0] + func_def = "{}({})".format(func_metadata.func_name, ",".join(params)) + func_def = ( + "{}({})".format(output_converter, func_def) + if output_converter is not None + else func_def + ) + if len(output._attrs["dst_ops"]) > 1: + name = "tmp_" + (str)(tmp_output_idx) + tmp_output_idx += 1 + body += "{} {} = {};\n".format(fused_func_metadata.op_t, name, func_def) + tensor_to_expr[output] = name + else: + tensor_to_expr[output] = func_def + + for tensor, name in zip(fused_func_metadata.original_outputs, output_names): + if tensor not in tensor_to_expr: + raise RuntimeError( + "Cannot generate expression for node {}, outputs: {}".format( + tensor, fused_func_metadata.original_outputs + ) + ) + expr = tensor_to_expr[tensor] + body += "{} = {};\n".format(name, expr) + + return body + + +def _get_sub_func_metadata( + ops: List[Operator], data_t: str, op_t: str, backend_spec: BackendSpec +) -> Tuple[List[ElementwiseMetaData], str]: + candidate_op_types = backend_spec.get_candidate_op_types(op_t) + func_enums = [] + for op in ops: + func_enum = op._attrs["func"] + func_enums.append(func_enum) + funcs = backend_spec.func_enum_to_func_name.get(func_enum) + if funcs is None: + raise NotImplementedError("Func {} is not supported!".format(func_enum)) + for candidate_op_t in candidate_op_types: + func_name = funcs.get(candidate_op_t) + if func_name is not None: + candidate_op_types = backend_spec.get_candidate_op_types(candidate_op_t) + break + if len(candidate_op_types) == 0: + raise RuntimeError( + "Cannot find a common rocm data type! candidate_op_types: {}, op_t: {}.".format( + candidate_op_types, op_t + ) + ) + if op_t in set(candidate_op_types): + op_t = candidate_op_types[0] + else: + op_t = data_t + candidate_op_types = backend_spec.get_candidate_op_types(op_t) + + sub_func_metadata = [] + for op in ops: + func_enum = op._attrs["func"] + funcs = backend_spec.func_enum_to_func_name.get(func_enum) + func_name = None + func_op_t = None + for candidate_op_t in candidate_op_types: + func_name = funcs.get(candidate_op_t) + if func_name is not None: + func_op_t = candidate_op_t + break + if func_name is None: + raise NotImplementedError( + "Unsupported func {} and op type {}!".format(func_enum, op_t) + ) + sub_func_metadata.append( + ElementwiseMetaData( + func_name, func_op_t, op._attrs["args"], op._attrs["outputs"] + ) + ) + return (sub_func_metadata, op_t) + + +def _get_types_and_sizes( + inputs: List[Tensor], + input_accessors: List[TensorAccessor], + output_accessors: List[TensorAccessor], + backend_spec: BackendSpec, +) -> Tuple[int, List[List[IntVar]], str]: + """ + Returns Tuple(alignment, input_broadcast_sizes, dtype) + """ + + # Handle input broadcast. + output_shape = output_accessors[0].original_shapes + dtype = "float16" + input_broadcast_sizes = [] + min_num_elements = None + for input_tensor, input_accessor in zip(inputs, input_accessors): + if input_tensor._attrs["dtype"] != "float16": + raise NotImplementedError( + "Unsupported dtype {}!".format(input_tensor._attrs["dtype"]) + ) + input_shape = input_accessor.original_shapes + broadcastable, _ = shape_utils.get_broadcast_max_shape( + output_shape, input_shape + ) + if not broadcastable: + raise RuntimeError( + "Input shape {} is not compatible with output shape {}!".format( + input_shape, output_shape + ) + ) + num_rightmost_non_broadcast_elements = len(input_shape) + extended_input_shape = list(input_shape) + if input_shape == output_shape: + input_broadcast_sizes.append(None) + else: + extended_input_shape = [IntImm(1)] * len(output_shape) + extended_input_shape[len(output_shape) - len(input_shape) :] = input_shape + input_broadcast_sizes.append(extended_input_shape) + for i in reversed(range(len(extended_input_shape))): + if extended_input_shape[i] != output_shape[i]: + num_rightmost_non_broadcast_elements -= i + 1 + break + num_elements_for_alignments = shape_utils.get_num_rightmost_static_elements( + extended_input_shape, num_rightmost_non_broadcast_elements + ) + if not min_num_elements: + min_num_elements = num_elements_for_alignments + else: + min_num_elements = min(min_num_elements, num_elements_for_alignments) + alignment = tensor_accessor_codegen.find_max_alignment( + min_num_elements, output_accessors + ) + # Note that we use the same alignment for accessing inputs and outputs, although + # they may have different alignment requirements. We may lose perf a little bit, + # but reduce the complexity of our jinja template. We can do some perf + # experiments later to determine if we want to chase more perf gains. + alignment = tensor_accessor_codegen.find_max_alignment(alignment, input_accessors) + return alignment, input_broadcast_sizes, dtype + + +def _get_dynamic_dims(output_accessors: List[TensorAccessor]) -> List[IntVar]: + res = {} + for output_accessor in output_accessors: + for dim in output_accessor.original_shapes: + if not isinstance(dim, IntImm): + res[dim._attrs["name"]] = dim + return res.values() + + +def _parse_func_metadata( + ops: List[Operator], + inputs: List[Tensor], + outputs: List[Tensor], + input_accessors: List[TensorAccessor], + output_accessors: List[TensorAccessor], + original_inputs: List[Tensor], + original_outputs: List[Tensor], + backend_spec: BackendSpec, +) -> FusedElementwiseMetaData: + alignment, input_broadcast_sizes, dtype = _get_types_and_sizes( + inputs, input_accessors, output_accessors, backend_spec + ) + read_type = backend_spec.get_backend_type( + alignment, dtype, backend_spec.read_num_elements_to_backend_type + ) + op_type = backend_spec.get_backend_type( + alignment, dtype, backend_spec.op_num_elements_to_backend_type + ) + data_type = backend_spec.get_fp16_dtype(dtype) + sub_func_metadata, op_type = _get_sub_func_metadata( + ops, data_type, op_type, backend_spec + ) + dynamic_dims = _get_dynamic_dims(output_accessors) + + return FusedElementwiseMetaData( + inputs, + outputs, + input_accessors, + output_accessors, + original_inputs, + original_outputs, + read_type, + op_type, + data_type, + input_broadcast_sizes, + dynamic_dims, + sub_func_metadata, + ) + + +def _gen_int_var_product_str( + int_vars: List[IntVar], +) -> str: + res = [] + for int_var in int_vars: + if isinstance(int_var, IntImm): + res.append(str(int_var._attrs["values"][0])) + elif isinstance(int_var, IntVar): + res.append(int_var._attrs["name"]) + else: + raise RuntimeError( + "A dim must be an IntVar! Current type: {}".format(type(int_var)) + ) + return " * ".join(res) + + +def _gen_input_broadcast_calculator_str( + input_shape: List[IntVar], + output_shape: List[IntVar], +) -> str: + output_num_elements = [] + output_strides = [] + input_strides = [] + + start_idx = 0 + for i, (input_dim, output_dim) in enumerate(zip(input_shape, output_shape)): + if input_dim != output_dim: + assert input_dim == IntImm( + 1 + ), "Unexpected shapes! Input: {}, output: {}".format( + input_shape, output_shape + ) + input_strides.append(input_shape[i:]) + output_strides.append(output_shape[i:]) + output_num_elements.append(output_shape[start_idx:]) + start_idx = i + 1 + if start_idx < len(output_shape): + input_strides.append([IntImm(1)]) + output_strides.append([IntImm(1)]) + output_num_elements.append(output_shape[start_idx:]) + + res = [] + for (output_num_element, output_stride, input_stride) in zip( + output_num_elements, output_strides, input_strides + ): + res.append( + "{} % ({}) / ({}) * ({})".format( + "idx * N_ELEMENTS_PER_THREAD", + _gen_int_var_product_str(output_num_element), + _gen_int_var_product_str(output_stride), + _gen_int_var_product_str(input_stride), + ) + ) + + return " + ".join(res) + + +def _gen_input_broadcast_size_str( + input_broadcast_sizes: List[List[IntVar]], + output_shape: List[IntVar], +) -> List[str]: + res = [] + for input_broadcast_size in input_broadcast_sizes: + if input_broadcast_size is None: + res.append("") + else: + res.append( + _gen_input_broadcast_calculator_str(input_broadcast_size, output_shape) + ) + return res + + +def _gen_dynamic_dim_str( + index_type: str, dynamic_dims: List[IntVar], has_type: bool +) -> str: + type_str = index_type + " " if has_type else "" + res = ", ".join([type_str + dim._attrs["name"] for dim in dynamic_dims]) + if res: + res += ", " + return res + + +def _gen_read_inputs_str( + fused_elementwise_metadata: FusedElementwiseMetaData, broadcast_sizes: List[str] +): + read_inputs = [] + for input_idx, (input_accessor, broadcast_size) in enumerate( + zip(fused_elementwise_metadata.input_accessors, broadcast_sizes) + ): + input_name = f"input_tmp{input_idx}" + data_idx = ( + "idx" + if not broadcast_size + else f"({broadcast_size}) / N_ELEMENTS_PER_THREAD" + ) + get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render( + tensor_accessor=input_accessor, + data_ptr=input_name, + data_t=fused_elementwise_metadata.data_t, + read_t=fused_elementwise_metadata.read_t, + data_idx=data_idx, + ) + read_input = KERNEL_READ_INPUT_TEMPLATE.render( + get_strided_address=get_strided_addr_str, + input_name=input_name, + input_idx=input_idx, + read_t=fused_elementwise_metadata.read_t, + op_t=fused_elementwise_metadata.op_t, + ) + read_inputs.append(read_input) + read_inputs_str = "\n".join(read_inputs) + return read_inputs_str + + +def _gen_write_outputs_str(fused_elementwise_metadata: FusedElementwiseMetaData): + write_outputs = [] + for output_idx, output_accessor in enumerate( + fused_elementwise_metadata.output_accessors + ): + output_name = f"output{output_idx}" + get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render( + tensor_accessor=output_accessor, + data_ptr=output_name, + data_t=fused_elementwise_metadata.data_t, + read_t=fused_elementwise_metadata.read_t, + data_idx="idx", + ) + write_out = KERNEL_WRITE_OUTPUT_TEMPLATE.render( + get_strided_address=get_strided_addr_str, + output_name=output_name, + output_idx=output_idx, + ) + write_outputs.append(write_out) + write_outputs_str = "\n".join(write_outputs) + return write_outputs_str + + +def _gen_kernel_function( + func_attrs: Dict[str, Any], + index_type: str, + fused_elementwise_metadata: FusedElementwiseMetaData, + backend_datatype_convertors: Dict[str, Dict[str, str]], +) -> str: + output_params_decl = ",".join( + [ + KERNEL_DECL_OUTPUT_PARAM_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.outputs) + ] + ) + input_params_decl = ",".join( + [ + KERNEL_DECL_INPUT_PARAM_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.inputs) + ] + ) + + broadcast_sizes = _gen_input_broadcast_size_str( + fused_elementwise_metadata.input_broadcast_sizes, + fused_elementwise_metadata.output_accessors[0].original_shapes, + ) + read_inputs_str = _gen_read_inputs_str(fused_elementwise_metadata, broadcast_sizes) + + define_outputs = KERNEL_DEFINE_OUTPUTS_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, + op_t=fused_elementwise_metadata.op_t, + indexes=list(range(len(fused_elementwise_metadata.outputs))), + ) + write_outputs_str = _gen_write_outputs_str(fused_elementwise_metadata) + + input_names = [ + KERNEL_TMP_INPUT_TEMPLATE.render(idx=i) + for i, _ in enumerate(fused_elementwise_metadata.inputs) + ] + output_names = [ + KERNEL_TMP_OUTPUT_TEMPLATE.render(idx=i) + for i, _ in enumerate(fused_elementwise_metadata.outputs) + ] + fused_funcs = gen_function_single_thread( + fused_elementwise_metadata, + input_names, + output_names, + backend_datatype_convertors, + ) + + kernel_func = KERNEL_TEMPLATE.render( + func_name=func_attrs["name"], + output_params=output_params_decl, + input_params=input_params_decl, + dynamic_dims=_gen_dynamic_dim_str( + index_type, fused_elementwise_metadata.dynamic_dims, has_type=True + ), + read_inputs=read_inputs_str, + define_outputs=define_outputs, + write_outputs=write_outputs_str, + fused_funcs=fused_funcs, + ) + return kernel_func + + +def fused_elementwise_gen_function( + func_attrs: Dict[str, Any], + custom_libs: str, + head_template: str, + backend_spec: BackendSpec, +) -> str: + """Generates fused_elementwise function definition.""" + + ops = func_attrs["elementwise_ops"] + inputs = func_attrs["inputs"] + outputs = func_attrs["outputs"] + input_accessors = func_attrs["input_accessors"] + output_accessors = func_attrs["output_accessors"] + original_inputs = func_attrs["original_inputs"] + original_outputs = func_attrs["original_outputs"] + fused_elementwise_metadata = _parse_func_metadata( + ops, + inputs, + outputs, + input_accessors, + output_accessors, + original_inputs, + original_outputs, + backend_spec, + ) + # Dump data types into func_attr for testing purpose. + func_attrs["read_t"] = fused_elementwise_metadata.read_t + func_attrs["op_t"] = fused_elementwise_metadata.op_t + func_attrs["data_t"] = fused_elementwise_metadata.data_t + + tensor_accessor_lib = tensor_accessor_codegen.get_libs() + tensor_accessor_lib_str = "\n\n" + tensor_accessor_lib + "\n\n" + + kernel_function = _gen_kernel_function( + func_attrs, + backend_spec.index_type, + fused_elementwise_metadata, + backend_spec.backend_datatype_convertors, + ) + output_params_decl = ",".join( + [ + FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render( + data_t=fused_elementwise_metadata.data_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.outputs) + ] + ) + input_params_decl = ",".join( + [ + FUNC_DECL_INPUT_PARAM_TEMPLATE.render( + data_t=fused_elementwise_metadata.data_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.inputs) + ] + ) + kernel_call_output_params = ",".join( + [ + KERNEL_CALL_OUTPUT_PARAM_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.outputs) + ] + ) + kernel_call_input_params = ",".join( + [ + KERNEL_CALL_INPUT_PARAM_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.inputs) + ] + ) + constant = CONSTANT_TEMPLATE.render( + read_t=fused_elementwise_metadata.read_t, + op_t=fused_elementwise_metadata.op_t, + data_t=fused_elementwise_metadata.data_t, + ) + + function = FUNC_TEMPLATE.render( + prefix=backend_spec.prefix, + head=backend_spec.header_src_template.render(extra_header=head_template), + constant=constant, + custom_libs=custom_libs, + tensor_accessor_lib=tensor_accessor_lib_str, + kernel_function=kernel_function, + func_name=func_attrs["name"], + output_params=output_params_decl, + input_params=input_params_decl, + dynamic_dims_decl=_gen_dynamic_dim_str( + backend_spec.index_type, + fused_elementwise_metadata.dynamic_dims, + has_type=True, + ), + dynamic_dims_call=_gen_dynamic_dim_str( + backend_spec.index_type, + fused_elementwise_metadata.dynamic_dims, + has_type=False, + ), + kernel_call_output_params=kernel_call_output_params, + kernel_call_input_params=kernel_call_input_params, + ) + return function + + +def fused_elementwise_gen_function_decl( + func_attrs, + backend_spec: BackendSpec, +): + """Generates fused_elementwise function declaration.""" + + func_name = func_attrs["name"] + ops = func_attrs["elementwise_ops"] + inputs = func_attrs["inputs"] + outputs = func_attrs["outputs"] + input_accessors = func_attrs["input_accessors"] + output_accessors = func_attrs["output_accessors"] + original_inputs = func_attrs["original_inputs"] + original_outputs = func_attrs["original_outputs"] + fused_elementwise_metadata = _parse_func_metadata( + ops, + inputs, + outputs, + input_accessors, + output_accessors, + original_inputs, + original_outputs, + backend_spec, + ) + output_params_decl = ",".join( + [ + FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render( + data_t=fused_elementwise_metadata.data_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.outputs) + ] + ) + input_params_decl = ",".join( + [ + FUNC_DECL_INPUT_PARAM_TEMPLATE.render( + data_t=fused_elementwise_metadata.data_t, idx=i + ) + for i, _ in enumerate(fused_elementwise_metadata.inputs) + ] + ) + + function_decl = FUNC_DECL_TEMPLATE.render( + prefix=backend_spec.prefix, + func_name=func_name, + output_params=output_params_decl, + input_params=input_params_decl, + dynamic_dims=_gen_dynamic_dim_str( + backend_spec.index_type, + fused_elementwise_metadata.dynamic_dims, + has_type=True, + ), + ) + return function_decl + + +def fused_elementwise_gen_function_call( + func_attrs, + indent: str, + backend_spec: BackendSpec, +): + """Generates fused_elementwise function call.""" + ops = func_attrs["elementwise_ops"] + inputs = func_attrs["inputs"] + outputs = func_attrs["outputs"] + input_accessors = func_attrs["input_accessors"] + output_accessors = func_attrs["output_accessors"] + original_inputs = func_attrs["original_inputs"] + original_outputs = func_attrs["original_outputs"] + fused_elementwise_metadata = _parse_func_metadata( + ops, + inputs, + outputs, + input_accessors, + output_accessors, + original_inputs, + original_outputs, + backend_spec, + ) + + output_params_vec = [] + for output in outputs: + if output._attrs["dtype"] != "float16": + raise NotImplementedError( + "Unsupported dtype {}".format(output._attrs["dtype"]) + ) + output_params_vec.append( + backend_spec.cast_to_half_ptr_template.render(name=output._attrs["name"]) + ) + output_params = ",".join(output_params_vec) + + input_params_vec = [] + for inp in inputs: + if inp._attrs["dtype"] != "float16": + raise NotImplementedError( + "Unsupported dtype {}".format(inp._attrs["dtype"]) + ) + input_params_vec.append( + backend_spec.cast_to_half_ptr_template.render(name=inp._attrs["name"]) + ) + input_params = ",".join(input_params_vec) + + num_elements_calculator = _gen_int_var_product_str( + output_accessors[0].original_shapes + ) + + return FUNC_CALL_TEMPLATE.render( + stream=backend_spec.stream, + func_name=func_attrs["name"], + calculate_n=num_elements_calculator, + output_params=output_params, + input_params=input_params, + dynamic_dims=_gen_dynamic_dim_str( + backend_spec.index_type, + fused_elementwise_metadata.dynamic_dims, + has_type=False, + ), + indent=indent, + ) diff --git a/python/aitemplate/backend/common/gemm_common.py b/python/aitemplate/backend/common/gemm_common.py new file mode 100644 index 000000000..eb7bad8b4 --- /dev/null +++ b/python/aitemplate/backend/common/gemm_common.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Backend-agnostic functions for gemm codegen. +""" + +from typing import Dict + +import jinja2 + +from aitemplate.compiler.ops.gemm_universal.gemm_common import DimInfo, Source + +SHAPE_EVAL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{dtype}} {{name}} = {{dim_calculator}}; +""" +) + + +def gen_dim_calculator(dim_info: DimInfo, is_ptr: bool) -> str: + prefix = "*" if is_ptr else "" + if dim_info.source == Source.INPUT: + if dim_info.tensor_idx == 0: + prefix += "a_dim" + else: + assert dim_info.tensor_idx == 1, f"Unsupported gemm dim: {dim_info}" + prefix += "b_dim" + else: + assert ( + dim_info.source == Source.OUTPUT and dim_info.tensor_idx == 0 + ), f"Unsupported gemm dim: {dim_info}" + prefix += "c_dim" + dim_names = ["(" + prefix + str(idx) + ")" for idx in dim_info.dim_idx] + return " * ".join(dim_names) + + +def gen_shape_eval_code( + indent: int, dtype: str, dim_info_dict: Dict[str, DimInfo], is_ptr: bool +) -> str: + shape_eval_list = [] + for name, dim_info_list in dim_info_dict.items(): + dim_info = None + for d in dim_info_list: + if d.placeholder: + continue + + dim_info = d + break + assert dim_info is not None, f"Couldn't find valid dim info for dim {name}" + + shape_eval_list.append( + SHAPE_EVAL_TEMPLATE.render( + dtype=dtype, + indent=" " * indent, + name=name, + dim_calculator=gen_dim_calculator(dim_info, is_ptr), + ) + ) + return "\n".join(shape_eval_list) diff --git a/python/aitemplate/backend/common/split_common.py b/python/aitemplate/backend/common/split_common.py new file mode 100644 index 000000000..9205c90ee --- /dev/null +++ b/python/aitemplate/backend/common/split_common.py @@ -0,0 +1,569 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backend-agnostic function templates for split. +""" +import jinja2 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_output_type}} *[] /*outputs*/, + {{index_type}} **[] /*output_shapes*/, + const {{elem_input_type}} * /*input*/, + const {{index_type}} * /*input_shape*/, + {{index_type}} /*num_splits*/, + {{index_type}} [] /*split_sizes*/, + {{index_type}} /*split_dim*/, + {{index_type}} /*rank*/, + {{prefix}}Stream_t stream +); +""" +) + + +KERNEL_SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include + +{{header_src}} + +#ifndef CHECK_ERROR_SPLIT +#define CHECK_ERROR_SPLIT(expr) \\ + do { \\ + {{prefix}}Error_t status = (expr); \\ + if (status != {{prefix}}Success) { \\ + auto msg = std::string("Got error: ") + \\ + {{prefix}}GetErrorString(status) + \\ + " at " + __FILE__ + ": " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } while (0) +#endif // CHECK_ERROR_SPLIT + +#ifndef LAUNCH_CHECK_SPLIT +#define LAUNCH_CHECK_SPLIT() CHECK_ERROR_SPLIT({{prefix}}GetLastError()) +#endif // LAUNCH_CHECK_SPLIT + +template +struct OutputMetaData { + T* outputs[NumSplits]; /* pointer to each output */ + int64_t split_dim_offsets[NumSplits]; /* offset of each output along + the split dimension */ + int64_t split_dim_sizes[NumSplits]; /* cat dimension size of each output */ + int64_t num_elems[NumSplits]; /* number of the elements of each output */ +}; + +template <{{index_type}} Rank> +struct InputMetaData { + {{index_type}} input_shape[Rank]; + int64_t input_strides[Rank]; +}; + +__host__ __device__ __forceinline__ +int64_t get_num_elems(const {{index_type}} *shape, {{index_type}} rank) { + {{index_type}} num = 1; + for ({{index_type}} i = 0; i < rank; i++) { + num *= shape[i]; + } + return num; +} + +template <{{index_type}} Rank> +__host__ __device__ int64_t compute_input_elem_offset( + const {{index_type}} *input_shape, + int64_t *input_strides, + int64_t split_dim_size, + {{index_type}} split_dim, + int64_t linear_idx) { + int64_t offset = 0; + for ({{index_type}} i = Rank - 1; i >= 1; --i) { + int64_t cur_dim_size = i == split_dim ? split_dim_size : input_shape[i]; + int64_t next_dim_idx = linear_idx / cur_dim_size; + int64_t cur_dim_idx = linear_idx - cur_dim_size * next_dim_idx; + int64_t cur_dim_offset = cur_dim_idx * input_strides[i]; + offset += cur_dim_offset; + linear_idx = next_dim_idx; + } + return offset + linear_idx * input_strides[0]; +} + +template +__global__ void +split_kernel( + const ELEM_T *orig_input, + InputMetaData input_meta, + OutputMetaData output_meta, + const {{index_type}} split_dim, + const int64_t input_split_dim_stride) { + // split is the inverse of concat, so we + // (1) use blockIdx.y to specify the blocks for each ouput; and + // (2) use tid to access each output; + const {{index_type}} tid = blockIdx.x * blockDim.x + threadIdx.x; + const READ_T* input = reinterpret_cast(orig_input); + + READ_T* output = + reinterpret_cast(output_meta.outputs[blockIdx.y]); + int64_t output_offset = output_meta.split_dim_offsets[blockIdx.y]; + int64_t num_output_elems = output_meta.num_elems[blockIdx.y]; + int64_t split_dim_size = output_meta.split_dim_sizes[blockIdx.y]; + int64_t input_offset = output_offset * input_split_dim_stride; + + unsigned read_t_sz = sizeof(READ_T); + unsigned elem_t_sz = sizeof(ELEM_T); + assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + {{index_type}} n_of_elem_t = read_t_sz / elem_t_sz; + // number of READ_T elements per thread + {{index_type}} reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; + const {{index_type}} num_elems_in_read_t = num_output_elems / n_of_elem_t; + {{index_type}} read_idx = tid; + +#pragma unroll + for ({{index_type}} i = 0; i < reads_per_thread_in_read_t; + i++, read_idx += blockDim.x * gridDim.x) { + if (read_idx >= num_elems_in_read_t) { + break; + } + /* make sure to adjust read_idx, which refers to location at + (read_idx * n_of_elem_t) actually */ + int64_t input_elem_offset = + compute_input_elem_offset(input_meta.input_shape, + input_meta.input_strides, + split_dim_size, + split_dim, + read_idx * n_of_elem_t); + + READ_T tmp_v = input[(input_offset + input_elem_offset) / n_of_elem_t]; + output[read_idx] = tmp_v; + } +} + +enum class LoadVecType { + VT_HALF = 0, + VT_FLOAT, + VT_FLOAT2, + VT_FLOAT4 +}; + +template +static inline LoadVecType get_vec_type( + const {{index_type}} *shape, {{index_type}} rank, {{index_type}} dim) { + assert(rank > 0); + assert(dim < rank && dim >= 0); + int64_t running_stride = shape[rank - 1]; + for ({{index_type}} i = rank - 2; i >= dim; i--) { + running_stride *= shape[i]; + } + {{index_type}} size_elem_t = sizeof(ELEM_T); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + if (sizeof(vec_type) % size_elem_t == 0) { \\ + {{index_type}} n_of_elem_t = sizeof(vec_type) / size_elem_t; \\ + if (running_stride % n_of_elem_t == 0) { \\ + return load_vec_type; \\ + } \\ + } + + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + +#undef HANDLE_ONE_VEC_TYPE + throw std::runtime_error( + "Cannot resolve LoadVecType." + ); +} + +template +void split_kernel_launcher( + ELEM_T *outputs[], + {{index_type}} *output_shapes[], + const ELEM_T *input, + const {{index_type}} *input_shape, + const {{index_type}} split_dim, + {{prefix}}Stream_t stream +) { + + InputMetaData input_meta; + input_meta.input_strides[Rank - 1] = 1; + input_meta.input_shape[Rank - 1] = input_shape[Rank - 1]; + for ({{index_type}} i = Rank - 2; i >= 0; i--) { + input_meta.input_strides[i] = + input_meta.input_strides[i+1] * input_shape[i+1]; + input_meta.input_shape[i] = input_shape[i]; + } + + OutputMetaData output_meta; + {{index_type}} offset = 0; + LoadVecType min_vec_type = LoadVecType::VT_FLOAT4; + for ({{index_type}} i = 0; i < NumSplits; i++) { + output_meta.outputs[i] = outputs[i]; + output_meta.split_dim_offsets[i] = offset; + output_meta.split_dim_sizes[i] = output_shapes[i][split_dim]; + output_meta.num_elems[i] = get_num_elems(output_shapes[i], Rank); + offset += output_meta.split_dim_sizes[i]; + LoadVecType vec_type = + get_vec_type(output_shapes[i], Rank, split_dim); + min_vec_type = vec_type < min_vec_type ? vec_type : min_vec_type; + } + + int64_t max_num_output_elems = 0; + for ({{index_type}} i = 0; i < NumSplits; i++) { + {{index_type}} num_outputs = get_num_elems(output_shapes[i], Rank); + max_num_output_elems = num_outputs > max_num_output_elems ? + num_outputs : max_num_output_elems; + } + {{index_type}} m = (max_num_output_elems % (ThreadsPerBlock * ElemsPerThread) != 0); + {{index_type}} num_blocks_x = + (max_num_output_elems / (ThreadsPerBlock * ElemsPerThread)) + m; + dim3 grid_config = dim3(num_blocks_x, NumSplits); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + case load_vec_type: { \\ + if (ElemsPerThread * sizeof(ELEM_T) < sizeof(vec_type)) { \\ + throw std::runtime_error( \\ + std::string("No valid kernel available for ") + #vec_type); \\ + } \\ + split_kernel \\ + <<>>( \\ + input, \\ + input_meta, \\ + output_meta, \\ + split_dim, \\ + input_meta.input_strides[split_dim]); \\ + LAUNCH_CHECK_SPLIT(); \\ + break; \\ + } + + switch (min_vec_type) { + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + default: + throw std::runtime_error("Invalid LoadVecType\\n"); + } + +#undef HANDLE_ONE_VEC_TYPE +} + +#undef CHECK_ERROR_SPLIT +#undef LAUNCH_CHECK_SPLIT + +""" +) + + +EXEC_COND_TEMPLATE = jinja2.Template( + """ +{{indent}}if (rank == {{rank}} && num_splits == {{num_splits}}) { +{% for split_idx in range(num_splits) %} +{{indent}} {{index_type}} local_shape{{split_idx}}[{{rank}}]; +{% for rank_idx in range(rank) %} +{{indent}} local_shape{{split_idx}}[{{rank_idx}}] = input_shape[{{rank_idx}}]; +{% endfor %} +{{indent}} local_shape{{split_idx}}[split_dim] = split_sizes[{{split_idx}}]; + +{% endfor %} + +{{indent}} {{index_type}}* local_output_shapes[{{num_splits}}] = { +{% for idx in range(num_splits - 1) %} +{{indent}} local_shape{{idx}}, +{% endfor %} +{{indent}} local_shape{{num_splits - 1}} +{{indent}} }; +{{indent}} /* TODO: more profiling on ElemsPerThread and ThreadsPerBlock */ +{{indent}} split_kernel_launcher<{{elem_type}}, +{{indent}} {{rank}}/*Rank*/, +{{indent}} {{num_splits}}/*NumSplits*/, +{{indent}} {{elems_per_thread}}/*ElemsPerThread*/, +{{indent}} {{threads_per_block}}/*THREADS_PER_BLOCK*/>( +{{indent}} outputs, local_output_shapes, input, input_shape, split_dim, stream); +{{indent}} return; +{{indent}}} +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +{{kernel_src}} +void {{func_name}}( + {{elem_output_type}}* outputs[], + {{index_type}} **output_shapes[], + const {{elem_input_type}}* input, + const {{index_type}} *input_shape, + {{index_type}} num_splits, + {{index_type}} split_sizes[], + {{index_type}} split_dim, + {{index_type}} rank, + {{prefix}}Stream_t stream + ) { + + if (rank <= 0) { + throw std::runtime_error("rank must be larger than 0!"); + } + if (split_dim >= rank) { + throw std::runtime_error("cat_dim must be smaller than rank!"); + } + if (num_splits < 1) { + throw std::runtime_error("the number of splits must be larger than 0!"); + } + + // now we update the shape for each output + for ({{index_type}} i = 0; i < num_splits; i++) { + {{index_type}} **shape_ptr = output_shapes[i]; + for ({{index_type}} dim_idx = 0; dim_idx < rank; dim_idx++) { + *(shape_ptr[dim_idx]) = input_shape[dim_idx]; + } + // update dim size for the split axis + *(shape_ptr[split_dim]) = split_sizes[i]; + } + + {{index_type}} split_dim_size = input_shape[split_dim]; + {{index_type}} sum_of_split_sizes = 0; + for ({{index_type}} i = 0; i < num_splits; i++) { + sum_of_split_sizes += split_sizes[i]; + } + if (split_dim_size != sum_of_split_sizes) { + throw std::runtime_error("unmatched split dim size!"); + } + + // If split dim is zero, we are done + if (split_dim_size == 0) { + return; + } + // If the input tensor is empty, we are done + if (get_num_elems(input_shape, rank) == 0) { + return; + } + // make sure input and outputs are valid + if (!input) { + throw std::runtime_error("input is NULL!"); + } + for (int i = 0; i < num_splits; i++) { + if (!outputs[i]) { + throw std::runtime_error("NULL output found at: " + std::to_string(i)); + } + } + +{{exec_paths}} + + throw std::runtime_error( + "Unsupported cat kernel specialization!" + ); +} +""" +) + + +OUTPUT_SHAPE_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}}{{index_type}} *{{output_shape_name}}[] = { +{{indent}} {{output_dim_refs}} +{{indent}}}; +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ + +{{indent}} {{output_elem_type}} *outputs[] = { +{{indent}} {{outputs}} +{{indent}} }; + +{{output_shape_defs}} + +{{indent}} {{index_type}} **output_shapes[] = { +{{indent}} {{output_shapes}} +{{indent}} }; + +{{indent}} const {{index_type}} {{input_name}}_shape[] = { +{{indent}} {{input_dims}} +{{indent}} }; + +{{indent}} {{index_type}} split_sizes[] = { +{{indent}} {{split_sizes}} +{{indent}} }; + +{{indent}} {{func_name}}( +{{indent}} outputs, +{{indent}} output_shapes, +{{indent}} {{input_ptr}}, +{{indent}} {{input_name}}_shape, +{{indent}} {{num_splits}}/*num_splits*/, +{{indent}} split_sizes, +{{indent}} {{split_dim}}/*split_dim*/, +{{indent}} {{rank}}/*rank*/, +{{indent}} stream +{{indent}} ); +{{indent}}} +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Generate function declaration. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + Returns + ------- + str + Rendered function declaration. + """ + x = func_attrs["inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + func_name=func_attrs["name"], + elem_output_type=output_type, + elem_input_type=input_type, + ) + + +def gen_function(func_attrs, backend_spec): + """Generates function body. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + + Returns + ------- + str + Rendered function body. + """ + inputs = func_attrs["inputs"] + x = inputs[0] + y = func_attrs["outputs"][0] + x_shape = x._attrs["shape"] + + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + + # TODO: consider to add profiling paths for tuning + # elems_per_thread and threads_per_block + exec_paths = EXEC_COND_TEMPLATE.render( + indent=" ", + rank=len(x_shape), + num_splits=len(func_attrs["split_sizes"]), + elem_type=input_type, + elems_per_thread=128, + threads_per_block=128, + index_type=backend_spec.index_type, + ) + header_src = backend_spec.header_src_template.render() + kernel_src = KERNEL_SRC_TEMPLATE.render( + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + header_src=header_src, + ) + return SRC_TEMPLATE.render( + kernel_src=kernel_src, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + exec_paths=exec_paths, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ) + + +def gen_function_call(func_attrs, backend_spec, indent=" "): + """Generates function call. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + indent : str, optional + Indent for template, by default " ". + + Returns + ------- + str + Rendered function call. + """ + x = func_attrs["inputs"][0] + outputs = func_attrs["outputs"] + y = outputs[0] + split_dim = func_attrs["split_dim"] + num_splits = len(func_attrs["split_sizes"]) + + output_names = ",\n ".join( + [ + backend_spec.cast_to_half_ptr_template.render(name=i._attrs["name"]) + for i in outputs + ] + ) + + output_shape_defs = [] + output_shape_names = [] + for i in outputs: + output_shape_name = "{}_shape".format(i._attrs["name"]) + if output_shape_name not in output_shape_names: + dim_refs = ", ".join( + ["&" + dim._attrs["name"] for dim in i._attrs["shape"]] + ) + one_shape_def = OUTPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", + output_shape_name=output_shape_name, + output_dim_refs=dim_refs, + index_type=backend_spec.index_type, + ) + output_shape_defs.append(one_shape_def) + output_shape_names.append(output_shape_name) + + x_shape = x._attrs["shape"] + x_dims = ", ".join([dim._attrs["name"] for dim in x_shape]) + casted_x_ptr = backend_spec.cast_to_const_half_ptr_template.render( + name=x._attrs["name"] + ) + + split_sizes = ", ".join([str(i) for i in func_attrs["split_sizes"]]) + + return FUNC_CALL_TEMPLATE.render( + indent=indent, + output_elem_type=backend_spec.dtype_to_backend_type(y._attrs["dtype"]), + outputs=output_names, + output_shape_defs="".join(output_shape_defs), + output_shapes=", ".join(output_shape_names), + input_dims=x_dims, + func_name=func_attrs["name"], + input_name=x._attrs["name"], + input_ptr=casted_x_ptr, + split_dim=split_dim, + rank=len(x._attrs["shape"]), + num_splits=num_splits, + split_sizes=split_sizes, + index_type=backend_spec.index_type, + ) diff --git a/python/aitemplate/backend/common/tensor/argmax_common.py b/python/aitemplate/backend/common/tensor/argmax_common.py new file mode 100644 index 000000000..bb422646e --- /dev/null +++ b/python/aitemplate/backend/common/tensor/argmax_common.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +argmax kernel codegen. +""" + +import os +from typing import Any, Dict, List, Tuple + +import jinja2 + +from ... import builder +from ...target import Target + +# pylint: disable=C0301 + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast(&({{name}}->raw()))" +) + +FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{kernel}} + +} // namespace + +{{func_signature}} +{ + + argmax_launcher(stream, elem_cnt, instance_size, instance_num, input, workspace, output); +} + """ +) + +KERNEL_TEMPLATE = jinja2.Template( + """ +const int32_t kThreadsNumPerBlock = 256; +const int32_t kMaxBlocksNum = 8192; + +#define GPU_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline size_t GetAlignedSize(size_t size) { + const size_t kAlignSize = 512; + return (size + kAlignSize - 1) / kAlignSize * kAlignSize; +} + +template + +class TmpBufferManager final { + public: + TmpBufferManager(int32_t capacity, void* ptr, int32_t instance_num) + : capacity_{capacity}, key_value_out_elem_cnt_{instance_num} { + const int32_t key_value_out_aligned_bytes = GetAlignedSize( + key_value_out_elem_cnt_ * sizeof({{cub}}::KeyValuePair)); + + key_value_out_ptr_ = reinterpret_cast<{{cub}}::KeyValuePair*>(ptr); + temp_storage_ptr_ = reinterpret_cast( + reinterpret_cast(key_value_out_ptr_) + + key_value_out_aligned_bytes); + + temp_storage_bytes_ = capacity_ - key_value_out_aligned_bytes; + } + ~TmpBufferManager() = default; + + {{cub}}::KeyValuePair* KeyValueOutPtr() const { + return key_value_out_ptr_; + } + void* TempStoragePtr() const { + return temp_storage_ptr_; + } + + int32_t TempStorageBytes() const { + return temp_storage_bytes_; + } + + private: + int32_t capacity_; + + {{cub}}::KeyValuePair* key_value_out_ptr_; + void* temp_storage_ptr_; + + int32_t key_value_out_elem_cnt_; + int32_t temp_storage_bytes_; +}; + +class MultiplyFunctor final { + public: + MultiplyFunctor(int32_t num_col) : num_col_(num_col) {} + __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const { + return idx * num_col_; + } + + private: + int32_t num_col_; +}; + +template + +size_t InferTempStorageForArgMax(int32_t num_row, int32_t num_col) { + using SegmentOffsetIter = {{cub}}::TransformInputIterator< + int32_t, + MultiplyFunctor, + {{cub}}::CountingInputIterator>; + + {{cub}}::CountingInputIterator counting_iter(0); + MultiplyFunctor multiply_functor(num_col); + SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); + + size_t temp_storage_bytes = 0; + auto err = {{cub}}::DeviceSegmentedReduce:: + ArgMax*, SegmentOffsetIter>( + /* d_temp_storage */ nullptr, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_in */ nullptr, + /* d_out */ nullptr, + /* num_segments */ num_row, + /* d_begin_offsets */ segment_offset_iter, + /* d_end_offsets */ segment_offset_iter + 1, + + /* stream */ 0); + return temp_storage_bytes; +} + +template +void ArgMax( + const T* in_ptr, + int32_t num_row, + int32_t num_col, + void* temp_storage_ptr, + int32_t temp_storage_bytes, + {{cub}}::KeyValuePair* out_ptr, + {{prefix}}Stream_t stream) { + size_t rt_inferred_temp_storage_bytes = + InferTempStorageForArgMax(num_row, num_col); + + using SegmentOffsetIter = {{cub}}::TransformInputIterator< + int32_t, + MultiplyFunctor, + {{cub}}::CountingInputIterator>; + + {{cub}}::CountingInputIterator counting_iter(0); + MultiplyFunctor multiply_functor(num_col); + SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); + + auto err = {{cub}}::DeviceSegmentedReduce::ArgMax( + /* d_temp_storage */ temp_storage_ptr, + /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, + /* d_in */ in_ptr, + /* d_out */ out_ptr, + /* num_segments */ num_row, + /* d_begin_offsets */ segment_offset_iter, + /* d_end_offsets */ segment_offset_iter + 1, + /* stream */ stream); +} + +template +__global__ void WriteKeysToOutput( + const int32_t instance_num, + const int32_t instance_size, + const {{cub}}::KeyValuePair* key_value_out_ptr, + int64_t* out_ptr) { + GPU_KERNEL_LOOP(i, instance_num) { + out_ptr[i] = key_value_out_ptr[i].key{% if is_hipcub %} - instance_size * i{% endif %}; + } +} + +// ALIGNPTR +int64_t* alignPtr(int64_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return (int64_t*)addr; +} + +inline int32_t BlocksNum4ThreadsNum(const int32_t n) { + return std::min( + (n + kThreadsNumPerBlock - 1) / kThreadsNumPerBlock, + kMaxBlocksNum); +} + +template +void argmax_launcher( + {{prefix}}Stream_t stream, + const {{index_type}} elem_cnt, + const {{index_type}} instance_size, + const {{index_type}} instance_num, + const void* input, + void* workspace, + void* output) { + const uintptr_t ALIGNMENT = 32; + int64_t* vworkspace = alignPtr((int64_t*)workspace, ALIGNMENT); + T* tmp_buffer = (T*)vworkspace; + + TmpBufferManager buffer_manager( + static_cast(elem_cnt), tmp_buffer, instance_num); + + ArgMax( + (const T*)input, + instance_num, + instance_size, + buffer_manager.TempStoragePtr(), + buffer_manager.TempStorageBytes(), + buffer_manager.KeyValueOutPtr(), + stream); + + WriteKeysToOutput + <<>>( + instance_num, instance_size, buffer_manager.KeyValueOutPtr(), (int64_t*)output); +} +""" +) + + +PROFILER_TEMPLATE = jinja2.Template( + """ +#include +{{header_files}} +size_t GLOBAL_WORKSPACE_SIZE = 0; + +namespace { +{{kernel}} +} // namespace + +int main(int argc, char** argv) { + int instance_size = std::stoi(argv[1]); + int instance_num = std::stoi(argv[2]); + + float runtime_ms = 0; + int32_t key_value_out_bytes = GetAlignedSize(instance_num * sizeof({{cub}}::KeyValuePair)); + size_t temp_storage_bytes = InferTempStorageForArgMax(instance_num, instance_size); + GLOBAL_WORKSPACE_SIZE = GetAlignedSize(key_value_out_bytes + temp_storage_bytes); + + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(int64_t* output, + const half* input, + const {{index_type}} elem_cnt, + const {{index_type}} instance_size, + const {{index_type}} instance_num, + uint8_t* workspace, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{output}}, {{input}}, +{{indent}} {{elem_cnt}}, +{{indent}} {{instance_size}}, +{{indent}} {{instance_num}}, +{{indent}} global_workspace, stream /* default stream */ +{{indent}}); + """ +) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + """Generates function. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + header_files : str + Includes the header files for a backend. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + str + Rendered function. + """ + index_type = backend_spec.index_type + prefix = backend_spec.prefix + cub = backend_spec.cub + return FUNC_TEMPLATE.render( + header_files=header_files, + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], index_type=index_type, prefix=prefix + ), + kernel=KERNEL_TEMPLATE.render( + cub=cub, index_type=index_type, prefix=prefix, is_hipcub=(cub == "hipcub") + ), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + """Generates function decl. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + str + Rendered function decl. + """ + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ), + ).strip() + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str: + """Generates function call. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec : class + Specifies the backend configurations. + indent : str, optional + Indent for template, by default " ". + + Returns + ------- + str + Rendered function call. + """ + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 1 + + output_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + input_name = backend_spec.cast_to_half_ptr_template.render( + name=func_attrs["inputs"][0]._attrs["name"] + ) + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + + elem_cnt = 1 + for shape in xshape: + elem_cnt *= shape._attrs["values"][0] + instance_size = xshape[-1]._attrs["values"][0] + instance_num = elem_cnt // instance_size + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + elem_cnt=elem_cnt, + instance_size=instance_size, + instance_num=instance_num, + indent=indent, + ) + + +def add_profiler( + file_pairs: List[Tuple[str, str]], + workdir: str, + op_type: str, + output_name: str, + code: str, +): + prefix = os.path.join(workdir, "profiler", op_type) + if not os.path.exists(prefix): + os.makedirs(prefix) + src_path = os.path.join(prefix, output_name + ".cu") + obj_path = os.path.join(prefix, output_name) + if os.path.exists(obj_path): + return + with open(src_path, "w") as f: + f.write(code) + file_pairs.append((src_path, obj_path)) + + +def gen_profiler( + func_attrs: Dict[str, Any], workdir: str, header_files: str, backend_spec +): + """Generates code for argmax profiling. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + workdir: str + Target directory for generated C++ source code files + header_files : str + Includes the header files for a backend. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + None + """ + op_type = func_attrs["op"] + file_pairs = [] + index_type = backend_spec.index_type + prefix = backend_spec.prefix + cub = backend_spec.cub + code = PROFILER_TEMPLATE.render( + header_files=header_files, + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], index_type=index_type, prefix=prefix + ), + kernel=KERNEL_TEMPLATE.render( + cub=cub, index_type=index_type, prefix=prefix, is_hipcub=(cub == "hipcub") + ), + cub=cub, + ) + op_name = func_attrs["op"] + add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + target = Target.current() + compile_engine = builder.Builder() + compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) diff --git a/python/aitemplate/backend/common/tensor/batch_gather_common.py b/python/aitemplate/backend/common/tensor/batch_gather_common.py new file mode 100644 index 000000000..86bbea7a0 --- /dev/null +++ b/python/aitemplate/backend/common/tensor/batch_gather_common.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +batch_gather kernel codegen. +""" + +from typing import Any, Dict + +import jinja2 + +# pylint: disable=C0301 + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + """reinterpret_cast( + {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})""" +) + +FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{kernel}} + +} // namespace + +{{func_signature}} +{ + batch_gather_launcher(stream, batch_num, indices_num, instance_size, gather_dim_size, input, indices, workspace, output); +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(half* output, + const half* input, + const int64_t* indices, + const {{index_type}} batch_num, + const {{index_type}} indices_num, + const {{index_type}} instance_size, + const {{index_type}} gather_dim_size, + uint8_t* workspace, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{output}}, {{input}}, {{indices}}, +{{indent}} {{batch_num}}, +{{indent}} {{indices_num}}, +{{indent}} {{instance_size}}, +{{indent}} {{gather_dim_size}}, +{{indent}} global_workspace, stream /* default stream */ +{{indent}}); + """ +) + +KERNEL_TEMPLATE = jinja2.Template( + """ +const int64_t kThreadsNumPerBlock = 256; +const int64_t kMaxBlocksNum = 8192; + +#define GPU_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ int64_t GetInOffset( + const int64_t out_offset, + const K* indices, + const int64_t indices_num, + const int64_t instance_size, + const int64_t gather_dim_size) { + const int64_t batch_idx = out_offset / (indices_num * instance_size); + const int64_t indices_idx = + out_offset % (indices_num * instance_size) / instance_size; + const int64_t inner_idx = out_offset % instance_size; + const int64_t idx = indices[batch_idx * indices_num + indices_idx]; + assert(idx >= 0 && idx < gather_dim_size); + return batch_idx * gather_dim_size * instance_size + idx * instance_size + + inner_idx; +} + +template +__global__ void BatchGatherGpu( + const int64_t elem_cnt, + const T* in, + const K* indices, + const int64_t indices_num, + const int64_t instance_size, + const int64_t gather_dim_size, + T* out) { + GPU_KERNEL_LOOP(i, elem_cnt) { + out[i] = in[GetInOffset( + i, indices, indices_num, instance_size, gather_dim_size)]; + } +} + +inline int64_t BlocksNum4ThreadsNum(const int64_t n) { + return std::min( + (n + kThreadsNumPerBlock - 1) / kThreadsNumPerBlock, + kMaxBlocksNum); +} +template +void batch_gather_launcher( + {{prefix}}Stream_t stream, + const {{index_type}} batch_num, + const {{index_type}} indices_num, + const {{index_type}} instance_size, + const {{index_type}} gather_dim_size, + const T* input, + const K* indices, + void* workspace, + T* output) { + const int64_t elem_cnt = batch_num * indices_num * instance_size; + BatchGatherGpu + <<>>( + elem_cnt, + input, + indices, + indices_num, + instance_size, + gather_dim_size, + output); +} + """ +) + + +def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> str: + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 2 + + output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda + ) + input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda + ) + indices_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][1]._attrs["name"] + ) + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + indices = func_attrs["inputs"][1] + ind_shape = indices._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + + axis = len(ind_shape) - 1 + batch_num = 1 + for i in range(axis): + batch_num *= yshape[i]._attrs["values"][0] + + indices_num = yshape[axis]._attrs["values"][0] + + instance_size = 1 + for i in range(axis + 1, len(yshape)): + instance_size *= yshape[i]._attrs["values"][0] + + gather_dim_size = xshape[axis]._attrs["values"][0] + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + indices=indices_name, + batch_num=batch_num, + indices_num=indices_num, + instance_size=instance_size, + gather_dim_size=gather_dim_size, + indent=indent, + ) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + index_type = backend_spec.index_type + prefix = backend_spec.prefix + return FUNC_TEMPLATE.render( + header_files=header_files, + kernel=KERNEL_TEMPLATE.render(index_type=index_type, prefix=prefix), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], index_type=index_type, prefix=prefix + ), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ).strip() + ) diff --git a/python/aitemplate/backend/common/tensor/permute021_common.py b/python/aitemplate/backend/common/tensor/permute021_common.py new file mode 100644 index 000000000..db5ed63fd --- /dev/null +++ b/python/aitemplate/backend/common/tensor/permute021_common.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common implementations for all backends for permute021. + +For three dimension input, shift the second and the third dimension. +i.e. Output[d0, d2, d1] = Input[d0, d1, d2] + +""" +from typing import Any, Dict + +import jinja2 + +# pylint: disable=C0301,W0613,W0612 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{lib_dtype}}*, + {{lib_dtype}}*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + {{prefix}}Stream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} ({{lib_dtype}}*)({{in_ptr}}), +{{indent}} ({{lib_dtype}}*)({{out_ptr}}), +{{indent}} {{x_dim0}}, +{{indent}} {{x_dim1}}, +{{indent}} {{x_dim2}}, +{{indent}} {{y_dim0}}, +{{indent}} {{y_dim1}}, +{{indent}} {{y_dim2}}, +{{indent}} stream +{{indent}}); +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}permute021_launcher( +{{indent}} in_ptr, +{{indent}} out_ptr, +{{indent}} *x_dim0, +{{indent}} *x_dim1, +{{indent}} *x_dim2, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t hwi0 = blockIdx.y * 32; + const int32_t ci0 = blockIdx.x * 32; + + const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * 33; + if ((hwi0 + 32) <= hw) { + int hwi = wid; // between 0 and 7 + #pragma unroll + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[8 * c]; + hwi += 8; + } + } else { + for (int hwi = wid; hwi < 32; hwi += 8) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[8 * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + output = &output[ni * hwc + hwiOut]; + if (hwiOut < hw) { + if (ci0 + 32 < c) { + int cI = wid; + #pragma unroll + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + cI += 8; + } + } else { + for (int cI = wid; cI < 32; cI += 8) { + if (ci0 + cI < c) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + } + } + } + } +} + +void permute021_launcher({{lib_dtype}}* in_ptr, + {{lib_dtype}}* out_ptr, + int x_dim0, + int x_dim1, + int x_dim2, + {{prefix}}Stream_t stream) { + const int n = x_dim0; + const int h = 1; + const int w = x_dim1; + const int c = x_dim2; + dim3 grid((c + 31)/32, (h*w + 31)/32, n); + dim3 block(32, 8); + nhwc_to_nchw_kernel<{{lib_dtype}}><<>>( + ({{lib_dtype}}*)out_ptr, + (const {{lib_dtype}}*)in_ptr, + n, + h, + w, + c + ); +} +} // namespace + +void {{function_name}} ( + {{lib_dtype}}* in_ptr, + {{lib_dtype}}* out_ptr, + int64_t* x_dim0, + int64_t* x_dim1, + int64_t* x_dim2, + int64_t* y_dim0, + int64_t* y_dim1, + int64_t* y_dim2, + {{prefix}}Stream_t stream +) { + if (!in_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + if (!out_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + {{shape_function}} + {{exec_paths}} +} + +""" +) + + +def gen_function( + func_attrs: Dict[str, Any], + template_path: str, + shape_eval_template, + shape_save_template, + header_files: str, + backend_spec, +) -> str: + """ + Parameters + ---------- + func_attrs : Dict[str, Any] + Attributes from Operator + template_path : str + path to library used + shape_eval_template : jinja template + shape_save_template : jinja template + header_files : str + header files included in the function + backend_spec : class + specifies backend configs + + Returns + ------- + str + Source code for function generated. + """ + + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + shape_eval_func = shape_eval_template.render( + indent=" ", + dtype="int64_t ", + x_dim0="*x_dim0", + x_dim1="*x_dim1", + x_dim2="*x_dim2", + ) + shape_save_func = shape_save_template.render( + indent=" ", + y_dim0="*y_dim0", + y_dim1="*y_dim1", + y_dim2="*y_dim2", + ) + shape_func = shape_eval_func + shape_save_func + exec_paths = EXEC_TEMPLATE.render() + return SRC_TEMPLATE.render( + function_name=func_name, + header_files=header_files, + shape_function=shape_func, + exec_paths=exec_paths, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + prefix=backend_spec.prefix, + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies backend configs + + Returns + ------- + str + Function declaration + """ + + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + prefix=backend_spec.prefix, + ) + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies backend configs + indent : str, optional + Indentation for function call template, by default " " + + Returns + ------- + str + Driver code for invoking call + """ + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + xdtype = x._attrs["dtype"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + out_ptr=y._attrs["name"], + x_dim0="&" + xshape[0]._attrs["name"], + x_dim1="&" + xshape[1]._attrs["name"], + x_dim2="&" + xshape[2]._attrs["name"], + y_dim0="&" + yshape[0]._attrs["name"], + y_dim1="&" + yshape[1]._attrs["name"], + y_dim2="&" + yshape[2]._attrs["name"], + indent=indent, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + ) diff --git a/python/aitemplate/backend/common/tensor/permute102_common.py b/python/aitemplate/backend/common/tensor/permute102_common.py new file mode 100644 index 000000000..807e65bef --- /dev/null +++ b/python/aitemplate/backend/common/tensor/permute102_common.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common implementations for all backends for permute012. + +For three dimension input, shift the first and the second dimension. +i.e. Output[d1, d0, d2] = Input[d0, d1, d2] + +This is a naive modification over cutlass nhwc to nchw op: +https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nhwc_to_nchw.h +At implementation, it creates d1/32 x d2/32 x d0 blocks, each with 32 x 8 threads, +and each thread processes 4 elements. + +We change the write stage of this cutlass permute op for d1 & d0. +It might not be the most effecient version as applying different dimension on threads +may relate to cache's performance. +""" +from typing import Any, Dict + +import jinja2 + +# pylint: disable=C0301,W0613,W0612 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{lib_dtype}}*, + {{lib_dtype}}*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + {{prefix}}Stream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} ({{lib_dtype}}*){{in_ptr}}, +{{indent}} ({{lib_dtype}}*){{out_ptr}}, +{{indent}} {{x_dim0}}, +{{indent}} {{x_dim1}}, +{{indent}} {{x_dim2}}, +{{indent}} {{y_dim0}}, +{{indent}} {{y_dim1}}, +{{indent}} {{y_dim2}}, +{{indent}} stream +{{indent}}); +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}permute102_launcher( +{{indent}} in_ptr, +{{indent}} out_ptr, +{{indent}} *x_dim0, +{{indent}} *x_dim1, +{{indent}} *x_dim2, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +#define TILE_SIZE 32 +#define CH_K 4 + +namespace { +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[TILE_SIZE * (TILE_SIZE + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / TILE_SIZE;//th.y:0-7 + const int32_t lid = tid % TILE_SIZE;//th.x:0-31 + const int32_t ni0 = blockIdx.z; + const int32_t hwi0 = blockIdx.y * TILE_SIZE;//parallel 8*seq 4 + const int32_t ci0 = blockIdx.x * TILE_SIZE;//parallel 32 + const size_t input_idx = ni0 * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * (TILE_SIZE + 1); + if ((hwi0 + TILE_SIZE - TILE_SIZE / CH_K) <= hw) { + int hwi = wid; // between 0 and 7 + #pragma unroll + for (int cLoopIdx = 0; cLoopIdx < CH_K; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[TILE_SIZE / CH_K * c];//because c is distributed on threads y + hwi += TILE_SIZE / CH_K; + } + } else { + for (int hwi = wid; hwi < TILE_SIZE; hwi += TILE_SIZE / CH_K) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[TILE_SIZE / CH_K * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + const int nc = n*c; + output = &output[hwiOut*nc]; + if(hwiOut < hw){ + if(ci0 + TILE_SIZE < c){ + int cI = wid; + #pragma unroll + for(int hwLoopIdx = 0; hwLoopIdx < CH_K; ++hwLoopIdx){ + output[ni0*c + ci0 + cI] = shbuf[(cI)* (TILE_SIZE + 1) + lid]; + cI += TILE_SIZE / CH_K; + } + } else { + for(int cI = wid; cI < TILE_SIZE; cI += TILE_SIZE / CH_K){ + if(ci0+cI<<>>( + out_ptr, + (const {{lib_dtype}}*)in_ptr, + n, + h, + w, + c + ); +} +} // namespace + +void {{function_name}} ( + {{lib_dtype}}* in_ptr, + {{lib_dtype}}* out_ptr, + int64_t* x_dim0, + int64_t* x_dim1, + int64_t* x_dim2, + int64_t* y_dim0, + int64_t* y_dim1, + int64_t* y_dim2, + {{prefix}}Stream_t stream +) { + if (!in_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + if (!out_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + {{shape_function}} + {{exec_paths}} +} + +""" +) + + +def gen_function( + func_attrs: Dict[str, Any], + template_path: str, + shape_eval_template, + shape_save_template, + header_files: str, + backend_spec, +) -> str: + """ + Parameters + ---------- + func_attrs : Dict[str, Any] + Attributes from Operator + template_path : str + path to library used + shape_eval_template : jinja template + shape_save_template : jinja template + backend_spec : class + specifies backend configs + + Returns + ------- + str + Source code for function generated. + """ + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + shape_eval_func = shape_eval_template.render( + indent=" ", + dtype="int64_t ", + x_dim0="*x_dim0", + x_dim1="*x_dim1", + x_dim2="*x_dim2", + ) + shape_save_func = shape_save_template.render( + indent=" ", + y_dim0="*y_dim0", + y_dim1="*y_dim1", + y_dim2="*y_dim2", + ) + shape_func = shape_eval_func + shape_save_func + exec_paths = EXEC_TEMPLATE.render() + return SRC_TEMPLATE.render( + function_name=func_name, + shape_function=shape_func, + exec_paths=exec_paths, + header_files=header_files, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + prefix=backend_spec.prefix, + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies backend configs + + Returns + ------- + str + Function declaration + """ + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + prefix=backend_spec.prefix, + ) + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies backend configs + indent : str, optional + Indentation for function call template, by default " " + + Returns + ------- + str + Driver code for invoking call + """ + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + xdtype = x._attrs["dtype"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + out_ptr=y._attrs["name"], + x_dim0="&" + xshape[0]._attrs["name"], + x_dim1="&" + xshape[1]._attrs["name"], + x_dim2="&" + xshape[2]._attrs["name"], + y_dim0="&" + yshape[0]._attrs["name"], + y_dim1="&" + yshape[1]._attrs["name"], + y_dim2="&" + yshape[2]._attrs["name"], + indent=indent, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + ) diff --git a/python/aitemplate/backend/common/tensor/permute210_common.py b/python/aitemplate/backend/common/tensor/permute210_common.py new file mode 100644 index 000000000..fa1d5d25a --- /dev/null +++ b/python/aitemplate/backend/common/tensor/permute210_common.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common implementations for all backends for permute210. + +For three dimension input, shift the first and the third dimension. +i.e. Output[d2, d1, d0] = Input[d0, d1, d2] + +We invoke kernel with the following settings: +thread blocks of (TILE_SIZE x TILE_SIZE/4), +grid size of (ceil(d1/TILE_SIZE) x d2 x ceil(d3/TILE_SIZE)) +For each, we have shared memory of size (TILE_SIZE, TILE_SIZE+1) + +The 4 for thread blocks indicates each thread is responsible of 4 elements. +We use TILE_SIZE = 32 for the time being. +""" +from typing import Any, Dict + +import jinja2 + +# pylint: disable=C0301,W0613,W0612 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{lib_dtype}}*, + {{lib_dtype}}*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + {{prefix}}Stream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} static_cast<{{lib_dtype}}*>({{in_ptr}}), +{{indent}} static_cast<{{lib_dtype}}*>({{out_ptr}}), +{{indent}} {{x_dim0}}, +{{indent}} {{x_dim1}}, +{{indent}} {{x_dim2}}, +{{indent}} {{y_dim0}}, +{{indent}} {{y_dim1}}, +{{indent}} {{y_dim2}}, +{{indent}} stream +{{indent}}); +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}permute210_launcher( +{{indent}} in_ptr, +{{indent}} out_ptr, +{{indent}} *x_dim0, +{{indent}} *x_dim1, +{{indent}} *x_dim2, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +#define TILE_SIZE 32 + +namespace { +template +__global__ void permute210_kernel(T *output, + const T *input, + const int n, + const int c, + const int w) { + __shared__ T shbuf[TILE_SIZE][TILE_SIZE + 1]; + + int32_t strides[2] = { c * w, w }; + int32_t offset = blockIdx.y * strides[1]; // We are slicing through static c. + + int32_t xBlock = blockIdx.x * TILE_SIZE; + int32_t yBlock = blockIdx.z * TILE_SIZE; + int32_t x = xBlock + threadIdx.x; + int32_t y = yBlock + threadIdx.y; + + const int32_t inputIdx = y * strides[0] + offset + xBlock; + const T *A = input + inputIdx; + + if (x < w) { + if (y + 24 < n) { // This guards (y, y+8, y+16, y+24) are within boundary. + int tid = threadIdx.y; + #pragma unroll + for (int loopIdx = 0; loopIdx < 4; loopIdx++) { + shbuf[threadIdx.x][tid] = A[threadIdx.x]; + A = &A[8 * strides[0]]; + tid += 8; + } + } else { + #pragma unroll + for (int tid = threadIdx.y; tid < 32; tid += 8) { + if (yBlock + tid < n) { + shbuf[threadIdx.x][tid] = A[threadIdx.x]; + } + A = &A[8 * strides[0]]; + } + } + } + __syncthreads(); + + // Now, we do the computation of transposes toward the new indices + strides[0] = c * n; + strides[1] = n; + offset = blockIdx.y * strides[1]; + + xBlock = blockIdx.z * TILE_SIZE; + yBlock = blockIdx.x * TILE_SIZE; + x = xBlock + threadIdx.x; + y = yBlock + threadIdx.y; + + output = &output[y * strides[0] + offset + xBlock]; + if (x < n) { + if (y + 24 < w) { + int tid = threadIdx.y; + #pragma unroll + for (int loopIdx = 0; loopIdx < 4; loopIdx++) { + output[threadIdx.x] = shbuf[tid][threadIdx.x]; + output = &output[8 * strides[0]]; + tid += 8; + } + } else { + #pragma unroll + for (int tid = threadIdx.y; tid < 32; tid += 8) { + if (yBlock + tid < w) { + output[threadIdx.x] = shbuf[tid][threadIdx.x]; + } + output = &output[8 * strides[0]]; + } + } + } +} + +void permute210_launcher({{lib_dtype}}* in_ptr, + {{lib_dtype}}* out_ptr, + int x_dim0, + int x_dim1, + int x_dim2, + {{prefix}}Stream_t stream) { + dim3 grid((x_dim2 + (TILE_SIZE-1))/TILE_SIZE, x_dim1, (x_dim0 + (TILE_SIZE-1))/TILE_SIZE); + dim3 block(TILE_SIZE, TILE_SIZE/4); + permute210_kernel<{{lib_dtype}}><<>>( + out_ptr, + (const {{lib_dtype}}*)in_ptr, + x_dim0, + x_dim1, + x_dim2 + ); +} +} // namespace + +void {{function_name}} ( + {{lib_dtype}}* in_ptr, + {{lib_dtype}}* out_ptr, + int64_t* x_dim0, + int64_t* x_dim1, + int64_t* x_dim2, + int64_t* y_dim0, + int64_t* y_dim1, + int64_t* y_dim2, + {{prefix}}Stream_t stream +) { + if (!in_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + if (!out_ptr) { + throw std::runtime_error("in_ptr is NULL!"); + } + {{exec_paths}} +} + +""" +) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + header_files : str + header files included in the function + backend_spec : class + specifies the backend configs + + Returns + ------- + str + Source code for function generated. + """ + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + exec_paths = EXEC_TEMPLATE.render() + return SRC_TEMPLATE.render( + function_name=func_name, + header_files=header_files, + exec_paths=exec_paths, + prefix=backend_spec.prefix, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies the backend configs + + Returns + ------- + str + Function declaration + """ + func_name = func_attrs["name"] + x = func_attrs["inputs"][0] + xdtype = x._attrs["dtype"] + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, + prefix=backend_spec.prefix, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + ) + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str: + """ + Parameters + ---------- + func_attrs : dict + Attributes from Operator + backend_spec : class + specifies the backend configs + indent : str, optional + Indentation for function call template, by default " " + + Returns + ------- + str + Driver code for invoking call + """ + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + xdtype = x._attrs["dtype"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + out_ptr=y._attrs["name"], + x_dim0="&" + xshape[0]._attrs["name"], + x_dim1="&" + xshape[1]._attrs["name"], + x_dim2="&" + xshape[2]._attrs["name"], + y_dim0="&" + yshape[0]._attrs["name"], + y_dim1="&" + yshape[1]._attrs["name"], + y_dim2="&" + yshape[2]._attrs["name"], + indent=indent, + lib_dtype=backend_spec.dtype_to_lib_type(xdtype), + ) diff --git a/python/aitemplate/backend/common/tensor/slice_common.py b/python/aitemplate/backend/common/tensor/slice_common.py new file mode 100644 index 000000000..fb17116de --- /dev/null +++ b/python/aitemplate/backend/common/tensor/slice_common.py @@ -0,0 +1,902 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Slice backend common implementation. +""" +import jinja2 + +CAST_TO_CONST_HALF_PTR_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + + +CAST_TO_HALF_PTR_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + + +SHAPE_UPDATE_FUNC = jinja2.Template( + """ +{{indent}}int64_t output_scatter_dim_value = 0; +{{indent}}for ({{index_type}} i = 0; i < num_inputs; i++) { +{{indent}} output_scatter_dim_value += +{{indent}} slice_end_indices[i][scatter_dim] - slice_start_indices[i][scatter_dim]; +{{indent}}} +{{indent}} +{{indent}}for ({{index_type}} i = 0; i < rank; i++) { +{{indent}} if (i == scatter_dim) { +{% if update_output_shape %} +{{indent}} *output_shape[i] = output_scatter_dim_value; +{% else %} +{{indent}} // skip updating output_shape[i] +{% endif %} +{{indent}} } else { +{{indent}} int64_t dim = slice_end_indices[0][i] - slice_start_indices[0][i]; +{{indent}} for ({{index_type}} j = 1; j < num_inputs; j++) { +{{indent}} if (slice_end_indices[j][i] - slice_start_indices[j][i] != dim) { +{{indent}} throw std::runtime_error("invalid indices"); +{{indent}} } +{% if update_output_shape %} +{{indent}} *output_shape[i] = dim; +{% else %} +{{indent}} // skip updating output_shape[i] +{% endif %} +{{indent}} } +{{indent}} } +{{indent}}} +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_output_type}} * /*output*/, + int64_t *[] /*output_shape*/, + const {{elem_input_type}} *[] /*inputs*/, + const int64_t *[] /*input_shapes*/, + const int64_t *[] /*orig_slice_start_indices*/, + const int64_t *[] /*orig_slice_end_indices*/, + {{index_type}} /*scatter_dim*/, + {{index_type}} /*rank*/, + {{index_type}} /*num_inputs*/, + {{prefix}}Stream_t + ); +""" +) + + +KERNEL_SRC_TEMPLATE = jinja2.Template( + """ +{{header_src}} + +#include +#include +#include +#include +#include +#include + +{% if element_func_def %} +//#include +{% endif %} + +namespace { +#ifndef CHECK_ERROR_SLICE +#define CHECK_ERROR_SLICE(expr) \\ + do { \\ + {{prefix}}Error_t status = (expr); \\ + if (status != {{prefix}}Success) { \\ + auto msg = std::string("Got error: ") + \\ + {{prefix}}GetErrorString(status) + \\ + " at " + __FILE__ + ": " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } while (0) +#endif // CHECK_ERROR_SLICE + +#ifndef LAUNCH_CHECK_SLICE +#define LAUNCH_CHECK_SLICE() CHECK_ERROR_SLICE({{prefix}}GetLastError()) +#endif // LAUNCH_CHECK_SLICE + +{% if element_func_def %} +{{element_func_def}} +{% endif %} + +template +struct SliceMetaData { + const T *inputs[NumInputs]; + int64_t slice_start_indices[NumInputs][Rank]; + int64_t slice_end_indices[NumInputs][Rank]; + {{index_type}} dim; // scatter dimension + int64_t input_strides[NumInputs][Rank]; + int64_t num_elems[NumInputs]; + int64_t offsets[NumInputs]; // value of (dim_offset * output_dim_stride) at + // the dim axis in the output, where dim_offset + // is the offset of the scattered input at the + // dimension axis in the output + int64_t dim_sizes[NumInputs]; // dimension size of the input to be scattered + // at the dim axis +}; + +template <{{index_type}} Rank, {{index_type}} NumInputs> +struct ScatterMetaData { + int64_t output_shape[Rank]; + int64_t output_strides[Rank]; +}; + +__host__ __device__ __forceinline__ +int64_t get_num_elems(const int64_t *shape, {{index_type}} rank) { + {{index_type}} num = 1; + for ({{index_type}} i = 0; i < rank; i++) { + num *= shape[i]; + } + return num; +} + +template <{{index_type}} Rank> +__host__ __device__ int64_t compute_input_linear_index( + const int64_t *input_strides, + const int64_t *slice_start_indices, + const int64_t *slice_end_indices, + int64_t linear_idx) { + int64_t input_offset = slice_start_indices[0] * input_strides[0]; + for ({{index_type}} i = Rank - 1; i > 0; i--) { + {{index_type}} curr_output_dim_size = slice_end_indices[i] - slice_start_indices[i]; + int64_t curr_output_idx = linear_idx % curr_output_dim_size; + int64_t curr_input_idx = curr_output_idx + slice_start_indices[i]; + input_offset += curr_input_idx * input_strides[i]; + linear_idx /= curr_output_dim_size; + } + return input_offset + linear_idx * input_strides[0]; +} + +template <{{index_type}} Rank> +__host__ __device__ int64_t compute_output_elem_offset( + const int64_t *output_shape, + const int64_t *output_strides, + int64_t scatter_dim_size, + const {{index_type}} scatter_dim, + int64_t linear_idx) { + int64_t offset = 0; + for ({{index_type}} i = Rank - 1; i >= 1; --i) { + int64_t cur_dim_size = i == scatter_dim ? scatter_dim_size : output_shape[i]; + int64_t next_dim_idx = linear_idx / cur_dim_size; + int64_t cur_dim_idx = linear_idx - cur_dim_size * next_dim_idx; + int64_t cur_dim_offset = cur_dim_idx * output_strides[i]; + offset += cur_dim_offset; + linear_idx = next_dim_idx; + } + return offset + linear_idx * output_strides[0]; +} + +template +__global__ void +slice_scatter_kernel( + ELEM_T *orig_output, + SliceMetaData slice_meta_data, + ScatterMetaData scatter_meta_data) { + const {{index_type}} tid = blockIdx.x * blockDim.x + threadIdx.x; + const {{index_type}} block_y = blockIdx.y % NumInputs; + + READ_T* output = reinterpret_cast(orig_output); + const READ_T* input = + reinterpret_cast(slice_meta_data.inputs[block_y]); + int64_t num_elems = slice_meta_data.num_elems[block_y]; + const int64_t *input_strides = slice_meta_data.input_strides[block_y]; + const int64_t *slice_start_indices = + slice_meta_data.slice_start_indices[block_y]; + const int64_t *slice_end_indices = + slice_meta_data.slice_end_indices[block_y]; + + {{index_type}} scatter_dim = slice_meta_data.dim; + int64_t scatter_dim_size = slice_meta_data.dim_sizes[block_y]; + int64_t scatter_offset = slice_meta_data.offsets[block_y]; + + unsigned read_t_sz = sizeof(READ_T); + unsigned elem_t_sz = sizeof(ELEM_T); + assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + {{index_type}} n_of_elem_t = read_t_sz / elem_t_sz; + // number of READ_T elements per thread + {{index_type}} reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; + const int64_t num_elems_in_read_t = num_elems / n_of_elem_t; + {{index_type}} read_idx = tid; + +#pragma unroll + for ({{index_type}} i = 0; i < reads_per_thread_in_read_t; + i++, read_idx += blockDim.x * gridDim.x) { + if (read_idx >= num_elems_in_read_t) { + break; + } + /* make sure to adjust read_idx, which refers to location at + (read_idx * n_of_elem_t) actually */ + int64_t input_idx = compute_input_linear_index( + input_strides, + slice_start_indices, + slice_end_indices, + read_idx * n_of_elem_t); + int64_t output_elem_offset = compute_output_elem_offset( + scatter_meta_data.output_shape, + scatter_meta_data.output_strides, + scatter_dim_size, + scatter_dim, + read_idx * n_of_elem_t); + + READ_T tmp_v = input[input_idx / n_of_elem_t]; + int64_t output_idx = (scatter_offset + output_elem_offset) / n_of_elem_t; + {% if element_func %} + output[output_idx] = {{element_func}}(tmp_v); + {% else %} + output[output_idx] = tmp_v; + {% endif %} + } +} + +enum class LoadVecType { + VT_HALF = 0, + VT_FLOAT, + VT_FLOAT2, + VT_FLOAT4 +}; + +template +static inline LoadVecType get_vec_type(int64_t dim_size) { + {{index_type}} size_elem_t = sizeof(ELEM_T); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + if (sizeof(vec_type) % size_elem_t == 0) { \\ + {{index_type}} n_of_elem_t = sizeof(vec_type) / size_elem_t; \\ + if (dim_size % n_of_elem_t == 0) { \\ + return load_vec_type; \\ + } \\ + } + + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + +#undef HANDLE_ONE_VEC_TYPE + throw std::runtime_error( + "Cannot resolve LoadVecType." + ); +} + +template +static LoadVecType get_input_vec_type( + const int64_t *output_strides, + const ELEM_T *input, + const int64_t *input_shape, + const int64_t *input_strides, + const int64_t *slice_start_indices, + const int64_t *slice_end_indices, + {{index_type}} scatter_dim, + {{index_type}} scatter_offset, + {{index_type}} dim_size) { + // get the outermost index where we continuous element accesses + {{index_type}} flatten_index = Rank - 1; + for (; flatten_index >= 0; flatten_index--) { + if (slice_end_indices[flatten_index] - slice_start_indices[flatten_index] != + input_shape[flatten_index]) { + break; + } + } + int64_t input_start_offset = + compute_input_linear_index(input_strides, + slice_start_indices, + slice_end_indices, + /*linear_idx*/0); + LoadVecType slice_vec_type1 = + get_vec_type(input_start_offset); + LoadVecType slice_vec_type2; + if (Rank == 1) { + int64_t continuous_read_size = slice_end_indices[0] - slice_start_indices[0]; + slice_vec_type2 = get_vec_type(continuous_read_size); + } else { + int64_t continuous_read_size = + (slice_end_indices[flatten_index] - slice_start_indices[flatten_index]) * + input_strides[flatten_index]; + LoadVecType vec_type1 = get_vec_type(continuous_read_size); + continuous_read_size = + (input_shape[flatten_index] - slice_end_indices[flatten_index]) * + input_strides[flatten_index]; + LoadVecType vec_type2 = get_vec_type(continuous_read_size); + // find the smaller alignment reqirement between the sliced piece + // and the rest along the flattened dimensions + slice_vec_type2 = vec_type1 < vec_type2 ? vec_type1 : vec_type2; + } + LoadVecType slice_min_vec_type = slice_vec_type1 < slice_vec_type2 ? + slice_vec_type1 : slice_vec_type2; + + LoadVecType scatter_vec_type1 = get_vec_type(dim_size); + LoadVecType scatter_vec_type2 = get_vec_type(scatter_offset); + LoadVecType scatter_min_vec_type = scatter_vec_type1 < scatter_vec_type2 ? + scatter_vec_type1 : scatter_vec_type2; + + LoadVecType min_vec_type = slice_min_vec_type < scatter_min_vec_type ? + slice_min_vec_type : scatter_min_vec_type; + return min_vec_type; +} + +template +void prepare_one_meta_data( + {{index_type}} input_idx, + SliceMetaData &slice_meta_data, + ScatterMetaData &scatter_meta_data, + const ELEM_T *input, + const int64_t *input_shape, + const int64_t *slice_start_indices, + const int64_t *slice_end_indices, + {{index_type}} scatter_dim, + {{index_type}} scatter_dim_offset) { + slice_meta_data.inputs[input_idx] = input; + slice_meta_data.input_strides[input_idx][Rank-1] = 1; + for ({{index_type}} i = Rank - 2; i >= 0; i--) { + slice_meta_data.input_strides[input_idx][i] = + slice_meta_data.input_strides[input_idx][i+1] * input_shape[i+1]; + } + + slice_meta_data.num_elems[input_idx] = 1; + for ({{index_type}} i = 0; i < Rank; i++) { + assert(slice_start_indices[i] >= 0 && + slice_start_indices[i] <= input_shape[i]); + assert(slice_end_indices[i] >= 0 && slice_end_indices[i] <= input_shape[i]); + assert(slice_start_indices[i] <= slice_end_indices[i]); + + slice_meta_data.num_elems[input_idx] *= + slice_end_indices[i] - slice_start_indices[i]; + slice_meta_data.slice_start_indices[input_idx][i] = slice_start_indices[i]; + slice_meta_data.slice_end_indices[input_idx][i] = slice_end_indices[i]; + } + + slice_meta_data.dim_sizes[input_idx] = + slice_end_indices[scatter_dim] - slice_start_indices[scatter_dim]; + slice_meta_data.offsets[input_idx] = + scatter_dim_offset * scatter_meta_data.output_strides[scatter_dim]; +} + +template +void slice_scatter_kernel_launcher( + ELEM_T *output, + const int64_t *output_shape, + const ELEM_T *inputs[], + const int64_t *input_shapes[], + const std::vector> &slice_start_indices, + const std::vector> &slice_end_indices, + {{index_type}} scatter_dim, + {{prefix}}Stream_t stream +) { + SliceMetaData slice_meta_data; + ScatterMetaData scatter_meta_data; + + // meta data for placing sliced output + scatter_meta_data.output_strides[Rank-1] = 1; + scatter_meta_data.output_shape[Rank-1] = output_shape[Rank-1]; + for ({{index_type}} i = Rank - 2; i >= 0; i--) { + scatter_meta_data.output_strides[i] = + scatter_meta_data.output_strides[i+1] * output_shape[i+1]; + scatter_meta_data.output_shape[i] = output_shape[i]; + } + + {{index_type}} scatter_dim_offset = 0; + slice_meta_data.dim = scatter_dim; + for ({{index_type}} i = 0; i < NumInputs; i++) { + prepare_one_meta_data(i, slice_meta_data, scatter_meta_data, + inputs[i], input_shapes[i], + slice_start_indices[i].data(), + slice_end_indices[i].data(), + scatter_dim, scatter_dim_offset); + scatter_dim_offset += slice_meta_data.dim_sizes[i]; + } + + LoadVecType min_vec_type = LoadVecType::VT_FLOAT4; + for ({{index_type}} i = 0; i < NumInputs; i++) { + LoadVecType vec_type = get_input_vec_type( + scatter_meta_data.output_strides, + inputs[i], + input_shapes[i], + slice_meta_data.input_strides[i], + slice_start_indices[i].data(), + slice_end_indices[i].data(), + scatter_dim, + slice_meta_data.offsets[i], + slice_meta_data.dim_sizes[i]); + min_vec_type = vec_type < min_vec_type ? vec_type : min_vec_type; + } + + // setup kernel configs + int64_t max_num_elems = 0; + for ({{index_type}} i = 0; i < NumInputs; i++) { + if (slice_meta_data.num_elems[i] > max_num_elems) { + max_num_elems = slice_meta_data.num_elems[i]; + } + } + + {{index_type}} m = max_num_elems % (ThreadsPerBlock * ElemsPerThread) != 0; + {{index_type}} num_blocks_x = + (max_num_elems / (ThreadsPerBlock * ElemsPerThread)) + m; + dim3 grid_config = dim3(num_blocks_x, NumInputs); + +#define HANDLE_ONE_VEC_TYPE(load_vec_type, vec_type) \\ + case load_vec_type: { \\ + if (ElemsPerThread * sizeof(ELEM_T) < sizeof(vec_type)) { \\ + throw std::runtime_error( \\ + std::string("No valid kernel available for ") + #vec_type); \\ + } \\ + slice_scatter_kernel \\ + <<>>( \\ + output, \\ + slice_meta_data, \\ + scatter_meta_data); \\ + LAUNCH_CHECK_SLICE(); \\ + break; \\ + } + + switch (min_vec_type) { + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT4, float4) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT2, float2) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_FLOAT, float) + HANDLE_ONE_VEC_TYPE(LoadVecType::VT_HALF, half) + default: + throw std::runtime_error("Invalid LoadVecType\\n"); + } + +#undef HANDLE_ONE_VEC_TYPE +} + +std::tuple, std::vector> +normalize_slice_indices( + const int64_t *input_shape, + const int64_t *orig_slice_start_indices, + const int64_t *orig_slice_end_indices, + {{index_type}} rank) { + std::vector slice_start_indices(rank); + std::vector slice_end_indices(rank); + for ({{index_type}} i = 0; i < rank; i++) { + slice_start_indices[i] = orig_slice_start_indices[i] < 0 ? + input_shape[i] + orig_slice_start_indices[i]: + orig_slice_start_indices[i]; + // make it compatible with PyTorch + slice_start_indices[i] = slice_start_indices[i] < 0 ? + 0 : slice_start_indices[i]; + if (slice_start_indices[i] < 0) { + slice_start_indices[i] = 0; + } + if (slice_start_indices[i] > input_shape[i]) { + slice_start_indices[i] = input_shape[i]; + } + + slice_end_indices[i] = orig_slice_end_indices[i] < 0 ? + input_shape[i] + orig_slice_end_indices[i]: + orig_slice_end_indices[i]; + // make it compatible with PyTorch + slice_end_indices[i] = slice_end_indices[i] < 0 ? + 0 : slice_end_indices[i]; + if (slice_end_indices[i] < 0) { + slice_end_indices[i] = 0; + } + if (slice_end_indices[i] > input_shape[i]) { + slice_end_indices[i] = input_shape[i]; + } + + // make it compatible with PyTorch + if (slice_start_indices[i] > slice_end_indices[i]) { + slice_start_indices[i] = slice_end_indices[i]; + } + } + + return {slice_start_indices, slice_end_indices}; +} +} // namespace + +""" +) + + +EXEC_COND_TEMPLATE = jinja2.Template( + """ +{{indent}}if (rank == {{rank}} && num_inputs == {{num_inputs}}) { +{{indent}} int64_t local_output_shape[{{rank}}]; +{% for rank_idx in range(rank) %} +{{indent}} local_output_shape[{{rank_idx}}] = *output_shape[{{rank_idx}}]; +{% endfor %} +{{indent}} slice_scatter_kernel_launcher<{{elem_type}}, +{{indent}} {{rank}}/*Rank*/, +{{indent}} {{num_inputs}}/*NumInputs*/, +{{indent}} {{elems_per_thread}}/*ElemsPerThread*/, +{{indent}} {{threads_per_block}}/*ThreadsPerBlock*/>( +{{indent}} output, local_output_shape, inputs, input_shapes, +{{indent}} slice_start_indices, slice_end_indices, scatter_dim, stream); +{{indent}} return; +{{indent}}} +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +{{kernel_src}} + +void {{func_name}}( + {{elem_output_type}} *output, + int64_t *output_shape[], + const {{elem_input_type}} *inputs[], + const int64_t *input_shapes[], + const int64_t *orig_slice_start_indices[], + const int64_t *orig_slice_end_indices[], + {{index_type}} scatter_dim, + {{index_type}} rank, + {{index_type}} num_inputs, + {{prefix}}Stream_t stream + ) { + + if (rank <= 0) { + throw std::runtime_error("rank must > 0!"); + } + if (scatter_dim >= rank) { + throw std::runtime_error("scatter_dim must < rank!"); + } + + // clip slip start and end indices + std::vector> slice_start_indices(num_inputs); + std::vector> slice_end_indices(num_inputs); + std::vector output_dim_sizes; + for ({{index_type}} i = 0; i < num_inputs; i++) { + std::vector start_indices; + std::vector end_indices; + std::tie(start_indices, end_indices) = + normalize_slice_indices(input_shapes[i], + orig_slice_start_indices[i], + orig_slice_end_indices[i], + rank); + slice_start_indices[i] = start_indices; + slice_end_indices[i] = end_indices; + } + +{{shape_function}} + + // If all input tensors are empty, we are done + bool empty = true; + for ({{index_type}} i = 0; i < num_inputs; i++) { + if (get_num_elems(input_shapes[i], rank) != 0) { + empty = false; + // make sure input is valid for each non-zero-size tensor + if (!inputs[i]) { + throw std::runtime_error("NULL input is found at: " + std::to_string(i)); + } + } + } + + if (empty) + return; + + // if we output has any zero dim size, we are done + for ({{index_type}} i = 0; i < rank; i++) { + if (*output_shape[i] == 0) + return; + } + // make sure we have a valid output pointer + if (!output) { + throw std::runtime_error("output is NULL!"); + } + +{{exec_paths}} + + throw std::runtime_error( + "Unsupported cat kernel specialization!" + ); +} +""" +) + + +DEFAULT_OUTPUT_SHAPE_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}} int64_t *{{output_name}}_shape[] = { +{{indent}} {{output_dim_refs}} +{{indent}} }; +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{output_shape_def}} + +{{indent}} const half *inputs[] = { +{{indent}} {{inputs}} +{{indent}} }; + +{{input_shape_defs}} + +{{indent}} const int64_t *input_shapes[] = { +{{indent}} {{input_shapes}} +{{indent}} }; + +{{start_indices_defs}} + +{{indent}} const int64_t *slice_start_indices[] = { +{{indent}} {{slice_start_indices}} +{{indent}} }; + +{{end_indices_defs}} + +{{indent}} const int64_t *slice_end_indices[] = { +{{indent}} {{slice_end_indices}} +{{indent}} }; + +{{indent}} {{func_name}}( +{{indent}} {{output_ptr}}, +{{indent}} {{output_name}}_shape, +{{indent}} inputs, +{{indent}} input_shapes, +{{indent}} slice_start_indices, +{{indent}} slice_end_indices, +{{indent}} {{scatter_dim}}/*scatter_dim*/, +{{indent}} {{rank}}/*rank*/, +{{indent}} {{num_inputs}}/*num_inputs*/, +{{indent}} stream +{{indent}} ); +{{indent}}} +""" +) + + +INPUT_SHAPE_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}}int64_t {{input_shape_name}}[] = { +{{indent}} {{input_dims}} +{{indent}}}; +""" +) + + +INPUT_INDICES_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}}int64_t {{input_indices_name}}[] = { +{{indent}} {{input_indices}} +{{indent}}}; +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Generate function declaration. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec: dataclass + Backend specification. + + Returns + ------- + str + Rendered function declaration. + """ + x = func_attrs["inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + func_name=func_attrs["name"], + elem_output_type=output_type, + elem_input_type=input_type, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ) + + +def gen_function( + func_attrs, + backend_spec, + elems_per_thread=8, + update_output_shape=True, + element_func=None, + element_func_def=None, + extra_header_template=None, +): + """Generates function body. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec: dataclass + Backend specification. + elems_per_thread: int + Per thread elements. + update_output_shape: bool + Whether to update output shape, by default True. + element_func: str + Attributes for ease of tanh concatenate fusion, default is None. + element_func_def: str + Implmentation for fast_tanh, default is None. + extra_header_template: str + Header for fast_tanh, default is None. + + + Returns + ------- + str + Rendered function body. + """ + inputs = func_attrs["inputs"] + x = inputs[0] + y = func_attrs["outputs"][0] + x_shape = x._attrs["shape"] + + input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + + # TODO: consider to add profiling paths for tuning + # elems_per_thread and threads_per_block + exec_paths = EXEC_COND_TEMPLATE.render( + indent=" ", + num_inputs=len(inputs), + rank=len(x_shape), + elem_type=input_type, + elems_per_thread=elems_per_thread, + threads_per_block=128, + ) + + shape_func = SHAPE_UPDATE_FUNC.render( + indent=" ", + update_output_shape=update_output_shape, + index_type=backend_spec.index_type, + ) + extra_header = ( + extra_header_template.render(element_func_def=element_func_def) + if extra_header_template is not None + else "" + ) + header_src = backend_spec.header_src_template.render(extra_header=extra_header) + kernel_src = KERNEL_SRC_TEMPLATE.render( + element_func=element_func, + element_func_def=element_func_def, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + header_src=header_src, + ) + return SRC_TEMPLATE.render( + kernel_src=kernel_src, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + shape_function=shape_func, + exec_paths=exec_paths, + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + header_src=header_src, + ) + + +def gen_function_call( + backend_spec, + func_name, + inputs, + outputs, + start_indices, + end_indices, + dim=0, + indent=" ", + output_shape_def=None, +): + """Generates function call. + + Parameters + ---------- + backend_spec: dataclass + Backend specification. + func_name : str + Function neame + inputs : List[Tensor] + Input tensors. + outputs : List[Tensor] + Output tensors. + start_indices : List[List[int]] + each input has its own list of indices + end_indices : List[List[int]] + Each input has its own list of indices + dim : int + Specify the concat dim if we concat outputs of all inputs, by default 0. + indent : str, optional + Indent for template, by default " ". + output_shape_def: jinja2.Template + output shape template, by default None. + + Returns + ------- + str + Rendered function call. + """ + assert len(inputs) == len(start_indices) == len(end_indices) + x = inputs[0] + y = outputs[0] + + input_names = ",\n ".join( + [ + backend_spec.cast_to_const_half_ptr_template.render(name=i._attrs["name"]) + for i in inputs + ] + ) + + input_shape_defs = [] + input_shape_names = [] + start_indices_defs = [] + start_indices_names = [] + end_indices_defs = [] + end_indices_names = [] + + for idx, (i, s_indices, e_indices) in enumerate( + zip(inputs, start_indices, end_indices) + ): + input_shape_name = "{}_shape".format(i._attrs["name"]) + s_indices_name = "{}_slice_start_indices_{}".format(i._attrs["name"], idx) + e_indices_name = "{}_slice_end_indices_{}".format(i._attrs["name"], idx) + if input_shape_name not in input_shape_names: + dims = ", ".join([dim._attrs["name"] for dim in i._attrs["shape"]]) + one_shape_def = INPUT_SHAPE_DEF_TEMPLATE.render( + indent=" ", input_shape_name=input_shape_name, input_dims=dims + ) + input_shape_defs.append(one_shape_def) + + s_indices_str = ", ".join([str(i) for i in s_indices]) + one_s_indices_def = INPUT_INDICES_DEF_TEMPLATE.render( + indent=" ", + input_indices_name=s_indices_name, + input_indices=s_indices_str, + ) + start_indices_defs.append(one_s_indices_def) + + e_indices_str = ", ".join([str(i) for i in e_indices]) + one_e_indices_def = INPUT_INDICES_DEF_TEMPLATE.render( + indent=" ", + input_indices_name=e_indices_name, + input_indices=e_indices_str, + ) + end_indices_defs.append(one_e_indices_def) + + input_shape_names.append(input_shape_name) + start_indices_names.append(s_indices_name) + end_indices_names.append(e_indices_name) + + if output_shape_def is None: + y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y._attrs["shape"]]) + output_shape_def = DEFAULT_OUTPUT_SHAPE_DEF_TEMPLATE.render( + indent=indent, output_name=y._attrs["name"], output_dim_refs=y_dim_refs + ) + + casted_y_ptr = backend_spec.cast_to_half_ptr_template.render(name=y._attrs["name"]) + + return FUNC_CALL_TEMPLATE.render( + indent=indent, + func_name=func_name, + output_elem_type=backend_spec.dtype_to_backend_type(y._attrs["dtype"]), + output_name=y._attrs["name"], + output_ptr=casted_y_ptr, + output_shape_def=output_shape_def, + inputs=input_names, + input_shape_defs="".join(input_shape_defs), + input_shapes=", ".join(input_shape_names), + start_indices_defs="".join(start_indices_defs), + slice_start_indices=", ".join(start_indices_names), + end_indices_defs="".join(end_indices_defs), + slice_end_indices=", ".join(end_indices_names), + scatter_dim=dim, + rank=len(x._attrs["shape"]), + num_inputs=len(inputs), + ) diff --git a/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py b/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py new file mode 100644 index 000000000..b8901a062 --- /dev/null +++ b/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Slice reshape backend common implementation. +""" +import functools + +import jinja2 + +from . import slice_common + +OUTPUT_DIM_DEF_TEMPLATE = jinja2.Template( + """ +{{indent}}int64_t {{dim_name}} = {{dim_value}}; +""" +) + +OUTPUT_SHAPE_DEF_TEMPLATE = jinja2.Template( + """ +{{dim_defs}} +{{indent}} int64_t *{{output_name}}_shape[] = { +{{indent}} {{output_dim_refs}} +{{indent}} }; +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Generate function declaration. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec: dataclass + Backend specification. + + Returns + ------- + str + Rendered function declaration. + """ + return slice_common.gen_function_decl(func_attrs, backend_spec=backend_spec) + + +def gen_function( + func_attrs, backend_spec, tanh_def, element_func=None, extra_header_template=None +): + """Generates function body. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec: dataclass + Backend specification. + element_func: str + Attributes for ease of tanh concatenate fusion, default is None. + extra_header_template: str + Header for fast_tanh, default is None. + + + Returns + ------- + str + Rendered function body. + """ + # TODO: consider to profile elems_per_thread + elems_per_thread = 8 if len(func_attrs["inputs"]) == 1 else 256 + element_func_def = None if element_func is None else tanh_def.render() + return slice_common.gen_function( + func_attrs, + backend_spec=backend_spec, + elems_per_thread=elems_per_thread, + update_output_shape=False, + element_func=element_func, + element_func_def=element_func_def, + extra_header_template=extra_header_template, + ) + + +def gen_function_call(func_attrs, backend_spec, indent=" "): + """Generates function call. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec: dataclass + Backend specification. + indent : str, optional + Indent for template, by default " ". + + Returns + ------- + str + Rendered function call. + """ + slice_ops = func_attrs["slice_ops"] + assert len(slice_ops) >= 1 + start_indices = [op._attrs["start_indices"] for op in slice_ops] + end_indices = [op._attrs["end_indices"] for op in slice_ops] + + y = func_attrs["outputs"][0] + dims = [d._attrs["values"][0] for d in y._attrs["shape"]] + scatter_dim = func_attrs["scatter_dim"] + output_shape_dims = [] + output_shape_dim_defs = [] + new_dims = dims[:scatter_dim] + remaining_dim = functools.reduce(lambda a, b: a * b, dims[scatter_dim:]) + new_dims.append(remaining_dim) + for i, dim in enumerate(new_dims): + dim_name = "output_dim_{}".format(i) + output_shape_dims.append(dim_name) + dim_def = OUTPUT_DIM_DEF_TEMPLATE.render( + indent=indent, dim_name=dim_name, dim_value=dim + ) + output_shape_dim_defs.append(dim_def) + y_dim_refs = ", ".join(["&" + dim for dim in output_shape_dims]) + output_shape_def = OUTPUT_SHAPE_DEF_TEMPLATE.render( + indent=indent, + dim_defs="".join(output_shape_dim_defs), + output_name=y._attrs["name"], + output_dim_refs=y_dim_refs, + ) + + return slice_common.gen_function_call( + backend_spec, + func_attrs["name"], + func_attrs["inputs"], + func_attrs["outputs"], + start_indices, + end_indices, + dim=scatter_dim, + indent=indent, + output_shape_def=output_shape_def, + ) diff --git a/python/aitemplate/backend/common/tensor/topk_common.py b/python/aitemplate/backend/common/tensor/topk_common.py new file mode 100644 index 000000000..6b82ef531 --- /dev/null +++ b/python/aitemplate/backend/common/tensor/topk_common.py @@ -0,0 +1,769 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +topk kernel codegen. +""" + +import os +from typing import Any, Dict, List, Tuple + +import jinja2 + +from ... import builder +from ...target import Target + +# pylint: disable=C0301 + +FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{kernel}} + +} // namespace + +{{func_signature}} +{ + topk_launcher(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output); +} + """ +) + +PROFILER_TEMPLATE = jinja2.Template( + """ +#include +{{header_files}} + +size_t GLOBAL_WORKSPACE_SIZE = 0; + +namespace { + +{{kernel}} + +} // namespace + +int main(int argc, char** argv) { + int elem_cnt = std::stoi(argv[1]); + int instance_size = std::stoi(argv[2]); + int instance_num = std::stoi(argv[3]); + + float runtime_ms = 0; + const int64_t sorted_in_aligned_bytes = GetAlignedSize(elem_cnt * sizeof(half)); + const int64_t indices_aligned_bytes = GetAlignedSize(elem_cnt * sizeof(int64_t)); + const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; + int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending(instance_size, instance_num); + GLOBAL_WORKSPACE_SIZE = GetAlignedSize(sorted_in_aligned_bytes + indices_aligned_bytes + sorted_indices_aligned_bytes + temp_storage_bytes); + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(int64_t* output, + const half* input, + const {{index_type}} elem_cnt, + const {{index_type}} instance_size, + const {{index_type}} instance_num, + const {{index_type}} top_k, + uint8_t* workspace, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{output}}, {{input}}, +{{indent}} {{elem_cnt}}, +{{indent}} {{instance_size}}, +{{indent}} {{instance_num}}, +{{indent}} {{top_k}}, +{{indent}} global_workspace, stream /* default stream */ +{{indent}}); + """ +) + +KERNEL_TEMPLATE = jinja2.Template( + """ +const int32_t kThreadsNumPerBlock = 256; +const int32_t kMaxBlocksNum = 8192; + +#define GPU_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline size_t GetAlignedSize(size_t size) { + const size_t kAlignSize = 512; + return (size + kAlignSize - 1) / kAlignSize * kAlignSize; +} + +template +T GetZeroVal() { + return static_cast(0); +} + +template +T GetOneVal() { + return static_cast(1); +} + +template +T GetMinVal() { + uint16_t ret = 0xfbff; + return *(T*)&ret; +} + +template +T GetMaxVal() { + uint16_t ret = 0x7bff; + return *(T*)&ret; +} + +template +T PowOf2Floor(T val, int64_t max_power) { + T max_floor = static_cast(std::pow(2, max_power)); + val = std::min(val, max_floor); + T ret = GetOneVal(); + while (true) { + ret *= 2; + if (ret >= val) { + return ret == val ? ret : ret / 2; + } + } +} + +template +T PowOf2Ceil(T val, int64_t max_power) { + T max_ceil = static_cast(std::pow(2, max_power)); + val = std::min(val, max_ceil); + T ret = GetOneVal(); + while (true) { + ret *= 2; + if (ret >= val) { + return ret; + } + } +} + +template +__device__ void BitonicSwap( + T* data, + const int64_t i, + const int64_t j, + const bool dir, + const Compare& comp) { + if (comp(data[i], data[j]) == dir) { + T tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } +} + +class MultiplyFunctor final { + public: + MultiplyFunctor(int32_t num_col) : num_col_(num_col) {} + __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const { + return idx * num_col_; + } + + private: + int32_t num_col_; +}; + +template +size_t InferTempStorageForSortPairsDescending( + int32_t num_row, + int32_t num_col) { + using SegmentOffsetIter = {{cub}}::TransformInputIterator< + int32_t, + MultiplyFunctor, + {{cub}}::CountingInputIterator>; + + {{cub}}::CountingInputIterator counting_iter(0); + MultiplyFunctor multiply_functor(num_col); + SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); + + size_t temp_storage_bytes = 0; + auto err = {{cub}}::DeviceSegmentedRadixSort:: + SortPairsDescending( + /* d_temp_storage */ nullptr, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ nullptr, + /* d_keys_out */ nullptr, + /* d_values_in */ nullptr, + /* d_values_out */ nullptr, + /* num_items */ num_row * num_col, + /* num_segments */ num_row, + /* d_begin_offsets */ segment_offset_iter, + /* d_end_offsets */ segment_offset_iter + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(KeyType) * 8, + /* stream */ 0); + + return temp_storage_bytes; +} + +template +void SortPairsDescending( + const KeyType* keys_ptr, + const ValueType* values_ptr, + int32_t num_row, + int32_t num_col, + void* temp_storage_ptr, + int32_t temp_storage_bytes, + KeyType* sorted_keys_ptr, + ValueType* sorted_values_ptr, + {{prefix}}Stream_t stream) { + size_t rt_inferred_temp_storage_bytes = + InferTempStorageForSortPairsDescending( + num_row, num_col); + + using SegmentOffsetIter = {{cub}}::TransformInputIterator< + int32_t, + MultiplyFunctor, + {{cub}}::CountingInputIterator>; + + {{cub}}::CountingInputIterator counting_iter(0); + MultiplyFunctor multiply_functor(num_col); + SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); + + auto err = {{cub}}::DeviceSegmentedRadixSort::SortPairsDescending( + /* d_temp_storage */ temp_storage_ptr, + /* temp_storage_bytes */ rt_inferred_temp_storage_bytes, + /* d_keys_in */ keys_ptr, + /* d_keys_out */ sorted_keys_ptr, + /* d_values_in */ values_ptr, + /* d_values_out */ sorted_values_ptr, + /* num_items */ num_row * num_col, + /* num_segments */ num_row, + /* d_begin_offsets */ segment_offset_iter, + /* d_end_offsets */ segment_offset_iter + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(KeyType) * 8, + /* stream */ stream); +} + +template +__device__ void +BitonicSort(T* data, const int64_t elem_cnt, const Compare& comp) { + // The element count of instance should be pow-of-2 + assert(elem_cnt > 0 && !(elem_cnt & (elem_cnt - 1))); + + // Generate a bitonic sequence from input + for (int64_t size = 2; size <= elem_cnt / 2; size *= 2) { + // Merge 2 bitonic sequences of length 'size' into a bitonic sequence of + // length '2 * size' + for (int64_t stride = size / 2; stride > 0; stride /= 2) { + for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; + swap_id += blockDim.x) { + // Change dir at intervals of 'size / 2' swaps + const bool dir = swap_id & (size / 2); + // Locate the pair {pos, pos + stride} which is going te be swaped if + // needed + const int pos = 2 * swap_id - (swap_id & (stride - 1)); + + BitonicSwap(data, pos, pos + stride, dir, comp); + + __syncthreads(); + } + } + } + + // Sort the bitonic sequence + for (int64_t stride = elem_cnt / 2; stride > 0; stride /= 2) { + for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; + swap_id += blockDim.x) { + // Locate the pair {pos, pos + stride} which is going te be swaped if + // needed + const int pos = 2 * swap_id - (swap_id & (stride - 1)); + + BitonicSwap(data, pos, pos + stride, false, comp); + + __syncthreads(); + } + } +} + +template +class Entry final { + public: + __device__ __forceinline__ Entry(int64_t index, T value) + : index_(index), value_(value) {} + + __device__ __forceinline__ int64_t GetIndex() const { + return index_; + } + __device__ __forceinline__ T GetValue() const { + return value_; + } + __device__ __forceinline__ void SetIndex(int64_t index) { + index_ = index; + } + __device__ __forceinline__ void SetValue(T value) { + value_ = value; + } + + __device__ __forceinline__ bool operator<(const Entry& entry) const { + return (value_ < entry.GetValue()) || + (value_ == entry.GetValue() && index_ > entry.GetIndex()); + } + __device__ __forceinline__ bool operator>(const Entry& entry) const { + return (value_ > entry.GetValue()) || + (value_ == entry.GetValue() && index_ < entry.GetIndex()); + } + + private: + int64_t index_; + T value_; +}; + +template +class MinHeap final { + public: + __device__ __forceinline__ MinHeap( + Entry* data, + const int64_t heap_size, + const int64_t init_index, + const T init_value) + : data_(data), heap_size_(heap_size) { + for (int64_t i = 0; i < heap_size; ++i) { + data_[i].SetIndex(init_index); + data_[i].SetValue(init_value); + } + } + __device__ __forceinline__ Entry& Top() { + return data_[0]; + } + __device__ __forceinline__ void Swap(const int64_t i, const int64_t j) { + auto tmp = data_[j]; + data_[j] = data_[i]; + data_[i] = tmp; + } + __device__ __forceinline__ void MinHeapify(int64_t index) { + while (true) { + const int64_t left = 2 * index + 1; + const int64_t right = 2 * index + 2; + int64_t min = index; + if (left < heap_size_ && data_[left] < data_[min]) { + min = left; + } + if (right < heap_size_ && data_[right] < data_[min]) { + min = right; + } + if (min == index) { + return; + } + Swap(min, index); + index = min; + } + } + + private: + Entry* data_; + int64_t heap_size_; +}; + +template +class TmpBufferManager final { + public: + TmpBufferManager(int64_t capacity, void* ptr, const int64_t N) + : capacity_{capacity}, + sorted_in_elem_cnt_{N}, + indices_elem_cnt_{sorted_in_elem_cnt_}, + sorted_indices_elem_cnt_{sorted_in_elem_cnt_} { + const int64_t sorted_in_aligned_bytes = + GetAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); + const int64_t indices_aligned_bytes = + GetAlignedSize(indices_elem_cnt_ * sizeof(int64_t)); + const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; + sorted_in_ptr_ = reinterpret_cast(ptr); + indices_ptr_ = reinterpret_cast( + reinterpret_cast(sorted_in_ptr_) + sorted_in_aligned_bytes); + sorted_indices_ptr_ = reinterpret_cast( + reinterpret_cast(indices_ptr_) + indices_aligned_bytes); + temp_storage_ptr_ = reinterpret_cast( + reinterpret_cast(sorted_indices_ptr_) + + sorted_indices_aligned_bytes); + temp_storage_bytes_ = capacity_ - sorted_in_aligned_bytes - + indices_aligned_bytes - sorted_indices_aligned_bytes; + } + ~TmpBufferManager() = default; + + T* SortedInPtr() const { + return sorted_in_ptr_; + } + int64_t* IndicesPtr() const { + return indices_ptr_; + } + int64_t* SortedIndicesPtr() const { + return sorted_indices_ptr_; + } + void* TempStoragePtr() const { + return temp_storage_ptr_; + } + + int64_t TempStorageBytes() const { + return temp_storage_bytes_; + } + + private: + int64_t capacity_; + + T* sorted_in_ptr_; + int64_t* indices_ptr_; + int64_t* sorted_indices_ptr_; + void* temp_storage_ptr_; + + int64_t sorted_in_elem_cnt_; + int64_t indices_elem_cnt_; + int64_t sorted_indices_elem_cnt_; + int64_t temp_storage_bytes_; +}; + +__global__ void InitializeIndices( + int64_t elem_cnt, + int64_t* indices_ptr, + int64_t instance_size) { + GPU_KERNEL_LOOP(i, elem_cnt) { + indices_ptr[i] = i % instance_size; + }; +} + +template +__global__ void GetOutput( + int64_t top_k, + int64_t instance_num, + int64_t instance_size, + int64_t* indices_ptr, + T* output) { + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < instance_num; + j += blockDim.y * gridDim.y) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < top_k; + i += blockDim.x * gridDim.x) { + output[top_k * j + i] = indices_ptr[instance_size * j + i]; + } + } +} + +template +__global__ void HeapTopKKernel( + const T* in_ptr, + const int64_t instance_num, + const int64_t instance_size, + const int64_t k, + const int64_t heap_size, + const int64_t init_index, + const T init_value, + int64_t* out_ptr) { + extern __shared__ char smem[]; + auto* shared_entries = reinterpret_cast*>(smem); + + // Divide elements to be sorted into disjoint sets (# of sets == # of heaps). + // Each thread in the thread block manipulates one heap to select top + // heap_size entries from corresponding set + const T* input = in_ptr + blockIdx.x * instance_size; + auto heap = MinHeap( + shared_entries + threadIdx.x * heap_size, + heap_size, + init_index, + init_value); + for (int64_t i = threadIdx.x; i < instance_size; i += blockDim.x) { + auto entry = Entry(i, input[i]); + if (entry > heap.Top()) { + heap.Top() = entry; + heap.MinHeapify(0); + } + } + + __syncthreads(); + + // Merge all heaps into a unified, sorted array + BitonicSort( + shared_entries, + blockDim.x * heap_size, + [](const Entry& x, const Entry& y) { return x > y; }); + + // Write top_k elements in sorted array to output + for (int64_t i = threadIdx.x; i < k; i += blockDim.x) { + (out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex(); + } +} +// ALIGNPTR +int64_t* alignPtr(int64_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return (int64_t*)addr; +} + +inline int32_t BlocksNum4ThreadsNum(const int32_t n) { + return std::min( + (n + kThreadsNumPerBlock - 1) / kThreadsNumPerBlock, + kMaxBlocksNum); +} + +template +void topk_launcher( + {{prefix}}Stream_t stream, + const int elem_cnt, + const int instance_size, + const int instance_num, + const int top_k, + const void* input, + void* workspace, + void* output) { + const int32_t k = std::min(top_k, instance_size); + + if (top_k < 100) { + const int32_t kMaxSharedMemoryByteSize = 48 << 10; + + // Use as many heaps as possible (# of heaps == # of threads used in thread + // block). Limitation 1: size of shared memory We also need heap_size * + // num_heap to be pow-of-2 which is necessary for bitonic sort + const int64_t heap_size = PowOf2Ceil(k, 16); + int32_t num_heap = PowOf2Floor( + kMaxSharedMemoryByteSize / (heap_size * sizeof(Entry)), 16); + // Limitation 2: # of threads in thread block + num_heap = std::min(num_heap, kThreadsNumPerBlock); + + HeapTopKKernel + <<), + stream>>>( + (const T*)input, + instance_num, + instance_size, + k, + heap_size, + GetMaxVal(), + GetMinVal(), + (int64_t*)output); + + } else { + const uintptr_t ALIGNMENT = 32; + int64_t* vworkspace = alignPtr((int64_t*)workspace, ALIGNMENT); + T* tmp_buffer = (T*)vworkspace; + + TmpBufferManager buf_manager( + static_cast(elem_cnt), tmp_buffer, elem_cnt); + + InitializeIndices<<< + BlocksNum4ThreadsNum(elem_cnt), + kThreadsNumPerBlock, + 0, + stream>>>(elem_cnt, buf_manager.IndicesPtr(), instance_size); + + SortPairsDescending( + (const T*)input, + buf_manager.IndicesPtr(), + instance_num, + instance_size, + buf_manager.TempStoragePtr(), + buf_manager.TempStorageBytes(), + buf_manager.SortedInPtr(), + buf_manager.SortedIndicesPtr(), + stream); + + {{prefix}}Memcpy2DAsync( + (int64_t*)output, + k * sizeof(int64_t), + buf_manager.SortedIndicesPtr(), + instance_size * sizeof(int64_t), + k * sizeof(int64_t), + instance_num, + {{prefix}}MemcpyDefault, + stream); + } +} + """ +) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + """Generates function. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + header_files : str + Includes the header files for a backend. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + str + Rendered function. + """ + index_type = backend_spec.index_type + prefix = backend_spec.prefix + return FUNC_TEMPLATE.render( + header_files=header_files, + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], index_type=index_type, prefix=prefix + ), + kernel=KERNEL_TEMPLATE.render(cub=backend_spec.cub, prefix=prefix), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + """Generates function decl. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + str + Rendered function decl. + """ + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + ), + ).strip() + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str: + """Generates function call. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + backend_spec : class + Specifies the backend configurations. + indent : str, optional + Indent for template, by default " ". + + Returns + ------- + str + Rendered function call. + """ + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 1 + + output_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + input_name = backend_spec.cast_to_half_ptr_template.render( + name=func_attrs["inputs"][0]._attrs["name"] + ) + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + + elem_cnt = 1 + for shape in xshape: + elem_cnt *= shape._attrs["values"][0] + instance_size = xshape[-1]._attrs["values"][0] + instance_num = elem_cnt // instance_size + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + elem_cnt=elem_cnt, + instance_size=instance_size, + instance_num=instance_num, + top_k=func_attrs["topK"], + indent=indent, + ) + + +def add_profiler( + file_pairs: List[Tuple[str, str]], + workdir: str, + op_type: str, + output_name: str, + code: str, +): + prefix = os.path.join(workdir, "profiler", op_type) + if not os.path.exists(prefix): + os.makedirs(prefix) + src_path = os.path.join(prefix, output_name + ".cu") + obj_path = os.path.join(prefix, output_name) + if os.path.exists(obj_path): + return + with open(src_path, "w") as f: + f.write(code) + file_pairs.append((src_path, obj_path)) + + +def gen_profiler( + func_attrs: Dict[str, Any], workdir: str, header_files: str, backend_spec +): + """Generates code for topk profiling. + + Parameters + ---------- + func_attrs : Dict[str, Any] + Stores the operation attributes. + workdir: str + Target directory for generated C++ source code files + header_files : str + Includes the header files for a backend. + backend_spec : class + Specifies the backend configurations. + + Returns + ------- + None + """ + # If topK is less than 100, disable profiling since our implementation does not need it. + if func_attrs["topK"] < 100: + func_attrs["has_profiler"] = False + return + + op_type = func_attrs["op"] + file_pairs = [] + index_type = backend_spec.index_type + prefix = backend_spec.prefix + code = PROFILER_TEMPLATE.render( + header_files=header_files, + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], index_type=index_type, prefix=prefix + ), + kernel=KERNEL_TEMPLATE.render(cub=backend_spec.cub, prefix=prefix), + ) + op_name = func_attrs["op"] + add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + target = Target.current() + compile_engine = builder.Builder() + compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) diff --git a/python/aitemplate/backend/common/tensor_accessor.cuh b/python/aitemplate/backend/common/tensor_accessor.cuh new file mode 100644 index 000000000..64179da02 --- /dev/null +++ b/python/aitemplate/backend/common/tensor_accessor.cuh @@ -0,0 +1,110 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef AIT_TENSOR_ACCESSOR_CUH +#define AIT_TENSOR_ACCESSOR_CUH + +// Returns a strided address based on a base pointer, an index and strided +// information. +// DATA_T: tensor data type. +// READ_T: actual data type used when reading data. e.g. for a "half" +// tensor, READ_T could be uint4 when all data is aligned. +// data: A base pointer in READ_T type. +// idx: read index in terms of READ_T. +// offset, original_total_elements_from_stride_dim and +// actual_total_elements_from_stride_dim are the corresponding data member +// values of TensorAccessor. +template +__device__ __forceinline__ READ_T* get_strided_address( + READ_T* data, + int64_t idx, + int64_t offset, + int64_t original_total_elements_from_stride_dim, + int64_t actual_total_elements_from_stride_dim) { + (void)original_total_elements_from_stride_dim; // Suppress incorrect declared + // but never referenced warning + // from nvcc. + (void)actual_total_elements_from_stride_dim; // Ditto. + if constexpr (is_contiguous) { + return reinterpret_cast(reinterpret_cast(data) + offset) + + idx; + } else { + constexpr int N_ELEMENTS_PER_READ = sizeof(READ_T) / sizeof(DATA_T); + int64_t data_idx = idx * N_ELEMENTS_PER_READ; + int64_t num_rows = data_idx / original_total_elements_from_stride_dim; + int64_t row_offset = data_idx % original_total_elements_from_stride_dim; + data_idx = + num_rows * actual_total_elements_from_stride_dim + row_offset + offset; + return reinterpret_cast( + reinterpret_cast(data) + data_idx); + } + return nullptr; // Suppress incorrect warning about missing return statement + // from nvcc. +} + +static inline uint64_t max_power2_divisor(uint64_t n) { + // max power of 2 which divides n + return n & (~(n - 1)); +} + +// A TensorAccessor which handles strided tensor access underneath. +struct TensorAccessor { + int64_t offset{0}; + bool is_contiguous{true}; + + int stride_dim{-1}; + int64_t original_total_elements_from_stride_dim{-1}; + int64_t actual_total_elements_from_stride_dim{-1}; + + // Returns an address based on a base pointer and an index. + + // DATA_T: tensor data type. + // READ_T: actual data type used when reading data. e.g. for a "half" + // tensor, READ_T could be uint4 when all data is aligned. + // data: A base pointer in READ_T type. + // idx: read index in terms of READ_T. + template + __device__ inline READ_T* get(READ_T* data, int64_t idx) const { + return is_contiguous ? get_strided_address( + data, + idx, + offset, + original_total_elements_from_stride_dim, + actual_total_elements_from_stride_dim) + : get_strided_address( + data, + idx, + offset, + original_total_elements_from_stride_dim, + actual_total_elements_from_stride_dim); + } + + uint64_t max_alignment() const { + // gcd of max alignments + auto alignment = max_power2_divisor(offset); + if (!is_contiguous) { + alignment |= max_power2_divisor(original_total_elements_from_stride_dim); + alignment |= max_power2_divisor(actual_total_elements_from_stride_dim); + } + return max_power2_divisor(alignment); + } + + bool is_valid_alignment(uint64_t n) const { + // n is a power of 2; return whether tensor accessor alignment is divisible + // by n. + return !(max_alignment() & (n - 1)); + } +}; + +#endif diff --git a/python/aitemplate/backend/common/tensor_accessor_codegen.py b/python/aitemplate/backend/common/tensor_accessor_codegen.py new file mode 100644 index 000000000..e2e873647 --- /dev/null +++ b/python/aitemplate/backend/common/tensor_accessor_codegen.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Tensor accessor related codegens. +""" + +import os +from typing import List + +import jinja2 + +from ...compiler.tensor_accessor import TensorAccessor +from ..target import Target + +# Template used to transform a Python TensorAccessor object +# to a C++ TensorAccessor struct. +TENSOR_ACCESSOR_TEMPLATE = jinja2.Template( + """ + TensorAccessor {{name}} = { + {{tensor_accessor.offset}}, + {% if tensor_accessor.is_contiguous %} + true + {% else %} + false + {% endif %} + {% if not tensor_accessor.is_contiguous %} + , + {{tensor_accessor.stride_dim}}, + {{tensor_accessor.original_total_elements_from_stride_dim}}, + {{tensor_accessor.actual_total_elements_from_stride_dim}} + {% endif %} + }; +""" +) + +STRIDED_ADDRESS_AT_IDX_FUNC_TEMPLATE = jinja2.Template( + """ +template +__device__ __forceinline__ READ_T* get_strided_address_at_idx( + READ_T *data, int64_t data_idx) { +{%if output_accessor.is_contiguous %} + return get_strided_address( + data, data_idx, {{output_accessor.offset}}, 0, 0); +{% else %} + return get_strided_address( + data, data_idx, + {{output_accessor.offset}}, + {{output_accessor.original_total_elements_from_stride_dim}}, + {{output_accessor.actual_total_elements_from_stride_dim}}); +{% endif %} +} +""" +) + + +def get_libs() -> str: + return Target.current().get_custom_libs( + os.path.dirname(__file__), "tensor_accessor.cuh" + ) + + +# Currently read4, add2 is best for both backend, so two backend seems identical. +# They may diverge when we got deeper understanding / further optimization. +ALIGNMENTS = [ + 8, + 4, + 2, + 1, +] + + +def _find_max_alignment(number: int) -> int: + """ + Return the first alignment value that meets the alignment requirement + for accessing the `number` of elements. + """ + for alignment in ALIGNMENTS: + if number % alignment == 0: + return alignment + return 1 + + +def find_max_alignment_for_accessor(accessor: TensorAccessor) -> int: + """the max alignment value that meets the requirement specified by + the accessor + + Parameters + ---------- + accessors: TensorAccessor + + Returns + ---------- + int + the max alignment value + """ + alignment = _find_max_alignment(accessor.offset) + if not accessor.is_contiguous: + alignment = min( + alignment, + _find_max_alignment(accessor.original_total_elements_from_stride_dim), + ) + alignment = min( + alignment, + _find_max_alignment(accessor.actual_total_elements_from_stride_dim), + ) + return alignment + + +def find_max_alignment_for_accessors(accessors: List[TensorAccessor]) -> int: + """the max alignment value that meets the requirement specified by + the accessors + + Parameters + ---------- + accessors: List[TensorAccessor] + TensorAccessor(s) attached to the relevant tensor being accessed + + Returns + ---------- + int + the max alignment value + """ + alignment = max(ALIGNMENTS) + # Handle accessors + for accessor in accessors: + alignment = min(alignment, find_max_alignment_for_accessor(accessor)) + return alignment + + +def find_max_alignment(num_elements: int, accessors: List[TensorAccessor]) -> int: + """find the max alignment value that meets the requirement of accessing + num_elements of data with access patterns (strides and offsets) + specified by accessors + + Parameters + ---------- + num_elements: int + specify the number of elements being accessed + + accessors: List[TensorAccessor] + TensorAccessor(s) attached to the relevant tensor being accessed + + Returns + ---------- + int + the max alignment value + """ + # get initial alignment based on the number of elements being accessed + alignment = _find_max_alignment(num_elements) + accessor_alignment = find_max_alignment_for_accessors(accessors) + return min(alignment, accessor_alignment) diff --git a/python/aitemplate/backend/common/upsampling2d_common.py b/python/aitemplate/backend/common/upsampling2d_common.py new file mode 100644 index 000000000..6d7aadd3c --- /dev/null +++ b/python/aitemplate/backend/common/upsampling2d_common.py @@ -0,0 +1,425 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backend-agnostic function templates for upsampling2d. +""" + +import jinja2 + +# pylint: disable=C0103,C0415,W0613,C0301,W0612 + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}bilinear_upsampling_luncher( +{{indent}} in_ptr, +{% if bias_add %} + {{indent}} res_ptr, +{% endif %} +{{indent}} out_ptr, +{{indent}} NI, +{{indent}} HI, +{{indent}} WI, +{{indent}} CI, +{{indent}} HO, +{{indent}} WO, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { +#define GPU_1D_KERNEL_LOOP(i, n) \ + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +{% if mode == "bilinear"%} +__global__ void bilinear_upsampling_f16_nhwc_kernel(const half2* input, + {% if bias_add %} + const half2* input_res, + {% endif %} + half2* output, + const {{index_type}} batch, + const {{index_type}} in_height, + const {{index_type}} in_width, + const {{index_type}} channels, + const {{index_type}} out_height, + const {{index_type}} out_width) { + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + const int64_t num_threads = out_height * out_width * channels * batch; + +GPU_1D_KERNEL_LOOP(out_idx, num_threads) { + int64_t idx = out_idx; + const int64_t c = idx % channels; + idx /= channels; + const int64_t x = idx % out_width; + idx /= out_width; + const int64_t y = idx % out_height; + const int64_t b = idx / out_height; + + const float in_y = (static_cast(y) + 0.5f) * height_scale - 0.5f; + const int64_t top_y_index = in_y > 0.0 ? floorf(in_y) : 0; + const int64_t bottom_y_index = + (in_y < in_height - 1) ? ceilf(in_y) : in_height - 1; + const float y_lerp = in_y - floorf(in_y); + + const float in_x = (static_cast(x) + 0.5f) * width_scale - 0.5f; + const int64_t left_x_index = in_x > 0.0 ? floorf(in_x) : 0; + const int64_t right_x_index = + (in_x < in_width - 1) ? ceilf(in_x) : in_width - 1; + const float x_lerp = in_x - floorf(in_x); + + const half2 top_left = __ldg( + input + ((b * in_height + top_y_index) * in_width + left_x_index) * + channels + + c); + + const half2 top_right = __ldg( + input + ((b * in_height + top_y_index) * in_width + right_x_index) * + channels + + c); + const half2 bottom_left = __ldg( + input + ((b * in_height + bottom_y_index) * in_width + left_x_index) * + channels + + c); + const half2 bottom_right = __ldg( + input + ((b * in_height + bottom_y_index) * in_width + right_x_index) * + channels + + c); + + float top_x = __half2float(top_left{{half2_data_ref}}.x) + (__half2float(top_right{{half2_data_ref}}.x) - __half2float(top_left{{half2_data_ref}}.x)) * x_lerp; + float top_y = __half2float(top_left{{half2_data_ref}}.y) + (__half2float(top_right{{half2_data_ref}}.y) - __half2float(top_left{{half2_data_ref}}.y)) * x_lerp; + + float bottom_x = __half2float(bottom_left{{half2_data_ref}}.x) + (__half2float(bottom_right{{half2_data_ref}}.x) - __half2float(bottom_left{{half2_data_ref}}.x)) * x_lerp;; + float bottom_y = __half2float(bottom_left{{half2_data_ref}}.y) + (__half2float(bottom_right{{half2_data_ref}}.y) - __half2float(bottom_left{{half2_data_ref}}.y)) * x_lerp;; + + float2 out = {0.f, 0.f}; + out.x = top_x + (bottom_x - top_x) * y_lerp; + out.y = top_y + (bottom_y - top_y) * y_lerp; + + {% if bias_add %} + output[out_idx] = __hadd2(__float22half2_rn(out), __ldg(input_res + out_idx)); + {% else %} + output[out_idx] = __float22half2_rn(out); + {% endif %} + } + +} + +{% else %} +template +__global__ void nearest_upsampling_f16_nhwc_kernel(const T* input, + {% if bias_add %} + const T* input_res, + {% endif %} + T* output, + const {{index_type}} batch, + const {{index_type}} in_height, + const {{index_type}} in_width, + const {{index_type}} channels, + const {{index_type}} out_height, + const {{index_type}} out_width) { + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + const int64_t nthreads = out_height * out_width * channels * batch; + +GPU_1D_KERNEL_LOOP(index, nthreads) { + int n = index; + int c = n % channels; + n /= channels; + int out_x = n % out_width; + n /= out_width; + int out_y = n % out_height; + n /= out_height; + + const T* bottom_data_n = input + n * channels * in_height * in_width; + const int in_y = + max(min(static_cast( + floorf((static_cast(out_y) + 0.5f) * height_scale)), + static_cast(in_height) - 1), + 0); + const int in_x = + max(min(static_cast( + floorf((static_cast(out_x) + 0.5f) * width_scale)), + static_cast(in_width) - 1), + 0); + const int idx = (in_y * in_width + in_x) * channels + c; + + + {% if bias_add %} + T input_val = __ldg(bottom_data_n + idx); + T input_res_val = __ldg(input_res + index); + {% if tsize == 1 %} + output[index] = input_val + input_res_val; + + {% elif tsize == 8 %} + T output_val; + Telement* pack_y = reinterpret_cast(&output_val); + Telement* pack_x = reinterpret_cast(&input_val); + Telement* pack_res = reinterpret_cast(&input_res_val); + for (int k = 0 ; k < element_in_Tio ; k++) + pack_y[k] = pack_x[k] + pack_res[k]; + output[index] = output_val; + + {% else %} + T output_val; + output_val{{half2_data_ref}}.x = input_val{{half2_data_ref}}.x + input_res_val{{half2_data_ref}}.x; + output_val{{half2_data_ref}}.y = input_val{{half2_data_ref}}.y + input_res_val{{half2_data_ref}}.y; + output[index] = output_val; + {% endif %} + {% else %} + output[index] = __ldg(bottom_data_n + idx); + {% endif %} + + } +} + +{% endif %} + +template +constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +void bilinear_upsampling_luncher({{elem_input_type}}* input, + {% if bias_add %} + {{elem_input_type}}* input_res, + {% endif %} + {{elem_output_type}}* output, + const {{index_type}} N, + const {{index_type}} H, + const {{index_type}} W, + const {{index_type}} C, + const {{index_type}} HO, + const {{index_type}} WO, + {{prefix}}Stream_t stream) { + const int64_t output_size = N * (C) * HO * WO; + dim3 grid(std::min( + ceil_div(static_cast(output_size), static_cast(512)), + static_cast(4096))); + dim3 block(512); + +{% if mode == "bilinear" %} + bilinear_upsampling_f16_nhwc_kernel<<>>( + (const half2 *)input, + {% if bias_add %} + (const half2 *)input_res, + {% endif %} + (half2 *)output, + N, H, W, C/2, HO, WO); +{% else %} + {% if tsize == 1 %} + nearest_upsampling_f16_nhwc_kernel<<>>( + (const half *)input, + {% if bias_add %} + (const half *)input_res, + {% endif %} + (half *)output, + N, H, W, C, HO, WO); + {% elif tsize == 8 %} + nearest_upsampling_f16_nhwc_kernel<<>>( + (const float4 *)input, + {% if bias_add %} + (const float4 *)input_res, + {% endif %} + (float4 *)output, + N, H, W, C/8, HO, WO); + {% else %} + nearest_upsampling_f16_nhwc_kernel<<>>( + (const half2 *)input, + {% if bias_add %} + (const half2 *)input_res, + {% endif %} + (half2 *)output, + N, H, W, C/2, HO, WO); + {% endif %} +{% endif %} +} +} // namespace + +void {{function_name}} ( + {{elem_input_type}}* in_ptr, + {% if bias_add %} + {{elem_input_type}}* res_ptr, + {% endif %} + {{elem_output_type}}* out_ptr, + {{index_type}}* batch, + {{index_type}}* in_h, + {{index_type}}* in_w, + {{index_type}}* in_ch, + {{index_type}}* out_batch, + {{index_type}}* out_h, + {{index_type}}* out_w, + {{prefix}}Stream_t stream +) { + {{shape_function}} + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this bilinear upsampling specialization." + ); +} +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_input_type}}*, + {% if bias_add %} + {{elem_input_type}}*, + {% endif %} + {{elem_output_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{prefix}}Stream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr}}), +{% if bias_add %} + {{indent}} static_cast<{{elem_input_type}}*>({{res_ptr}}), +{% endif %} +{{indent}} static_cast<{{elem_output_type}}*>({{out_ptr}}), +{{indent}} {{p_batch}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_function_decl(func_attrs, backend_spec, bias_add=False): + """Function declaration generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + backend_spec : custom class + It specifies the corresponding backend dtypes of pytorch dtypes for many operations + + Returns + ------- + str + Rendered function declaration stmt + """ + x = func_attrs["inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + bias_add=bias_add, + ) + + +def gen_alignment(x): + in_channel = x.shape()[-1].value() + if in_channel % 8 == 0: + tsize = 8 + elif in_channel % 4 == 0: + tsize = 4 + elif in_channel % 2 == 0: + tsize = 2 + else: + tsize = 1 + return tsize + + +def gen_function_call(func_attrs, backend_spec, indent=" ", bias_add=False): + """Function call generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + indent : str, optional + Indent for template, by default " " + + Returns + ------- + str + Rendered function call + """ + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + if bias_add: + r = func_attrs["inputs"][1] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + index_type=backend_spec.index_type, + in_ptr=x._attrs["name"], + res_ptr=r._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + indent=indent, + bias_add=bias_add, + ) + else: + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + index_type=backend_spec.index_type, + in_ptr=x._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + indent=indent, + bias_add=bias_add, + ) diff --git a/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py b/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py new file mode 100644 index 000000000..8431e5d87 --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +nms kernel codegen for CUDA. +""" + +import os +from typing import Any, Dict + +import jinja2 + +from ... import builder +from ...target import Target +from .efficient_nms_kernel import kernel + +# pylint: disable=C0301 + +FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +{{kernel}} + +} // namespace + +{{func_signature}} +{ + + const int N = *batch; + const int R = *num_rois; + const int C = *num_classes; + + EfficientNMSParameters mParam; + mParam.iouThreshold = iouThreshold; + mParam.scoreThreshold = 0.001; + mParam.boxDecoder = false; + mParam.numOutputBoxesPerClass = nmsMaxOut; + mParam.numOutputBoxes = nmsMaxOut; + mParam.batchSize = N; + mParam.numBoxElements = R * C * 4; + mParam.numScoreElements = R * C; + mParam.numAnchors = R; + mParam.numClasses = C; + mParam.shareLocation = (C == 1) ? true : false; + mParam.outputONNXIndices = false; + mParam.scoreSigmoid = false; + mParam.numSelectedBoxes = 5000; + + const void* const boxesInput = proposals; + const void* const scoresInput = fgScores; + const void* const anchorsInput = nullptr; + + void* numDetectionsOutput = num_detections; + void* nmsBoxesOutput = detection_boxes; + void* nmsScoresOutput = detection_scores; + void* nmsClassesOutput = detection_classe; + + return EfficientNMSInference(mParam, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nullptr, workspace, stream); + + +} + """ +) + +PROFILER_TEMPLATE = jinja2.Template( + """ +#include +{{header_files}} +size_t GLOBAL_WORKSPACE_SIZE = 0; + +namespace { + +{{kernel}} + +} // namespace + +int main(int argc, char** argv) { + float runtime_ms = 0; + int batchSize = std::stoi(argv[1]); + int numScoreElements = std::stoi(argv[2]); + int numClasses = std::stoi(argv[3]); + GLOBAL_WORKSPACE_SIZE = EfficientNMSWorkspaceSize(batchSize, numScoreElements, numClasses); + + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(int64_t* num_detections, + half* detection_boxes, + half* detection_scores, + int64_t* detection_classe, + const half* proposals, + const half* fgScores, + int64_t* batch, + int64_t* num_rois, + int64_t* num_classes, + const int preNmsTop, + const int nmsMaxOut, + const float iouThreshold, + const float minBoxSize, + uint8_t* workspace, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{num_detections}}, +{{indent}} {{detection_boxes}}, +{{indent}} {{detection_scores}}, +{{indent}} {{detection_classe}}, +{{indent}} {{proposals}}, +{{indent}} {{fgScores}}, +{{indent}} {{p_batch}}, +{{indent}} {{num_rois}}, +{{indent}} {{num_classes}}, +{{indent}} {{preNmsTop}}, +{{indent}} {{nmsMaxOut}}, +{{indent}} {{iouThreshold}}, +{{indent}} {{minBoxSize}}, +{{indent}} global_workspace, stream /* default stream */ +{{indent}}); + """ +) + + +def gen_function(func_attrs: Dict[str, Any], header_files, backend_spec) -> str: + """the function for generating nms kernel""" + return FUNC_TEMPLATE.render( + header_files=header_files, + kernel=kernel.render(prefix=backend_spec.prefix, cub=backend_spec.cub), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], prefix=backend_spec.prefix + ), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec): + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], prefix=backend_spec.prefix + ).strip() + ) + + +def gen_function_call(func_attrs, backend_spec, indent=" "): + """the function for generating a function call for nms op""" + + assert len(func_attrs["outputs"]) == 4 + assert len(func_attrs["inputs"]) == 2 + + num_detections = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + detection_boxes = backend_spec.cast_to_half_ptr_template.render( + name=func_attrs["outputs"][1]._attrs["name"] + ) + detection_scores = backend_spec.cast_to_half_ptr_template.render( + name=func_attrs["outputs"][2]._attrs["name"] + ) + detection_classes = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][3]._attrs["name"] + ) + (input_name, score_name) = ( + backend_spec.cast_to_half_ptr_template.render(name=input_tensor._attrs["name"]) + for input_tensor in func_attrs["inputs"] + ) + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + num_detections=num_detections, + detection_boxes=detection_boxes, + detection_scores=detection_scores, + detection_classe=detection_classes, + proposals=input_name, + fgScores=score_name, + p_batch="&" + xshape[0]._attrs["name"], + num_rois="&" + xshape[1]._attrs["name"], + num_classes="&" + xshape[2]._attrs["name"], + preNmsTop=func_attrs["preNmsTop"], + nmsMaxOut=func_attrs["nmsMaxOut"], + iouThreshold=func_attrs["iouThreshold"], + minBoxSize=func_attrs["minBoxSize"], + indent=indent, + ) + + +def add_profiler(file_pairs, workdir, op_type, output_name, code): + """generate nms kernel for profiling""" + prefix = os.path.join(workdir, "profiler", op_type) + if not os.path.exists(prefix): + os.makedirs(prefix) + src_path = os.path.join(prefix, output_name + ".cu") + obj_path = os.path.join(prefix, output_name) + if os.path.exists(obj_path): + return + with open(src_path, "w") as f: + f.write(code) + file_pairs.append((src_path, obj_path)) + + +def gen_profiler(func_attrs, workdir, header_files, backend_spec): + """the function for generating profiler for nms op""" + op_type = func_attrs["op"] + file_pairs = [] + code = PROFILER_TEMPLATE.render( + header_files=header_files, + kernel=kernel.render(prefix=backend_spec.prefix, cub=backend_spec.cub), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], prefix=backend_spec.prefix + ), + ) + op_name = func_attrs["op"] + add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + target = Target.current() + compile_engine = builder.Builder() + compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) diff --git a/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py new file mode 100644 index 000000000..5d5631f14 --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py @@ -0,0 +1,1160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +efficient_nms function gpu kernel. +""" +import jinja2 + +kernel = jinja2.Template( + """ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define NMS_TILES 5 + +#define CSC(call, err) \ + do { \ + {{prefix}}Error_t {{prefix}}Status = call; \ + if ({{prefix}}Status != {{prefix}}Success) { \ + return err; \ + } \ + } while (0) + +#ifndef TRT_EFFICIENT_NMS_INFERENCE_CUH +#define TRT_EFFICIENT_NMS_INFERENCE_CUH + +// FP32 Intrinsics + +float __device__ __inline__ exp_mp(const float a) { + return __expf(a); +} +float __device__ __inline__ sigmoid_mp(const float a) { + return __frcp_rn(__fadd_rn(1.f, __expf(-a))); +} +float __device__ __inline__ add_mp(const float a, const float b) { + return __fadd_rn(a, b); +} +float __device__ __inline__ sub_mp(const float a, const float b) { + return __fsub_rn(a, b); +} +float __device__ __inline__ mul_mp(const float a, const float b) { + return __fmul_rn(a, b); +} +bool __device__ __inline__ gt_mp(const float a, const float b) { + return a > b; +} +bool __device__ __inline__ lt_mp(const float a, const float b) { + return a < b; +} +bool __device__ __inline__ lte_mp(const float a, const float b) { + return a <= b; +} +bool __device__ __inline__ gte_mp(const float a, const float b) { + return a >= b; +} + +#if __CUDA_ARCH__ >= 530 + +// FP16 Intrinsics + +__half __device__ __inline__ exp_mp(const __half a) { + return hexp(a); +} +__half __device__ __inline__ sigmoid_mp(const __half a) { + return hrcp(__hadd((__half)1, hexp(__hneg(a)))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) { + return __hadd(a, b); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) { + return __hsub(a, b); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) { + return __hmul(a, b); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) { + return __hgt(a, b); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) { + return __hlt(a, b); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) { + return __hle(a, b); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) { + return __hge(a, b); +} + +#else + +// FP16 Fallbacks on older architectures that lack support + +__half __device__ __inline__ exp_mp(const __half a) { + return __float2half(exp_mp(__half2float(a))); +} +__half __device__ __inline__ sigmoid_mp(const __half a) { + return __float2half(sigmoid_mp(__half2float(a))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) { + return __float2half(add_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) { + return __float2half(sub_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) { + return __float2half(mul_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) { + return __float2half(gt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) { + return __float2half(lt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) { + return __float2half(lte_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) { + return __float2half(gte_mp(__half2float(a), __half2float(b))); +} + +#endif + +typedef enum { + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 +} pluginStatus_t; + +struct EfficientNMSParameters { + // Related to NMS Options + float iouThreshold = 0.5f; + float scoreThreshold = 0.5f; + int numOutputBoxes = 100; + int numOutputBoxesPerClass = -1; + bool padOutputBoxesPerClass = false; + int backgroundClass = -1; + bool scoreSigmoid = false; + bool clipBoxes = false; + int boxCoding = 0; // BoxCorner + + // Related to NMS Internals + int numSelectedBoxes = 4096; + int scoreBits = -1; + bool outputONNXIndices = false; + + // Related to Tensor Configuration + // (These are set by the various plugin configuration methods, no need to + // define them during plugin creation.) + int batchSize = -1; + int numClasses = 1; + int numBoxElements = -1; + int numScoreElements = -1; + int numAnchors = -1; + bool shareLocation = true; + bool shareAnchors = true; + bool boxDecoder = false; + // DataType datatype = DataType::kFLOAT; +}; + +template +struct __align__(4 * sizeof(T)) BoxCorner; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize; + +template +struct __align__(4 * sizeof(T)) BoxCorner { + // For NMS/IOU purposes, YXYX coding is identical to XYXY + T y1, x1, y2, x2; + + __device__ void reorder() { + if (gt_mp(y1, y2)) { + // Swap values, so y1 < y2 + y1 = sub_mp(y1, y2); + y2 = add_mp(y1, y2); + y1 = sub_mp(y2, y1); + } + if (gt_mp(x1, x2)) { + // Swap values, so x1 < x2 + x1 = sub_mp(x1, x2); + x2 = add_mp(x1, x2); + x1 = sub_mp(x2, x1); + } + } + + __device__ BoxCorner clip(T low, T high) const { + return { + lt_mp(y1, low) ? low : (gt_mp(y1, high) ? high : y1), + lt_mp(x1, low) ? low : (gt_mp(x1, high) ? high : x1), + lt_mp(y2, low) ? low : (gt_mp(y2, high) ? high : y2), + lt_mp(x2, low) ? low : (gt_mp(x2, high) ? high : x2)}; + } + + __device__ BoxCorner decode(BoxCorner anchor) const { + return { + add_mp(y1, anchor.y1), + add_mp(x1, anchor.x1), + add_mp(y2, anchor.y2), + add_mp(x2, anchor.x2)}; + } + + __device__ float area() const { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + if (lte_mp(h, (T)0)) { + return 0; + } + if (lte_mp(w, (T)0)) { + return 0; + } + return (float)h * (float)w; + } + + __device__ operator BoxCenterSize() const { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + return BoxCenterSize{ + add_mp(y1, mul_mp((T)0.5, h)), add_mp(x1, mul_mp((T)0.5, w)), h, w}; + } + + __device__ static BoxCorner intersect(BoxCorner a, BoxCorner b) { + return { + gt_mp(a.y1, b.y1) ? a.y1 : b.y1, + gt_mp(a.x1, b.x1) ? a.x1 : b.x1, + lt_mp(a.y2, b.y2) ? a.y2 : b.y2, + lt_mp(a.x2, b.x2) ? a.x2 : b.x2}; + } +}; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize { + // For NMS/IOU purposes, YXHW coding is identical to XYWH + T y, x, h, w; + + __device__ void reorder() {} + + __device__ BoxCenterSize clip(T low, T high) const { + return BoxCenterSize(BoxCorner(*this).clip(low, high)); + } + + __device__ BoxCenterSize decode(BoxCenterSize anchor) const { + return { + add_mp(mul_mp(y, anchor.h), anchor.y), + add_mp(mul_mp(x, anchor.w), anchor.x), + mul_mp(anchor.h, exp_mp(h)), + mul_mp(anchor.w, exp_mp(w))}; + } + + __device__ float area() const { + if (h <= (T)0) { + return 0; + } + if (w <= (T)0) { + return 0; + } + return (float)h * (float)w; + } + + __device__ operator BoxCorner() const { + T h2 = mul_mp(h, (T)0.5); + T w2 = mul_mp(w, (T)0.5); + return BoxCorner{ + sub_mp(y, h2), sub_mp(x, w2), add_mp(y, h2), add_mp(x, w2)}; + } + __device__ static BoxCenterSize intersect( + BoxCenterSize a, BoxCenterSize b) { + return BoxCenterSize( + BoxCorner::intersect(BoxCorner(a), BoxCorner(b))); + } +}; + +#endif + +template +__device__ float +IOU(EfficientNMSParameters param, BoxCorner box1, BoxCorner box2) { + // Regardless of the selected box coding, IOU is always performed in BoxCorner + // coding. The boxes are copied so that they can be reordered without + // affecting the originals. + BoxCorner b1 = box1; + BoxCorner b2 = box2; + b1.reorder(); + b2.reorder(); + float intersectArea = BoxCorner::intersect(b1, b2).area(); + if (intersectArea <= 0.f) { + return 0.f; + } + float unionArea = b1.area() + b2.area() - intersectArea; + if (unionArea <= 0.f) { + return 0.f; + } + return intersectArea / unionArea; +} + +template +__device__ BoxCorner DecodeBoxes( + EfficientNMSParameters param, + int boxIdx, + int anchorIdx, + const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput) { + // The inputs will be in the selected coding format, as well as the decoding + // function. But the decoded box will always be returned as BoxCorner. + Tb box = boxesInput[boxIdx]; + if (!param.boxDecoder) { + return BoxCorner(box); + } + Tb anchor = anchorsInput[anchorIdx]; + box.reorder(); + anchor.reorder(); + return BoxCorner(box.decode(anchor)); +} + +template +__device__ void MapNMSData( + EfficientNMSParameters param, + int idx, + int imageIdx, + const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, + const int* __restrict__ topClassData, + const int* __restrict__ topAnchorsData, + const int* __restrict__ topNumData, + const T* __restrict__ sortedScoresData, + const int* __restrict__ sortedIndexData, + T& scoreMap, + int& classMap, + BoxCorner& boxMap, + int& boxIdxMap) { + // idx: Holds the NMS box index, within the current batch. + // idxSort: Holds the batched NMS box index, which indexes the (filtered, but + // sorted) score buffer. scoreMap: Holds the score that corresponds to the + // indexed box being processed by NMS. + if (idx >= topNumData[imageIdx]) { + return; + } + int idxSort = imageIdx * param.numScoreElements + idx; + scoreMap = sortedScoresData[idxSort]; + + // idxMap: Holds the re-mapped index, which indexes the (filtered, but + // unsorted) buffers. classMap: Holds the class that corresponds to the idx'th + // sorted score being processed by NMS. anchorMap: Holds the anchor that + // corresponds to the idx'th sorted score being processed by NMS. + int idxMap = imageIdx * param.numScoreElements + sortedIndexData[idxSort]; + classMap = topClassData[idxMap]; + int anchorMap = topAnchorsData[idxMap]; + + // boxIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, and + // unsorted) boxes input buffer. + boxIdxMap = -1; + if (param.shareLocation) // Shape of boxesInput: [batchSize, numAnchors, 1, 4] + { + boxIdxMap = imageIdx * param.numAnchors + anchorMap; + } else // Shape of boxesInput: [batchSize, numAnchors, numClasses, 4] + { + int batchOffset = imageIdx * param.numAnchors * param.numClasses; + int anchorOffset = anchorMap * param.numClasses; + boxIdxMap = batchOffset + anchorOffset + classMap; + } + // anchorIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, + // and unsorted) anchors input buffer. + int anchorIdxMap = -1; + if (param.shareAnchors) // Shape of anchorsInput: [1, numAnchors, 4] + { + anchorIdxMap = anchorMap; + } else // Shape of anchorsInput: [batchSize, numAnchors, 4] + { + anchorIdxMap = imageIdx * param.numAnchors + anchorMap; + } + // boxMap: Holds the box that corresponds to the idx'th sorted score being + // processed by NMS. + boxMap = DecodeBoxes( + param, boxIdxMap, anchorIdxMap, boxesInput, anchorsInput); +} + +template +__device__ void WriteNMSResult( + EfficientNMSParameters param, + int64_t* __restrict__ numDetectionsOutput, + T* __restrict__ nmsScoresOutput, + int64_t* __restrict__ nmsClassesOutput, + BoxCorner* __restrict__ nmsBoxesOutput, + T threadScore, + int threadClass, + BoxCorner threadBox, + int imageIdx, + unsigned int resultsCounter) { + int outputIdx = imageIdx * param.numOutputBoxes + resultsCounter - 1; + if (param.scoreSigmoid) { + nmsScoresOutput[outputIdx] = sigmoid_mp(threadScore); + } else if (param.scoreBits > 0) { + nmsScoresOutput[outputIdx] = add_mp(threadScore, (T)-1); + } else { + nmsScoresOutput[outputIdx] = threadScore; + } + nmsClassesOutput[outputIdx] = (int64_t)threadClass; + if (param.clipBoxes) { + nmsBoxesOutput[outputIdx] = threadBox.clip((T)0, (T)1); + } else { + nmsBoxesOutput[outputIdx] = threadBox; + } + numDetectionsOutput[imageIdx] = (int64_t)resultsCounter; +} + +__device__ void WriteONNXResult( + EfficientNMSParameters param, + int* outputIndexData, + int* __restrict__ nmsIndicesOutput, + int imageIdx, + int threadClass, + int boxIdxMap) { + int index = boxIdxMap % param.numAnchors; + int idx = atomicAdd((unsigned int*)&outputIndexData[0], 1); + nmsIndicesOutput[idx * 3 + 0] = imageIdx; + nmsIndicesOutput[idx * 3 + 1] = threadClass; + nmsIndicesOutput[idx * 3 + 2] = index; +} + +__global__ void PadONNXResult( + EfficientNMSParameters param, + int* outputIndexData, + int* __restrict__ nmsIndicesOutput) { + if (threadIdx.x > 0) { + return; + } + int pidx = outputIndexData[0] - 1; + if (pidx < 0) { + return; + } + for (int idx = pidx + 1; idx < param.batchSize * param.numOutputBoxes; + idx++) { + nmsIndicesOutput[idx * 3 + 0] = nmsIndicesOutput[pidx * 3 + 0]; + nmsIndicesOutput[idx * 3 + 1] = nmsIndicesOutput[pidx * 3 + 1]; + nmsIndicesOutput[idx * 3 + 2] = nmsIndicesOutput[pidx * 3 + 2]; + } +} + +template +__global__ void EfficientNMS( + EfficientNMSParameters param, + const int* topNumData, + int* outputIndexData, + int* outputClassData, + const int* sortedIndexData, + const T* __restrict__ sortedScoresData, + const int* __restrict__ topClassData, + const int* __restrict__ topAnchorsData, + const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, + int64_t* __restrict__ numDetectionsOutput, + T* __restrict__ nmsScoresOutput, + int64_t* __restrict__ nmsClassesOutput, + int* __restrict__ nmsIndicesOutput, + BoxCorner* __restrict__ nmsBoxesOutput) { + unsigned int thread = threadIdx.x; + unsigned int imageIdx = blockIdx.y; + unsigned int tileSize = blockDim.x; + if (imageIdx >= param.batchSize) { + return; + } + + int numSelectedBoxes = min(topNumData[imageIdx], param.numSelectedBoxes); + int numTiles = (numSelectedBoxes + tileSize - 1) / tileSize; + if (thread >= numSelectedBoxes) { + return; + } + + __shared__ int blockState; + __shared__ unsigned int resultsCounter; + if (thread == 0) { + blockState = 0; + resultsCounter = 0; + } + + int threadState[NMS_TILES]; + unsigned int boxIdx[NMS_TILES]; + T threadScore[NMS_TILES]; + int threadClass[NMS_TILES]; + BoxCorner threadBox[NMS_TILES]; + int boxIdxMap[NMS_TILES]; + for (int tile = 0; tile < numTiles; tile++) { + threadState[tile] = 0; + boxIdx[tile] = thread + tile * blockDim.x; + MapNMSData( + param, + boxIdx[tile], + imageIdx, + boxesInput, + anchorsInput, + topClassData, + topAnchorsData, + topNumData, + sortedScoresData, + sortedIndexData, + threadScore[tile], + threadClass[tile], + threadBox[tile], + boxIdxMap[tile]); + } + + // Iterate through all boxes to NMS against. + for (int i = 0; i < numSelectedBoxes; i++) { + int tile = i / tileSize; + + if (boxIdx[tile] == i) { + // Iteration lead thread, figure out what the other threads should do, + // this will be signaled via the blockState shared variable. + if (threadState[tile] == -1) { + // Thread already dead, this box was already dropped in a previous + // iteration, because it had a large IOU overlap with another lead + // thread previously, so it would never be kept anyway, therefore it can + // safely be skip all IOU operations in this iteration. + blockState = -1; // -1 => Signal all threads to skip iteration + } else if (threadState[tile] == 0) { + // As this box will be kept, this is a good place to find what index in + // the results buffer it should have, as this allows to perform an early + // loop exit if there are enough results. + if (resultsCounter >= param.numOutputBoxes) { + blockState = -2; // -2 => Signal all threads to do an early loop exit. + } else { + // Thread is still alive, because it has not had a large enough IOU + // overlap with any other kept box previously. Therefore, this box + // will be kept for sure. However, we need to check against all other + // subsequent boxes from this position onward, to see how those other + // boxes will behave in future iterations. + blockState = 1; // +1 => Signal all (higher index) threads to + // calculate IOU against this box + threadState[tile] = 1; // +1 => Mark this box's thread to be kept and + // written out to results + + // If the numOutputBoxesPerClass check is enabled, write the result + // only if the limit for this class on this image has not been reached + // yet. Other than (possibly) skipping the write, this won't affect + // anything else in the NMS threading. + bool write = true; + if (param.numOutputBoxesPerClass >= 0) { + int classCounterIdx = + imageIdx * param.numClasses + threadClass[tile]; + write = + (outputClassData[classCounterIdx] < + param.numOutputBoxesPerClass); + outputClassData[classCounterIdx]++; + } + if (write) { + // This branch is visited by one thread per iteration, so it's safe + // to do non-atomic increments. + resultsCounter++; + if (param.outputONNXIndices) { + WriteONNXResult( + param, + outputIndexData, + nmsIndicesOutput, + imageIdx, + threadClass[tile], + boxIdxMap[tile]); + } else { + WriteNMSResult( + param, + numDetectionsOutput, + nmsScoresOutput, + nmsClassesOutput, + nmsBoxesOutput, + threadScore[tile], + threadClass[tile], + threadBox[tile], + imageIdx, + resultsCounter); + } + } + } + } else { + // This state should never be reached, but just in case... + blockState = 0; // 0 => Signal all threads to not do any updates, + // nothing happens. + } + } + + __syncthreads(); + + if (blockState == -2) { + // This is the signal to exit from the loop. + return; + } + + if (blockState == -1) { + // This is the signal for all threads to just skip this iteration, as no + // IOU's need to be checked. + continue; + } + + // Grab a box and class to test the current box against. The test box + // corresponds to iteration i, therefore it will have a lower index than the + // current thread box, and will therefore have a higher score than the + // current box because it's located "before" in the sorted score list. + T testScore; + int testClass; + BoxCorner testBox; + int testBoxIdxMap; + MapNMSData( + param, + i, + imageIdx, + boxesInput, + anchorsInput, + topClassData, + topAnchorsData, + topNumData, + sortedScoresData, + sortedIndexData, + testScore, + testClass, + testBox, + testBoxIdxMap); + + for (int tile = 0; tile < numTiles; tile++) { + // IOU + if (boxIdx[tile] > i && // Make sure two different boxes are being tested, + // and that it's a higher index; + boxIdx[tile] < numSelectedBoxes && // Make sure the box is within + // numSelectedBoxes; + blockState == 1 && // Signal that allows IOU checks to be performed; + threadState[tile] == 0 && // Make sure this box hasn't been either + // dropped or kept already; + threadClass[tile] == + testClass && // Compare only boxes of matching classes; + lte_mp(threadScore[tile], testScore) && // Make sure the sorting order + // of scores is as expected; + IOU(param, threadBox[tile], testBox) >= + param.iouThreshold) // And... IOU overlap. + { + // Current box overlaps with the box tested in this iteration, this box + // will be skipped. + threadState[tile] = -1; // -1 => Mark this box's thread to be dropped. + } + } + } +} + +template +{{prefix}}Error_t EfficientNMSLauncher( + EfficientNMSParameters& param, + int* topNumData, + int* outputIndexData, + int* outputClassData, + int* sortedIndexData, + T* sortedScoresData, + int* topClassData, + int* topAnchorsData, + const void* boxesInput, + const void* anchorsInput, + int64_t* numDetectionsOutput, + T* nmsScoresOutput, + int64_t* nmsClassesOutput, + int* nmsIndicesOutput, + void* nmsBoxesOutput, + {{prefix}}Stream_t stream) { + unsigned int tileSize = param.numSelectedBoxes / NMS_TILES; + if (param.numSelectedBoxes <= 512) { + tileSize = 512; + } + if (param.numSelectedBoxes <= 256) { + tileSize = 256; + } + + const dim3 blockSize = {tileSize, 1, 1}; + const dim3 gridSize = {1, (unsigned int)param.batchSize, 1}; + + if (param.boxCoding == 0) { + EfficientNMS><<>>( + param, + topNumData, + outputIndexData, + outputClassData, + sortedIndexData, + sortedScoresData, + topClassData, + topAnchorsData, + (BoxCorner*)boxesInput, + (BoxCorner*)anchorsInput, + numDetectionsOutput, + nmsScoresOutput, + nmsClassesOutput, + nmsIndicesOutput, + (BoxCorner*)nmsBoxesOutput); + } else if (param.boxCoding == 1) { + // Note that nmsBoxesOutput is always coded as BoxCorner, regardless of + // the input coding type. + EfficientNMS><<>>( + param, + topNumData, + outputIndexData, + outputClassData, + sortedIndexData, + sortedScoresData, + topClassData, + topAnchorsData, + (BoxCenterSize*)boxesInput, + (BoxCenterSize*)anchorsInput, + numDetectionsOutput, + nmsScoresOutput, + nmsClassesOutput, + nmsIndicesOutput, + (BoxCorner*)nmsBoxesOutput); + } + + if (param.outputONNXIndices) { + PadONNXResult<<<1, 1, 0, stream>>>( + param, outputIndexData, nmsIndicesOutput); + } + + return {{prefix}}GetLastError(); +} + +__global__ void EfficientNMSFilterSegments( + EfficientNMSParameters param, + const int* __restrict__ topNumData, + int* __restrict__ topOffsetsStartData, + int* __restrict__ topOffsetsEndData) { + int imageIdx = threadIdx.x; + if (imageIdx > param.batchSize) { + return; + } + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = + imageIdx * param.numScoreElements + topNumData[imageIdx]; +} + +template +__global__ void EfficientNMSFilter( + EfficientNMSParameters param, + const T* __restrict__ scoresInput, + int* __restrict__ topNumData, + int* __restrict__ topIndexData, + int* __restrict__ topAnchorsData, + T* __restrict__ topScoresData, + int* __restrict__ topClassData) { + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + // Boundary Conditions + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) { + return; + } + + // Shape of scoresInput: [batchSize, numAnchors, numClasses] + int scoresInputIdx = imageIdx * param.numScoreElements + elementIdx; + + // For each class, check its corresponding score if it crosses the threshold, + // and if so select this anchor, and keep track of the maximum score and the + // corresponding (argmax) class id + T score = scoresInput[scoresInputIdx]; + if (gte_mp(score, (T)param.scoreThreshold)) { + // Unpack the class and anchor index from the element index + int classIdx = elementIdx % param.numClasses; + int anchorIdx = elementIdx / param.numClasses; + + // If this is a background class, ignore it. + if (classIdx == param.backgroundClass) { + return; + } + + // Use an atomic to find an open slot where to write the selected anchor + // data. + if (topNumData[imageIdx] >= param.numScoreElements) { + return; + } + int selectedIdx = atomicAdd((unsigned int*)&topNumData[imageIdx], 1); + if (selectedIdx >= param.numScoreElements) { + topNumData[imageIdx] = param.numScoreElements; + return; + } + + // Shape of topScoresData / topClassData: [batchSize, numScoreElements] + int topIdx = imageIdx * param.numScoreElements + selectedIdx; + + if (param.scoreBits > 0) { + score = add_mp(score, (T)1); + if (gt_mp(score, (T)(2.f - 1.f / 1024.f))) { + // Ensure the incremented score fits in the mantissa without changing + // the exponent + score = (2.f - 1.f / 1024.f); + } + } + + topIndexData[topIdx] = selectedIdx; + topAnchorsData[topIdx] = anchorIdx; + topScoresData[topIdx] = score; + topClassData[topIdx] = classIdx; + } +} + +template +__global__ void EfficientNMSDenseIndex( + EfficientNMSParameters param, + int* __restrict__ topNumData, + int* __restrict__ topIndexData, + int* __restrict__ topAnchorsData, + int* __restrict__ topOffsetsStartData, + int* __restrict__ topOffsetsEndData, + T* __restrict__ topScoresData, + int* __restrict__ topClassData) { + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) { + return; + } + + int dataIdx = imageIdx * param.numScoreElements + elementIdx; + int anchorIdx = elementIdx / param.numClasses; + int classIdx = elementIdx % param.numClasses; + if (param.scoreBits > 0) { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T)param.scoreThreshold)) { + score = (T)1; + } else if (classIdx == param.backgroundClass) { + score = (T)1; + } else { + score = add_mp(score, (T)1); + if (gt_mp(score, (T)(2.f - 1.f / 1024.f))) { + // Ensure the incremented score fits in the mantissa without changing + // the exponent + score = (2.f - 1.f / 1024.f); + } + } + topScoresData[dataIdx] = score; + } else { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T)param.scoreThreshold)) { + topScoresData[dataIdx] = -(1 << 15); + } else if (classIdx == param.backgroundClass) { + topScoresData[dataIdx] = -(1 << 15); + } + } + + topIndexData[dataIdx] = elementIdx; + topAnchorsData[dataIdx] = anchorIdx; + topClassData[dataIdx] = classIdx; + + if (elementIdx == 0) { + // Saturate counters + topNumData[imageIdx] = param.numScoreElements; + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = (imageIdx + 1) * param.numScoreElements; + } +} + +template +{{prefix}}Error_t EfficientNMSFilterLauncher( + EfficientNMSParameters& param, + const T* scoresInput, + int* topNumData, + int* topIndexData, + int* topAnchorsData, + int* topOffsetsStartData, + int* topOffsetsEndData, + T* topScoresData, + int* topClassData, + {{prefix}}Stream_t stream) { + const unsigned int elementsPerBlock = 512; + const unsigned int imagesPerBlock = 1; + const unsigned int elementBlocks = + (param.numScoreElements + elementsPerBlock - 1) / elementsPerBlock; + const unsigned int imageBlocks = + (param.batchSize + imagesPerBlock - 1) / imagesPerBlock; + const dim3 blockSize = {elementsPerBlock, imagesPerBlock, 1}; + const dim3 gridSize = {elementBlocks, imageBlocks, 1}; + + float kernelSelectThreshold = 0.007f; + if (param.scoreSigmoid) { + // Inverse Sigmoid + if (param.scoreThreshold <= 0.f) { + param.scoreThreshold = -(1 << 15); + } else { + param.scoreThreshold = + logf(param.scoreThreshold / (1.f - param.scoreThreshold)); + } + kernelSelectThreshold = + logf(kernelSelectThreshold / (1.f - kernelSelectThreshold)); + // Disable Score Bits Optimization + param.scoreBits = -1; + } + + if (param.scoreThreshold < kernelSelectThreshold) { + // A full copy of the buffer is necessary because sorting will scramble the + // input data otherwise. + {{prefix}}MemcpyAsync( + topScoresData, + scoresInput, + param.batchSize * param.numScoreElements * sizeof(T), + {{prefix}}MemcpyDeviceToDevice, + stream); + + EfficientNMSDenseIndex<<>>( + param, + topNumData, + topIndexData, + topAnchorsData, + topOffsetsStartData, + topOffsetsEndData, + topScoresData, + topClassData); + } else { + EfficientNMSFilter<<>>( + param, + scoresInput, + topNumData, + topIndexData, + topAnchorsData, + topScoresData, + topClassData); + + EfficientNMSFilterSegments<<<1, param.batchSize, 0, stream>>>( + param, topNumData, topOffsetsStartData, topOffsetsEndData); + } + + return {{prefix}}GetLastError(); +} + +template +size_t EfficientNMSSortWorkspaceSize(int batchSize, int numScoreElements) { + size_t sortedWorkspaceSize = 0; + {{cub}}::DoubleBuffer keysDB(nullptr, nullptr); + {{cub}}::DoubleBuffer valuesDB(nullptr, nullptr); + {{cub}}::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + sortedWorkspaceSize, + keysDB, + valuesDB, + numScoreElements, + batchSize, + (const int*)nullptr, + (const int*)nullptr); + return sortedWorkspaceSize; +} + +template +size_t +EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses) { + size_t total = 0; + const size_t align = 256; + // Counters + // 3 for Filtering + // 1 for Output Indexing + // C for Max per Class Limiting + size_t size = (3 + 1 + numClasses) * batchSize * sizeof(int64_t); + total += size + (size % align ? align - (size % align) : 0); + // Int Buffers + for (int i = 0; i < 4; i++) { + size = batchSize * numScoreElements * sizeof(int64_t); + total += size + (size % align ? align - (size % align) : 0); + } + // Float Buffers + for (int i = 0; i < 2; i++) { + size = batchSize * numScoreElements * sizeof(T); + total += size + (size % align ? align - (size % align) : 0); + } + // Sort Workspace + size = EfficientNMSSortWorkspaceSize(batchSize, numScoreElements); + total += size + (size % align ? align - (size % align) : 0); + return total; +} + +template +T* EfficientNMSWorkspace(void* workspace, size_t& offset, size_t elements) { + T* buffer = (T*)((size_t)workspace + offset); + size_t align = 256; + size_t size = elements * sizeof(T); + size_t sizeAligned = size + (size % align ? align - (size % align) : 0); + offset += sizeAligned; + return buffer; +} + +template +pluginStatus_t EfficientNMSDispatch( + EfficientNMSParameters param, + const void* boxesInput, + const void* scoresInput, + const void* anchorsInput, + void* numDetectionsOutput, + void* nmsBoxesOutput, + void* nmsScoresOutput, + void* nmsClassesOutput, + void* nmsIndicesOutput, + void* workspace, + {{prefix}}Stream_t stream) { + // Clear Outputs (not all elements will get overwritten by the kernels, so + // safer to clear everything out) + if (param.outputONNXIndices) { + {{prefix}}MemsetAsync( + nmsIndicesOutput, + 0xFF, + param.batchSize * param.numOutputBoxes * 3 * sizeof(int), + stream); + } else { + {{prefix}}MemsetAsync( + numDetectionsOutput, 0x00, param.batchSize * sizeof(int64_t), stream); + {{prefix}}MemsetAsync( + nmsScoresOutput, + 0x00, + param.batchSize * param.numOutputBoxes * sizeof(T), + stream); + {{prefix}}MemsetAsync( + nmsBoxesOutput, + 0x00, + param.batchSize * param.numOutputBoxes * 4 * sizeof(T), + stream); + {{prefix}}MemsetAsync( + nmsClassesOutput, + 0x00, + param.batchSize * param.numOutputBoxes * sizeof(int64_t), + stream); + } + + // Empty Inputs + if (param.numScoreElements < 1) { + return STATUS_SUCCESS; + } + + // Counters Workspace + size_t workspaceOffset = 0; // 1 << 20; + int countersTotalSize = (3 + 1 + param.numClasses) * param.batchSize; + int* topNumData = + EfficientNMSWorkspace(workspace, workspaceOffset, countersTotalSize); + int* topOffsetsStartData = topNumData + param.batchSize; + int* topOffsetsEndData = topNumData + 2 * param.batchSize; + int* outputIndexData = topNumData + 3 * param.batchSize; + int* outputClassData = topNumData + 4 * param.batchSize; + {{prefix}}MemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream); + {{prefix}}Error_t status = {{prefix}}GetLastError(); + CSC(status, STATUS_FAILURE); + + // Other Buffers Workspace + int* topIndexData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topClassData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topAnchorsData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* sortedIndexData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* topScoresData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* sortedScoresData = EfficientNMSWorkspace( + workspace, workspaceOffset, param.batchSize * param.numScoreElements); + size_t sortedWorkspaceSize = + EfficientNMSSortWorkspaceSize(param.batchSize, param.numScoreElements); + char* sortedWorkspaceData = EfficientNMSWorkspace( + workspace, workspaceOffset, sortedWorkspaceSize); + {{cub}}::DoubleBuffer scoresDB(topScoresData, sortedScoresData); + {{cub}}::DoubleBuffer indexDB(topIndexData, sortedIndexData); + + // Kernels + status = EfficientNMSFilterLauncher( + param, + (T*)scoresInput, + topNumData, + topIndexData, + topAnchorsData, + topOffsetsStartData, + topOffsetsEndData, + topScoresData, + topClassData, + stream); + CSC(status, STATUS_FAILURE); + + status = {{cub}}::DeviceSegmentedRadixSort::SortPairsDescending( + sortedWorkspaceData, + sortedWorkspaceSize, + scoresDB, + indexDB, + param.batchSize * param.numScoreElements, + param.batchSize, + topOffsetsStartData, + topOffsetsEndData, + param.scoreBits > 0 ? (10 - param.scoreBits) : 0, + param.scoreBits > 0 ? 10 : sizeof(T) * 8, + stream, + false); + CSC(status, STATUS_FAILURE); + + status = EfficientNMSLauncher( + param, + topNumData, + outputIndexData, + outputClassData, + indexDB.Current(), + scoresDB.Current(), + topClassData, + topAnchorsData, + boxesInput, + anchorsInput, + (int64_t*)numDetectionsOutput, + (T*)nmsScoresOutput, + (int64_t*)nmsClassesOutput, + (int*)nmsIndicesOutput, + nmsBoxesOutput, + stream); + CSC(status, STATUS_FAILURE); + + return STATUS_SUCCESS; +} + +void EfficientNMSInference( + EfficientNMSParameters param, + const void* boxesInput, + const void* scoresInput, + const void* anchorsInput, + void* numDetectionsOutput, + void* nmsBoxesOutput, + void* nmsScoresOutput, + void* nmsClassesOutput, + void* nmsIndicesOutput, + void* workspace, + {{prefix}}Stream_t stream) { + if (param.scoreBits <= 0 || param.scoreBits > 10) { + param.scoreBits = -1; + } + EfficientNMSDispatch<__half>( + param, + boxesInput, + scoresInput, + anchorsInput, + numDetectionsOutput, + nmsBoxesOutput, + nmsScoresOutput, + nmsClassesOutput, + nmsIndicesOutput, + workspace, + stream); +} +""" +) diff --git a/python/aitemplate/backend/common/vision_ops/multi_level_roi_align_common.py b/python/aitemplate/backend/common/vision_ops/multi_level_roi_align_common.py new file mode 100644 index 000000000..19f8bd6cd --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/multi_level_roi_align_common.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +multi-level roi align common functions for all backends. +""" + +import jinja2 + +# pylint: disable=C0103,C0415,W0613,C0301,W0612 + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}FPNRoiAlign( +{{indent}} in_ptr_p2, +{{indent}} in_ptr_p3, +{{indent}} in_ptr_p4, +{{indent}} in_ptr_p5, +{{indent}} rois_ptr, +{{indent}} out_ptr, +{{indent}} batchSize, +{{indent}} featureCount, +{{indent}} imageSize, +{{indent}} P2dims, +{{indent}} P3dims, +{{indent}} P4dims, +{{indent}} P5dims, +{{indent}} sampling_ratio, +{{indent}} spatial_scale, +{{indent}} position_sensitive, +{{indent}} continuous_coordinate, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { +// customized roi align kernel + +struct xy_t { + int64_t y; + int64_t x; + + xy_t() : y(0), x(0) {} + xy_t(int64_t y_, int64_t x_) : y(y_), x(x_) {} +}; + +template +__device__ inline T interpolateBilinear( + const T* src, + xy_t srcDims, + float y, + float x, + const int channels) { + // deal with cases that inverse elements are out of feature map boundary + int height = srcDims.y; + int width = srcDims.x; + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = T(1.0) - ly, hx = T(1.0) - lx; + // do bilinear interpolation + T v1 = src[channels * (y_low * width + x_low)]; + T v2 = src[channels * (y_low * width + x_high)]; + T v3 = src[channels * (y_high * width + x_low)]; + T v4 = src[channels * (y_high * width + x_high)]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void roiAlign_kernel( + xy_t imageSize, + int featureCount, + int roiCount, + float threshold, + int samplingRatio, + const Trois* rois, + const Tfeat* P2, + const xy_t P2dims, + const Tfeat* P3, + const xy_t P3dims, + const Tfeat* P4, + const xy_t P4dims, + const Tfeat* P5, + const xy_t P5dims, + Tfeat* pooled, + const xy_t poolDims) { + const int batch = blockIdx.x; + const int feature = blockIdx.y; + const int roiIdx = blockIdx.z; + + const Trois* roi = rois + 5 * (batch * roiCount + roiIdx); + float hw; + float x1 = __half2float(roi[1]); + float y1 = __half2float(roi[2]); + float x2 = __half2float(roi[3]); + float y2 = __half2float(roi[4]); + + y1 = max(0.f, min((float)imageSize.y, y1)) / imageSize.y; + x1 = max(0.f, min((float)imageSize.x, x1)) / imageSize.x; + y2 = max(0.f, min((float)imageSize.y, y2)) / imageSize.y; + x2 = max(0.f, min((float)imageSize.x, x2)) / imageSize.x; + + hw = (y2 - y1) * (x2 - x1); + + const Tfeat* src = P2; + xy_t srcDims = P2dims; + int iP = 2; + + if (hw > threshold) { + src = P3; + srcDims = P3dims; + ++iP; + } + threshold *= 4; + + if (hw > threshold) { + src = P4; + srcDims = P4dims; + ++iP; + } + threshold *= 4; + + if (hw > threshold) { + src = P5; + srcDims = P5dims; + ++iP; + } + + src += batch * srcDims.x * srcDims.y * featureCount + feature; + // batch, roiCount, poolx, pooly, featureCount + Tfeat* dst = pooled + + poolDims.x * poolDims.y * + (batch * roiCount * featureCount + roiIdx * featureCount) + + feature; + + float samplingOffset = 0.5f; + float inputOffset = 0.5f; + + float yStart = y1 * srcDims.y - inputOffset; + float xStart = x1 * srcDims.x - inputOffset; + + float yEnd = y2 * srcDims.y - inputOffset; + float xEnd = x2 * srcDims.x - inputOffset; + + float yDelta = (yEnd - yStart) / poolDims.y; + float xDelta = (xEnd - xStart) / poolDims.x; + + const int samplingRatioX = samplingRatio > 0 + ? samplingRatio + : max(1, (int)ceilf((xEnd - xStart) / poolDims.x)); + const int samplingRatioY = samplingRatio > 0 + ? samplingRatio + : max(1, (int)ceilf((yEnd - yStart) / poolDims.y)); + const int samplingCount = samplingRatioX * samplingRatioY; + + for (int outIdx = threadIdx.x; outIdx < poolDims.x * poolDims.y; + outIdx += blockDim.x) { + int xx = outIdx % poolDims.x; + int yy = outIdx / poolDims.x; + Tfeat* out = dst + (poolDims.x * yy + xx) * featureCount; + Tfeat result = 0; + for (int iy = 0; iy < samplingRatioY; iy++) { + float ySample = yStart + yDelta * yy; + ySample += yDelta * (iy + samplingOffset) / samplingRatioY; + ySample = min(max(ySample, 0.f), srcDims.y - 1.0f); + + for (int ix = 0; ix < samplingRatioX; ix++) { + float xSample = xStart + xDelta * xx; + xSample += xDelta * (ix + samplingOffset) / samplingRatioX; + xSample = min(max(xSample, 0.f), srcDims.x - 1.0f); + + result += + interpolateBilinear(src, srcDims, ySample, xSample, featureCount); + } + } + *out = result / __float2half_rn(samplingCount); + } +} + +template +void FPNRoiAlign( + {{elem_input_type}}* P2, + {{elem_input_type}}* P3, + {{elem_input_type}}* P4, + {{elem_input_type}}* P5, + {{elem_input_type}}* rois, + {{elem_output_type}}* output, + const int batchSize, + const int featureCount, + const xy_t imageSize, + const xy_t P2dims, + const xy_t P3dims, + const xy_t P4dims, + const xy_t P5dims, + const int samplingRatio, + const float spatial_scale, + const bool position_sensitive, + const bool continuous_coordinate, + {{prefix}}Stream_t stream) { + float mFPNScale = 224; + float normScale = sqrtf(mFPNScale * mFPNScale / (imageSize.x * imageSize.y)); + float firstThreshold = normScale * normScale / 4.f; + + const dim3 blocks(batchSize, featureCount, roiCount); + const int threads(min(256, pool_size * pool_size)); + + roiAlign_kernel<<>>( + imageSize, + featureCount, + roiCount, + firstThreshold, + samplingRatio, + (const half*)rois, + (const half*)P2, + P2dims, + (const half*)P3, + P3dims, + (const half*)P4, + P4dims, + (const half*)P5, + P5dims, + (half*)output, + {pool_size, pool_size}); +} + +} // namespace + +void {{function_name}} ( + {{elem_input_type}}* in_ptr_p2, + {{elem_input_type}}* in_ptr_p3, + {{elem_input_type}}* in_ptr_p4, + {{elem_input_type}}* in_ptr_p5, + {{elem_input_type}}* rois_ptr, + {{elem_output_type}}* out_ptr, + {{index_type}}* batch, {{index_type}}* in_ch, + {{index_type}}* p2_h, {{index_type}}* p2_w, + {{index_type}}* p3_h, {{index_type}}* p3_w, + {{index_type}}* p4_h, {{index_type}}* p4_w, + {{index_type}}* p5_h, {{index_type}}* p5_w, + const int im_h, const int im_w, + int sampling_ratio, + const float spatial_scale, + const bool position_sensitive, + const bool continuous_coordinate, + {{prefix}}Stream_t stream +) { + {{shape_function}} + + const xy_t imageSize = {im_h, im_w}; + const xy_t P2dims = {*p2_h, *p2_w}; + const xy_t P3dims = {*p3_h, *p3_w}; + const xy_t P4dims = {*p4_h, *p4_w}; + const xy_t P5dims = {*p5_h, *p5_w}; + const int featureCount = *in_ch; + const int batchSize = *batch; + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this bilinear upsampling specialization." + ); +} +""" +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_output_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + int, + int, + int, + float, + bool, + bool, + {{prefix}}Stream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr_p2}}), +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr_p3}}), +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr_p4}}), +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr_p5}}), +{{indent}} static_cast<{{elem_input_type}}*>({{rois_ptr}}), +{{indent}} static_cast<{{elem_output_type}}*>({{out_ptr}}), +{{indent}} {{p_batch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p2_h}}, {{p2_w}}, +{{indent}} {{p3_h}}, {{p3_w}}, +{{indent}} {{p4_h}}, {{p4_w}}, +{{indent}} {{p5_h}}, {{p5_w}}, +{{indent}} {{im_h}}, {{im_w}}, +{{indent}} {{sampling_ratio}}, +{{indent}} {{spatial_scale}}, +{{indent}} {{position_sensitive}}, +{{indent}} {{continuous_coordinate}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Function declaration generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + backend_spec : custom class + It specifies the corresponding backend dtypes of pytorch dtypes for many operations + + Returns + ------- + str + Rendered function declaration stmt + """ + x = func_attrs["inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + ) + + +def gen_function_call(func_attrs, backend_spec, indent=" "): + """Function call generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + indent : str, optional + Indent for template, by default " " + + Returns + ------- + str + Rendered function call + """ + p2 = func_attrs["inputs"][0] + p3 = func_attrs["inputs"][1] + p4 = func_attrs["inputs"][2] + p5 = func_attrs["inputs"][3] + rois = func_attrs["inputs"][4] + xshape = p2._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + + input_type = backend_spec.dtype_to_lib_type(p2._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr_p2=p2._attrs["name"], + in_ptr_p3=p3._attrs["name"], + in_ptr_p4=p4._attrs["name"], + in_ptr_p5=p5._attrs["name"], + rois_ptr=rois._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + im_h=func_attrs["im_shape"][0], + im_w=func_attrs["im_shape"][1], + p2_h="&" + p2._attrs["shape"][1]._attrs["name"], + p2_w="&" + p2._attrs["shape"][2]._attrs["name"], + p3_h="&" + p3._attrs["shape"][1]._attrs["name"], + p3_w="&" + p3._attrs["shape"][2]._attrs["name"], + p4_h="&" + p4._attrs["shape"][1]._attrs["name"], + p4_w="&" + p4._attrs["shape"][2]._attrs["name"], + p5_h="&" + p5._attrs["shape"][1]._attrs["name"], + p5_w="&" + p5._attrs["shape"][2]._attrs["name"], + sampling_ratio=func_attrs["sampling_ratio"], + spatial_scale=func_attrs["spatial_scale"], + position_sensitive="true" if func_attrs["position_sensitive"] else "false", + continuous_coordinate="true" + if func_attrs["continuous_coordinate"] + else "false", + backend_spec=backend_spec, + elem_input_type=input_type, + elem_output_type=output_type, + indent=indent, + ) diff --git a/python/aitemplate/backend/common/vision_ops/nms_common.py b/python/aitemplate/backend/common/vision_ops/nms_common.py new file mode 100644 index 000000000..50cc5e356 --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/nms_common.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +nms kernel codegen. +""" + +import os +from typing import Any, Dict, List + +import jinja2 + +from ... import builder +from ...target import Target +from .nms_kernel import KERNEL_TEMPLATE + +# pylint: disable=C0301 + +FUNC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { + +const int T_SIZE = {{T_SIZE}}; //(preNmsTopN + blockSize - 1) / blockSize - 1; +{{kernel}} + +} // namespace + +{{func_signature}} +{ + + const int N = *batch; + const int R = *num_rois; + nmsGpu(stream, N, R, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, fgScores, proposals, workspace, rois); +} + """ +) + +PROFILER_TEMPLATE = jinja2.Template( + """ +#include +{{header_files}} + + +size_t GLOBAL_WORKSPACE_SIZE = 0; + +namespace { + +const int T_SIZE = {{T_SIZE}}; //(preNmsTopN + blockSize - 1) / blockSize - 1; +{{kernel}} + +} // namespace + +int main(int argc, char** argv) { + int instance_num = std::stoi(argv[1]); // batch + int instance_size = std::stoi(argv[2]); // num_rois + int elem_cnt = instance_size * instance_num; + + float runtime_ms = 0; + const int64_t offsets_bytes = GetCudaAlignedSize((instance_num+1) * sizeof(int64_t)); + const int64_t scores_bytes = GetCudaAlignedSize(elem_cnt * sizeof(half)); + const int64_t boxes_bytes = GetCudaAlignedSize(elem_cnt * 4 * sizeof(half)); + int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending(instance_num, instance_size); + + GLOBAL_WORKSPACE_SIZE = GetCudaAlignedSize(offsets_bytes + scores_bytes + boxes_bytes + temp_storage_bytes); + + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + """ +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(half* rois, + const half* proposals, + const half* fgScores, + int64_t* batch, + int64_t* num_rois, + const {{index_type}} preNmsTop, + const {{index_type}} nmsMaxOut, + const float iouThreshold, + const float minBoxSize, + uint8_t* workspace, + {{prefix}}Stream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{rois}}, {{proposals}}, {{fgScores}}, +{{indent}} {{p_batch}}, +{{indent}} {{num_rois}}, +{{indent}} {{preNmsTop}}, +{{indent}} {{nmsMaxOut}}, +{{indent}} {{iouThreshold}}, +{{indent}} {{minBoxSize}}, +{{indent}} global_workspace, stream /* default stream */ +{{indent}}); + """ +) + + +def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: + """the function for generating nms kernel""" + blockSize = 1024 + t_size = int((func_attrs["preNmsTop"] + blockSize - 1) / blockSize) + if backend_spec.backend_name == "cuda": + cuda_hmaxmin = True + else: + cuda_hmaxmin = False + + return FUNC_TEMPLATE.render( + T_SIZE=t_size, + header_files=header_files, + kernel=KERNEL_TEMPLATE.render( + prefix=backend_spec.prefix, cub=backend_spec.cub, cuda_hmaxmin=cuda_hmaxmin + ), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + index_type=backend_spec.index_type, + ), + ) + + +def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + index_type=backend_spec.index_type, + ).strip() + ) + + +def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent: str) -> str: + """ "The function for generating a function call to nms""" + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 2 + + output_name = backend_spec.cast_to_half_ptr_template.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + (input_name, score_name) = ( + backend_spec.cast_to_half_ptr_template.render(name=input_tensor._attrs["name"]) + for input_tensor in func_attrs["inputs"] + ) + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + rois=output_name, + proposals=input_name, + fgScores=score_name, + p_batch="&" + xshape[0]._attrs["name"], + num_rois="&" + xshape[1]._attrs["name"], + preNmsTop=func_attrs["preNmsTop"], + nmsMaxOut=func_attrs["nmsMaxOut"], + iouThreshold=func_attrs["iouThreshold"], + minBoxSize=func_attrs["minBoxSize"], + indent=indent, + ) + + +def add_profiler( + file_pairs: List[Any], workdir: str, op_type, output_name: str, code: str +) -> None: + """generate code for profiling""" + prefix = os.path.join(workdir, "profiler", op_type) + if not os.path.exists(prefix): + os.makedirs(prefix) + src_path = os.path.join(prefix, output_name + ".cu") + obj_path = os.path.join(prefix, output_name) + if os.path.exists(obj_path): + return + with open(src_path, "w") as f: + f.write(code) + file_pairs.append((src_path, obj_path)) + + +def gen_profiler( + func_attrs: Dict[str, Any], workdir: str, header_files: str, backend_spec +) -> None: + """generate and build code for NMS profiling""" + op_type = func_attrs["op"] + file_pairs = [] + blockSize = 1024 + t_size = int((func_attrs["preNmsTop"] + blockSize - 1) / blockSize) + + if backend_spec.backend_name == "cuda": + cuda_hmaxmin = True + else: + cuda_hmaxmin = False + + code = PROFILER_TEMPLATE.render( + T_SIZE=t_size, + header_files=header_files, + kernel=KERNEL_TEMPLATE.render( + prefix=backend_spec.prefix, cub=backend_spec.cub, cuda_hmaxmin=cuda_hmaxmin + ), + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + prefix=backend_spec.prefix, + index_type=backend_spec.index_type, + ), + ) + op_name = func_attrs["op"] + add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + target = Target.current() + compile_engine = builder.Builder() + compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) diff --git a/python/aitemplate/backend/common/vision_ops/nms_kernel.py b/python/aitemplate/backend/common/vision_ops/nms_kernel.py new file mode 100644 index 000000000..1eb8a51bd --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/nms_kernel.py @@ -0,0 +1,565 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +nms kernel template. +""" +import jinja2 + +KERNEL_TEMPLATE = jinja2.Template( + """ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// code adapted from +// https://github.com/NVIDIA/TensorRT/blob/main/plugin/common/kernels/nmsLayer.cu +//------------------------------------------------------------------------ +// GPU kernel parameters. + +template < + typename Key, + int BLOCK_THREADS, + int ITEMS_PER_THREAD> +__launch_bounds__(BLOCK_THREADS) __global__ void BlockSortKernel( + Key* d_in, // Tile of input + Key* d_out) // Elapsed cycle count of block scan +{ + enum { TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD }; + + // Specialize BlockLoad type for our thread block (uses warp-striped loads for + // coalescing, then transposes in shared memory to a blocked arrangement) + typedef {{cub}}::BlockLoad< + Key, + BLOCK_THREADS, + ITEMS_PER_THREAD, + {{cub}}::BLOCK_LOAD_WARP_TRANSPOSE> + BlockLoadT; + + // Specialize BlockRadixSort type for our thread block + typedef {{cub}}::BlockRadixSort + BlockRadixSortT; + + // Shared memory + __shared__ union TempStorage { + typename BlockLoadT::TempStorage load; + typename BlockRadixSortT::TempStorage sort; + } temp_storage; + + // Per-thread tile items + Key items[ITEMS_PER_THREAD]; + + // Our current block's offset + int block_offset = blockIdx.x * TILE_SIZE; + + // Load items into a blocked arrangement + BlockLoadT(temp_storage.load).Load(d_in + block_offset, items); + + // Barrier for smem reuse + __syncthreads(); + + // Start cycle timer + clock_t start = clock(); + + // Sort keys + BlockRadixSortT(temp_storage.sort).SortBlockedToStriped(items); + + // Stop cycle timer + clock_t stop = clock(); + + // Store output in striped fashion + {{cub}}::StoreDirectStriped( + threadIdx.x, d_out + block_offset, items); + + // // Store elapsed clocks + // if (threadIdx.x == 0) + // { + // d_elapsed[blockIdx.x] = (start > stop) ? start - stop : stop - start; + // } +} + +typedef enum { + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 +} pluginStatus_t; + +typedef enum { NCHW = 0, NC4HW = 1, NC32HW = 2 } DLayout_t; + +#define CSC(call, err) \ + do { \ + {{prefix}}Error_t {{prefix}}Status = call; \ + if ({{prefix}}Status != {{prefix}}Success) { \ + return err; \ + } \ + } while (0) + +template +struct Bbox { + T xmin, ymin, xmax, ymax; + Bbox(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {} + Bbox() = default; +}; + +// HASH +unsigned int hash(const void* array_, size_t size) { + // Apply hashing only when debugging RPN codes. + if (0) { + const char* array_const; + char* array; + {{prefix}}MallocHost((void**)&array, size); + {{prefix}}Memcpy(array, array_, size, {{prefix}}MemcpyDeviceToHost); + array_const = array; + unsigned int hash = 45599; + for (size_t i = 0; i < size; i++) { + unsigned int value = array_const[i]; + hash = hash * 1487 + value; + hash = hash * 317; + hash = hash % 105359; + } + return hash; + } else { + return 0; + } +} + +// ALIGNPTR +int8_t* alignPtr(int8_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return (int8_t*)addr; +} + +#define ASSERT_PARAM(exp) \ + do { \ + if (!(exp)) \ + return STATUS_BAD_PARAM; \ + } while (0) + +// CUB's bug workaround: +// To work properly for large batch size CUB segmented sort needs ridiculous +// workspace alignment. +const uintptr_t ALIGNMENT = 1 << 20; + +// IOU +// template +// __device__ __host__ inline float IoU(const Bbox& a, const +// Bbox& b) +// { +// TFloat left = max(a.xmin, b.xmin), right = min(a.xmax, b.xmax); +// TFloat top = max(a.ymin, b.ymin), bottom = min(a.ymax, b.ymax); +// TFloat width = max((TFloat)(right - left + (TFloat) 1.0), (TFloat) 0.0); +// TFloat height = max((TFloat)(bottom - top + (TFloat) 1.0), (TFloat) 0.0); +// TFloat interS = width * height; +// TFloat Sa = (a.xmax - a.xmin + (TFloat) 1) * (a.ymax - a.ymin + (TFloat) +// 1); TFloat Sb = (b.xmax - b.xmin + (TFloat) 1) * (b.ymax - b.ymin + +// (TFloat) 1); return (float) interS / (float) (Sa + Sb - interS); +// } + +__device__ inline half hmax(const half a, const half b) { +{% if cuda_hmaxmin %} +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return a > b ? a : b; +#endif +{% else %} + return a > b ? a : b; +{% endif %} +} + +__device__ inline half hmin(const half a, const half b) { +{% if cuda_hmaxmin %} +#if __CUDA_ARCH__ >= 800 + return __hmin(a, b); +#else + return a < b ? a : b; +#endif +{% else %} + return a < b ? a : b; +{% endif %} +} + +template +__device__ __host__ inline float IoU(const Bbox& a, const Bbox& b) { + T left = hmax(a.xmin, b.xmin), right = hmin(a.xmax, b.xmax); + T top = hmax(a.ymin, b.ymin), bottom = hmin(a.ymax, b.ymax); + T width = hmax(T(right - left + T(1.0)), T(0.0)); + T height = hmax(T(bottom - top + T(1.0)), T(0.0)); + float interS = __half2float(width) * __half2float(height); + float Sa = __half2float(a.xmax - a.xmin + T(1.0)) * + __half2float(a.ymax - a.ymin + T(1.0)); + float Sb = __half2float(b.xmax - b.xmin + T(1.0)) * + __half2float(b.ymax - b.ymin + T(1.0)); + + return interS / (Sa + Sb - interS); +} + +// NMS KERNEL FOR SMALL BATCH SIZE +template +__global__ __launch_bounds__(DIM) void nmsKernel1( + const int propSize, + Bbox const* __restrict__ preNmsProposals, + T_ROIS* __restrict__ afterNmsProposals, + const int preNmsTopN, + const float nmsThres, + const int afterNmsTopN) { + __shared__ bool kept_boxes[TSIZE * DIM]; + int kept = 0; + int batch_offset = blockIdx.x * propSize; + int max_box_idx = batch_offset + preNmsTopN; + int batch_offset_out = blockIdx.x * afterNmsTopN; + + int flag_idx[TSIZE]; + int boxes_idx[TSIZE]; + Bbox cur_boxes[TSIZE]; + +// initialize kept_boxes +#pragma unroll + for (int i = 0; i < TSIZE; i++) { + boxes_idx[i] = threadIdx.x + batch_offset + DIM * i; + flag_idx[i] = threadIdx.x + DIM * i; + + if (boxes_idx[i] < max_box_idx) { + cur_boxes[i] = preNmsProposals[boxes_idx[i]]; + kept_boxes[flag_idx[i]] = true; + } else { + kept_boxes[flag_idx[i]] = false; + boxes_idx[i] = -1.0f; + flag_idx[i] = -1.0f; + } + } + + int ref_box_idx = 0 + batch_offset; + + // remove the overlapped boxes + while ((kept < afterNmsTopN) && (ref_box_idx < max_box_idx)) { + Bbox ref_box; + ref_box = preNmsProposals[ref_box_idx]; + +#pragma unroll + for (int i = 0; i < TSIZE; i++) { + if (boxes_idx[i] > ref_box_idx) { + if (IoU(ref_box, cur_boxes[i]) > nmsThres) { + kept_boxes[flag_idx[i]] = false; + } + } else if (boxes_idx[i] == ref_box_idx) { + afterNmsProposals[(batch_offset_out + kept) * 4 + 0] = ref_box.xmin; + afterNmsProposals[(batch_offset_out + kept) * 4 + 1] = ref_box.ymin; + afterNmsProposals[(batch_offset_out + kept) * 4 + 2] = ref_box.xmax; + afterNmsProposals[(batch_offset_out + kept) * 4 + 3] = ref_box.ymax; + } + } + __syncthreads(); + + do { + ref_box_idx++; + } while (!kept_boxes[ref_box_idx - batch_offset] && + ref_box_idx < max_box_idx); + + kept++; + } +} + +// NMS KERNEL FOR LARGE BATCH SIZE +template +__global__ __launch_bounds__(DIM) void nmsKernel2( + const int propSize, + Bbox const* __restrict__ proposals, + T_ROIS* __restrict__ filtered, + const int preNmsTopN, + const float nmsThres, + const int afterNmsTopN) { + Bbox const* cProposals = proposals + blockIdx.x * propSize; + + Bbox t[TSIZE]; + uint64_t del = 0; + + for (int i = 0; i < TSIZE; i++) { + if (i < TSIZE - 1 || i * DIM + threadIdx.x < preNmsTopN) { + t[i] = cProposals[i * DIM + threadIdx.x]; + } + } + + __shared__ Bbox last; + __shared__ bool kept; + __shared__ int foundBatch; + if (threadIdx.x == 0) + foundBatch = 0; + + for (int i = 0; i < TSIZE; i++) { + for (int j = 0; j < DIM; j++) { + int offset = i * DIM; + int index = offset + j; + if (index >= preNmsTopN) + break; + + __syncthreads(); + + if (threadIdx.x == j) { + kept = 0 == (del & ((uint64_t)1 << i)); + last = t[i]; + + if (kept) { + int cnt = blockIdx.x * afterNmsTopN + foundBatch; + filtered[cnt * 4 + 0] = t[i].xmin; + filtered[cnt * 4 + 1] = t[i].ymin; + filtered[cnt * 4 + 2] = t[i].xmax; + filtered[cnt * 4 + 3] = t[i].ymax; + foundBatch++; + } + } + + __syncthreads(); + + if (foundBatch == afterNmsTopN) { + return; + } + + if (kept) { + Bbox test = last; + + for (int k = 0; k < TSIZE; k++) { + if (index < k * DIM + threadIdx.x && + IoU(test, t[k]) > nmsThres) { + del |= (uint64_t)1 << k; + } + } + } + } + } +} + +// NMS LAUNCH +template +pluginStatus_t nmsLaunch( + {{prefix}}Stream_t stream, + const int batch, + const int propSize, + void* proposals, + void* filtered, + const int preNmsTopN, + const float nmsThres, + const int afterNmsTopN) { + const int blockSize = 1024; + + // #define P1(tsize) nmsKernel1 + // #define P2(tsize) nmsKernel2 + + // void (*kernel[64])( + // int, Bbox const*, T_ROIS*, int, float, int) = { + // P1(1), P1(2), P1(3), P1(4), P1(5), P1(6), P1(7), P1(8), + // P1(9), P1(10), P1(11), P1(12), P2(13), P2(14), P2(15), P2(16), + // P2(17), P2(18), P2(19), P2(20), P2(21), P2(22), P2(23), P2(24), + // P2(25), P2(26), P2(27), P2(28), P2(29), P2(30), P2(31), P2(32), + // P2(33), P2(34), P2(35), P2(36), P2(37), P2(38), P2(39), P2(40), + // P2(41), P2(42), P2(43), P2(44), P2(45), P2(46), P2(47), P2(48), + // P2(49), P2(50), P2(51), P2(52), P2(53), P2(54), P2(55), P2(56), + // P2(57), P2(58), P2(59), P2(60), P2(61), P2(62), P2(63), P2(64)}; + +#if T_SZIE <= 12 +#define nmsKernel nmsKernel1 +#else +#define nmsKernel nmsKernel2 +#endif + + ASSERT_PARAM(preNmsTopN < 64 * blockSize); + + CSC({{prefix}}MemsetAsync( + filtered, 0x00, batch * afterNmsTopN * 4 * sizeof(T_ROIS), stream), + STATUS_FAILURE); + + nmsKernel<<>>( + propSize, + (Bbox*)proposals, + (T_ROIS*)filtered, + preNmsTopN, + nmsThres, + afterNmsTopN); + + CSC({{prefix}}GetLastError(), STATUS_FAILURE); + + return STATUS_SUCCESS; +} + +// SET OFFSET +// Works for up to 2Gi elements (cub's limitation)! +__global__ void setOffset(int stride, int size, int* output) { + // One block, because batch size shouldn't be too large. + for (int i = threadIdx.x; i < size; i += blockDim.x) { + output[i] = i * stride; + } +} + +// BBFilter KERNEL +__global__ void bboxFilter_kernel( + int N, + const float minSize, + const half* proposals, + half* scores) { + if (minSize == 0) + return; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + uint16_t bits = 0x3c00u; + half one = reinterpret_cast(bits); + + if (tid < N) { + int ininf = 0xff800000; + float ninf = *(float*)&ininf; + + if (__hsub(proposals[tid * 4 + 2], proposals[tid * 4 + 0]) < + half(minSize) || + __hsub(proposals[tid * 4 + 3], proposals[tid * 4 + 1]) < + half(minSize)) { + scores[tid] = half(ninf); + } + } +} + +inline size_t GetCudaAlignedSize(size_t size) { + const size_t kCudaAlignSize = 1 << 20; + return (size + kCudaAlignSize - 1) / kCudaAlignSize * kCudaAlignSize; +} + +class MultiplyFunctor final { + public: + MultiplyFunctor(int32_t num_col) : num_col_(num_col) {} + __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const { + return idx * num_col_; + } + + private: + int32_t num_col_; +}; + +template +size_t InferTempStorageForSortPairsDescending( + int32_t num_row, + int32_t num_col) { + using SegmentOffsetIter = {{cub}}::TransformInputIterator< + int32_t, + MultiplyFunctor, + {{cub}}::CountingInputIterator>; + + {{cub}}::CountingInputIterator counting_iter(0); + MultiplyFunctor multiply_functor(num_col); + SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor); + + size_t temp_storage_bytes = 0; + auto err = {{cub}}::DeviceSegmentedRadixSort:: + SortPairsDescending( + /* d_temp_storage */ nullptr, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ nullptr, + /* d_keys_out */ nullptr, + /* d_values_in */ nullptr, + /* d_values_out */ nullptr, + /* num_items */ num_row * num_col, + /* num_segments */ num_row, + /* d_begin_offsets */ segment_offset_iter, + /* d_end_offsets */ segment_offset_iter + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(KeyType) * 8, + /* stream */ 0); + // OF_CUDA_CHECK(err); + + return temp_storage_bytes; +} + +// NMS GPU +template +pluginStatus_t nmsGpu( + {{prefix}}Stream_t stream, + const int N, + const int R, + const int preNmsTop, + const int nmsMaxOut, + const float iouThreshold, + const float minBoxSize, + const void* fgScores, + const void* proposals, + void* workspace, + void* rois) { + const int BS = 32; + const int GS = ((R) + BS - 1) / BS; + bboxFilter_kernel<<>>( + R, minBoxSize, (T_ROIS*)proposals, (T_ROIS*)fgScores); + + int8_t* vworkspace = alignPtr((int8_t*)workspace, 32); + + pluginStatus_t error; + + int* offsets = (int*)vworkspace; + setOffset<<<1, 1024, 0, stream>>>(R, N + 1, offsets); + CSC({{prefix}}GetLastError(), STATUS_FAILURE); + + vworkspace = vworkspace + N + 1; + vworkspace = alignPtr(vworkspace, ALIGNMENT); + + std::size_t tempStorageBytes = + InferTempStorageForSortPairsDescending(N, R); + + CSC({{prefix}}GetLastError(), STATUS_FAILURE); + + T_SCORES* scoresOut = (T_SCORES*)vworkspace; + vworkspace = (int8_t*)(scoresOut + N * R); + vworkspace = alignPtr(vworkspace, ALIGNMENT); + Bbox* proposalsOut = (Bbox*)vworkspace; + vworkspace = (int8_t*)(proposalsOut + N * R); + vworkspace = alignPtr(vworkspace, ALIGNMENT); + + {{cub}}::DeviceSegmentedRadixSort::SortPairsDescending( + vworkspace, + tempStorageBytes, + (T_SCORES*)fgScores, + (T_SCORES*)scoresOut, + (Bbox*)proposals, + (Bbox*)proposalsOut, + N * R, + N, + offsets, + offsets + 1, + 0, + 8 * sizeof(T_SCORES), + stream); + + CSC({{prefix}}GetLastError(), STATUS_FAILURE); + + error = nmsLaunch( + stream, N, R, proposalsOut, rois, preNmsTop, iouThreshold, nmsMaxOut); + + if (error != STATUS_SUCCESS) { + return error; + } + return STATUS_SUCCESS; +} + """ +) diff --git a/python/aitemplate/backend/common/vision_ops/roi_align_common.py b/python/aitemplate/backend/common/vision_ops/roi_align_common.py new file mode 100644 index 000000000..b658b711f --- /dev/null +++ b/python/aitemplate/backend/common/vision_ops/roi_align_common.py @@ -0,0 +1,392 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +roi align common functions for all backends. +""" + +import jinja2 + +# pylint: disable=C0103,C0415,W0613,C0301,W0612 + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}roi_align_launcher( +{{indent}} in_ptr, +{{indent}} rois_ptr, +{{indent}} out_ptr, +{{indent}} NI, +{{indent}} HI, +{{indent}} WI, +{{indent}} CI, +{{indent}} HO, +{{indent}} WO, +{{indent}} sampling_ratio, +{{indent}} spatial_scale, +{{indent}} position_sensitive, +{{indent}} continuous_coordinate, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +{{header_files}} + +namespace { +#define CUDA_KERNEL_LOOP(i, n) \ + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +template +__device__ float2 bilinear_interpolate(const half2* bottom_data, + const int height, + const int width, + T y, + T x, + const int channels, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + float2 val = {0.f, 0.f}; + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return val; + } + + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + + y_high = y_low >= height - 1 ? height - 1 : y_low + 1; + y_low = y_low >= height - 1 ? height - 1 : y_low; + y = y_low >= height - 1 ? (T)y_low : y; + + x_high = x_low >= width - 1 ? width - 1 : x_low + 1; + x_low = x_low >= width - 1 ? width - 1 : x_low; + x = x_low >= width - 1 ? (T)x_low : x; + + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + const half2 v1 = __ldg(bottom_data + (y_low * width + x_low) * channels); + const half2 v2 = __ldg(bottom_data + (y_low * width + x_high) * channels); + const half2 v3 = __ldg(bottom_data + (y_high * width + x_low) * channels); + const half2 v4 = __ldg(bottom_data + (y_high * width + x_high) * channels); + + T v1_x = __half2float(v1{{half2_data_ref}}.x); + T v2_x = __half2float(v2{{half2_data_ref}}.x); + T v3_x = __half2float(v3{{half2_data_ref}}.x); + T v4_x = __half2float(v4{{half2_data_ref}}.x); + + T v1_y = __half2float(v1{{half2_data_ref}}.y); + T v2_y = __half2float(v2{{half2_data_ref}}.y); + T v3_y = __half2float(v3{{half2_data_ref}}.y); + T v4_y = __half2float(v4{{half2_data_ref}}.y); + + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + val.x = (w1 * v1_x + w2 * v2_x + w3 * v3_x + w4 * v4_x); + val.y = (w1 * v1_y + w2 * v2_y + w3 * v3_y + w4 * v4_y); + + return val; +} + +template +__global__ void roi_align_f16_nhwc_kernel(const half2* bottom_data, + const half* bottom_rois, + half2* top_data, + const int64_t N, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t pooled_height, + const int64_t pooled_width, + const int sampling_ratio, + const float spatial_scale, + const bool position_sensitive, + const bool continuous_coordinate) { + + const int64_t nthreads = num_rois * channels * pooled_width * pooled_height; + + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + // index = c + channels * (x + out_width * (y + out_height * b)) + int64_t idx = index; + const int c = idx % channels; + idx /= channels; + const int pw = idx % pooled_width; + idx /= pooled_width; + const int ph = idx % pooled_height; + const int n = idx / pooled_height; + + + const half* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = static_cast(__half2float(offset_bottom_rois[0])); + + float2 output_val = {0.f, 0.f}; + if (roi_batch_ind < 0) { + top_data[index] = __float22half2_rn(output_val); + continue; + } + + // Do not using rounding; this implementation detail is critical + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = __half2float(offset_bottom_rois[1]) * spatial_scale - roi_offset; + T roi_start_h = __half2float(offset_bottom_rois[2]) * spatial_scale - roi_offset; + T roi_end_w = __half2float(offset_bottom_rois[3]) * spatial_scale - roi_offset; + T roi_end_h = __half2float(offset_bottom_rois[4]) * spatial_scale - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!continuous_coordinate) { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_unpooled = c; + int channels_unpooled = channels; + if (position_sensitive) { + c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw; + channels_unpooled = channels * pooled_height * pooled_width; + } + + const half2* offset_bottom_data = + bottom_data + (roi_batch_ind * height * width * channels_unpooled + c_unpooled); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + // T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const T y = + roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + float2 val = bilinear_interpolate(offset_bottom_data, height, width, y, x, channels, index); + output_val.x += val.x; + output_val.y += val.y; + } + } + output_val.x /= count; + output_val.y /= count; + + top_data[index] = __float22half2_rn(output_val); + } + +} + + +template +constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + + +template +void roi_align_launcher({{elem_input_type}}* input, + {{elem_input_type}}* rois, + {{elem_output_type}}* output, + const {{index_type}} N, + const {{index_type}} H, + const {{index_type}} W, + const {{index_type}} C, + const {{index_type}} HO, + const {{index_type}} WO, + const int sampling_ratio, + const float spatial_scale, + const bool position_sensitive, + const bool continuous_coordinate, + {{prefix}}Stream_t stream) { + + const int64_t output_size = num_rois * C * HO * WO; + + dim3 grid(std::min( + ceil_div(static_cast(output_size), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + roi_align_f16_nhwc_kernel<<>>( + (const half2*)input, (const half*)rois, (half2*)output, N, H, W, C / 2, HO, WO, + sampling_ratio, spatial_scale, position_sensitive, continuous_coordinate); + +} +} // namespace + +void {{function_name}} ( + {{elem_input_type}}* in_ptr, + {{elem_input_type}}* rois_ptr, + {{elem_output_type}}* out_ptr, + {{index_type}}* batch, + {{index_type}}* in_h, + {{index_type}}* in_w, + {{index_type}}* in_ch, + {{index_type}}* out_batch, + {{index_type}}* out_h, + {{index_type}}* out_w, + int sampling_ratio, + const float spatial_scale, + const bool position_sensitive, + const bool continuous_coordinate, + {{prefix}}Stream_t stream +) { + {{shape_function}} + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this avg pool2d specialization." + ); +} + +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_output_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + {{index_type}}*, + int, + float, + bool, + bool, + {{prefix}}Stream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr}}), +{{indent}} static_cast<{{elem_input_type}}*>({{rois_ptr}}), +{{indent}} static_cast<{{elem_output_type}}*>({{out_ptr}}), +{{indent}} {{p_batch}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} {{sampling_ratio}}, +{{indent}} {{spatial_scale}}, +{{indent}} {{position_sensitive}}, +{{indent}} {{continuous_coordinate}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_function_decl(func_attrs, backend_spec): + """Function declaration generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + backend_spec : custom class + It specifies the corresponding backend dtypes of pytorch dtypes for many operations + + Returns + ------- + str + Rendered function declaration stmt + """ + x = func_attrs["inputs"][0] + y = func_attrs["outputs"][0] + input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + return FUNC_DECL_TEMPLATE.render( + index_type=backend_spec.index_type, + prefix=backend_spec.prefix, + func_name=func_attrs["name"], + elem_input_type=input_type, + elem_output_type=output_type, + ) + + +def gen_function_call(func_attrs, backend_spec, indent=" "): + """Function call generation + + Parameters + ---------- + func_attrs : Dict[str, Any] + It describes the operation attributes + indent : str, optional + Indent for template, by default " " + + Returns + ------- + str + Rendered function call + """ + x = func_attrs["inputs"][0] + rois = func_attrs["inputs"][1] + xshape = x._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + + input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) + output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + rois_ptr=rois._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + sampling_ratio=func_attrs["sampling_ratio"], + spatial_scale=func_attrs["spatial_scale"], + position_sensitive="true" if func_attrs["position_sensitive"] else "false", + continuous_coordinate="true" + if func_attrs["continuous_coordinate"] + else "false", + backend_spec=backend_spec, + elem_input_type=input_type, + elem_output_type=output_type, + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/__init__.py b/python/aitemplate/backend/cuda/__init__.py new file mode 100644 index 000000000..38586aab5 --- /dev/null +++ b/python/aitemplate/backend/cuda/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +""" +CUDA backend codegen functions. +""" +from . import cuda_common, lib_template, target_def, utils +from .common import * +from .conv2d import * +from .elementwise import * +from .embedding import * +from .gemm_special import * +from .gemm_universal import * +from .gemm_epilogue_vistor import * +from .layernorm_sigmoid_mul import * +from .padding import * +from .pool2d import * +from .reduce import * +from .softmax import * +from .tensor import * +from .upsample import * +from .view_ops import * +from .vision_ops import * +from .attention import * +from .groupnorm import * diff --git a/python/aitemplate/backend/cuda/attention/__init__.py b/python/aitemplate/backend/cuda/attention/__init__.py new file mode 100644 index 000000000..61a47c3ad --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +cuda flash_attention module init +""" +from . import flash_attention + +__all__ = ["flash_attention"] diff --git a/python/aitemplate/backend/cuda/attention/flash_attention.py b/python/aitemplate/backend/cuda/attention/flash_attention.py new file mode 100644 index 000000000..b2fe5c0ca --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/flash_attention.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +attention kernel codegen for CUDA. +""" +from typing import Any, Dict + +import jinja2 + +from ... import registry + +# pylint: disable=C0301 + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast(&({{name}}->raw()))" +) + +FUNC_CALL_INT32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_CALL_FP32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + +FUNC_TEMPLATE = jinja2.Template( + """ +#include +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +#include "fmha.h" +#include "fmha_fprop_kernel_1xN.h" + +namespace { + +template +__global__ void fmha_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) { + fmha::device_1xN_loop(params); +} + +template +void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, + const bool configure) { + bool is_causal = launch_params.params.is_causal; + auto kernel = (is_causal + ? (&fmha_fprop_fp16_sm80_loop_kernel) + : (&fmha_fprop_fp16_sm80_loop_kernel)); + + constexpr int N = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.s + N - 1) / N; + constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr int M = Kernel_traits::Cta_tile_p::M; + size_t STEPS = (launch_params.params.s + M - 1) / M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; + launch_params.elts_per_thread = elts_per_head; + return; + } + + dim3 grid(launch_params.params.h, launch_params.params.b); + kernel<<>>( + launch_params.params); + + FMHA_CHECK_CUDA(cudaPeekAtLastError()); +} + +void run_fmha_fp16_sm80(Launch_params &launch_params, + const bool configure) { +{{custom_kernel}} +} + +void set_params(Fused_multihead_attention_fprop_params ¶ms, + // sizes + const size_t b, + const size_t s, + const size_t h, + const size_t d, + // device pointers + void *qkv_packed_d, + void *cu_seqlens_d, + void *o_packed_d, + void *o_tmp_d, + void *do_packed_d, + void *s_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + bool is_causal) { + + Data_type acc_type = DATA_TYPE_FP32; + Data_type data_type = DATA_TYPE_FP16; + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = qkv_packed_d; + params.k_ptr = qkv_packed_d + get_size_in_bytes(h * d, data_type); + params.v_ptr = qkv_packed_d + 2 * get_size_in_bytes(h * d, data_type); + params.q_row_stride_in_elts = 3 * h * d; + params.k_row_stride_in_elts = 3 * h * d; + params.v_row_stride_in_elts = 3 * h * d; + params.q_head_stride_in_elts = d; + params.k_head_stride_in_elts = d; + params.v_head_stride_in_elts = d; + params.o_ptr = o_packed_d; + params.o_row_stride_in_elts = h * d; + params.o_head_stride_in_elts = d; + params.do_ptr = do_packed_d; + params.o_tmp_ptr = o_tmp_d; + + params.cu_seqlens = static_cast(cu_seqlens_d); + + // S = softmax(P) + params.s_ptr = s_d; + params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + params.dsoftmax_sum = dsoftmax_sum_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.s = s; + params.d = d; + + // Set the different scale values. + // const float scale_bmm1 = 1.f / sqrtf(d); + const float scale_bmm1 = softmax_scale; + constexpr float scale_softmax = 1.f; + constexpr float scale_bmm2 = 1.f; + + params.scale_bmm1f = scale_bmm1; + set_alpha(params.scale_bmm1, scale_bmm1, data_type); + set_alpha(params.scale_softmax, scale_softmax, acc_type); + set_alpha(params.scale_bmm2, scale_bmm2, data_type); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.rp_dropout = 1.f / params.p_dropout; + set_alpha(params.scale_dropout, params.rp_dropout, data_type); + + params.is_causal = is_causal; +} +} // namespace + +{{func_signature}} +{ + bool is_dropout = p_dropout > 0.0; + bool return_softmax = false; + + Launch_params launch_params(stream, is_dropout, return_softmax); + + set_params(launch_params.params, + batch_size, // b + seq_len, // s + num_heads, // h + head_size, // d + (void*)qkv, + (void*)cu_seqlens, + (void*)output, + loop ? (void*)o_tmp : nullptr, + nullptr, + nullptr, // return softmax + (void*)softmax_lse, + nullptr, + p_dropout, + softmax_scale, + is_causal); + + run_fmha_fp16_sm80(launch_params, /*configure=*/ false); +} + """ +) + + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(half* output, + const half* qkv, + const int* cu_seqlens, + float* softmax_lse, + float* o_tmp, + int batch_size, + int seq_len, + int num_heads, + int head_size, + float p_dropout, + float softmax_scale, + bool is_causal, + bool loop, + cudaStream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{output}}, {{qkv}}, {{cu_seqlens}}, +{{indent}} {{softmax_lse}}, {{o_tmp}}, +{{indent}} {{batch_size}}, +{{indent}} {{seq_len}}, +{{indent}} {{num_heads}}, +{{indent}} {{head_size}}, +{{indent}} {{p_dropout}}, +{{indent}} {{softmax_scale}}, +{{indent}} {{is_causal}}, {{loop}}, stream /* default stream */ +{{indent}}); + """ +) + +ATT_KERNEL_TEMPLATE = jinja2.Template( + """ + using Kernel_traits = FMHA_kernel_traits<{{s1}}, {{s2}}, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + """ +) + + +@registry.reg("cuda.flash_attention.gen_function") +def flash_attention_gen_function(func_attrs: Dict[str, Any]) -> str: + """the function for generating attention kernel""" + return FUNC_TEMPLATE.render( + custom_kernel=ATT_KERNEL_TEMPLATE.render( + s1=128 if func_attrs["seq_len"] == 128 else 256, + s2=func_attrs["head_size"], + ), + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]), + ) + + +@registry.reg("cuda.flash_attention.func_decl") +def flash_attention_gen_function_decl(func_attrs: Dict[str, Any]): + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]).strip() + ) + + +@registry.reg("cuda.flash_attention.func_call") +def flash_attention_gen_function_call(func_attrs, indent=" "): + """the function for generating a function call for attention""" + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 2 + + output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + + qkv_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][0]._attrs["name"] + ) + + seqlens_name = FUNC_CALL_INT32_PARAM_TEMPLATE.render( + name=func_attrs["inputs"][1]._attrs["name"] + ) + + x = func_attrs["inputs"][0] + + batch_size = func_attrs["batch_size"] + seq_len = func_attrs["seq_len"] + + num_heads = x._attrs["shape"][2]._attrs["values"][0] + head_size = x._attrs["shape"][3]._attrs["values"][0] + p_dropout = func_attrs["dropout"] + is_causal = func_attrs["causal"] + softmax_scale = head_size ** (-0.5) + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + qkv=qkv_name, + cu_seqlens=seqlens_name, + softmax_lse="reinterpret_cast(global_workspace)", + o_tmp="reinterpret_cast(global_workspace + {} * sizeof(float))".format( + batch_size * num_heads * seq_len + ), + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_heads, + head_size=head_size, + p_dropout=p_dropout, + softmax_scale=softmax_scale, + is_causal="true" if is_causal else "false", + loop="true" if seq_len > 256 else "false", + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/attention/src/fmha.h b/python/aitemplate/backend/cuda/attention/src/fmha.h new file mode 100644 index 000000000..9cc516722 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha.h @@ -0,0 +1,211 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include + +constexpr int TOTAL_DIM = 0; +constexpr int THREE_DIM = 1; +constexpr int H_DIM = 2; +constexpr int D_DIM = 3; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct PhiloxCudaState { + PhiloxCudaState() = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, uint64_t offset) { + seed_ = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState( + uint64_t seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_ = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::cuda::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + uint64_t seed_ = 0; + Payload offset_; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +struct Qkv_params { + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + // size_t qkv_stride_in_elts; + // size_t qkv_stride_in_bytes; + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + uint32_t q_row_stride_in_elts; + uint32_t k_row_stride_in_elts; + uint32_t v_row_stride_in_elts; + uint32_t q_head_stride_in_elts; + uint32_t k_head_stride_in_elts; + uint32_t v_head_stride_in_elts; + + // The number of heads. + int h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Fused_multihead_attention_fprop_params : public Qkv_params { + // The dQKV matrices. + void* __restrict__ dqkv_ptr; + + // The O matrix (output). + void* __restrict__ o_ptr; + + // The stride between rows of O. + // size_t o_stride_in_elts; + // size_t o_stride_in_bytes; + uint32_t o_row_stride_in_elts; + uint32_t o_head_stride_in_elts; + + // The pointer to the O_tmp matrix, which holds O intermediate value during + // the loop; + void* __restrict__ o_tmp_ptr; + + // The dO matrix . + void* __restrict__ do_ptr; + + // The pointer to the S matrix, overwritten by the dP matrix (bwd). + void* __restrict__ s_ptr; + // The stride between rows of the S matrix. + // int64_t s_stride_in_bytes; + uint32_t s_stride_in_bytes; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + + // The pointer to the softmax d sum. + void* __restrict__ dsoftmax_sum; + + // The dimensions. + int b, s, d; + + // The scaling factors for the kernel. + float scale_bmm1f; + uint32_t scale_bmm1, scale_softmax, scale_bmm2; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens; + + int* __restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint32_t p_dropout_in_uint; + uint16_t p_dropout_in_uint16_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Scale factor of 1 / (1 - p_dropout), in half2. + uint32_t scale_dropout; + + // Random state. + PhiloxCudaState philox_args; + + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Launch_params { + Launch_params(cudaStream_t stream_, bool is_dropout_, bool return_softmax_) + : elts_per_thread(0), + stream(stream_), + is_dropout(is_dropout_), + return_softmax(return_softmax_) {} + + size_t elts_per_thread; + + cudaStream_t stream; + + bool is_dropout; + bool return_softmax; + + Kernel_params params; + int num_full_heads; + int num_main_groups; + int heads_last_wave; + int main_steps; + int rest_steps; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// void run_fmha_fp16_sm80(Launch_params +// &launch_params, const bool configure); + +// void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params +// ¶ms, cudaStream_t stream); + +// void +// run_fmha_block_fp16_sm80(Launch_params +// &launch_params, const bool configure); + +// void run_fmha_block_dgrad_fp16_sm80(const +// Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h b/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h new file mode 100644 index 000000000..433676370 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h @@ -0,0 +1,482 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/layout/layout.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_base_ { + // The data type. + using Data_type = Data_type_; + // default input type + using Input_type_ = Data_type_; + // Does it store the array of elements. + static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8; + // The number of elements. + static constexpr int NUM_ELTS = NUM_ELTS_; + // The size of element in bits. + static constexpr int BITS_PER_ELT = BITS_PER_ELT_; + // The size of byte of a single register. + static constexpr int BYTES_PER_REG = 4; + // The size in bits. + static constexpr int BITS_PER_REG = BYTES_PER_REG * 8; + // The number of registers needed to store the fragment. + static constexpr int NUM_REGS = + DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG); + // The size in bytes (as returned by sizeof(Fragment_base<>). + static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG; + // The alignment. + static constexpr int ALIGNMENT = ALIGNMENT_ > 0 + ? ALIGNMENT_ + : MinConstexpr(NUM_REGS* BYTES_PER_REG, 16); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The type of the elements. + typename Data_type_, + // The number of elements. + int NUM_ELTS_, + // The alignment if you want to force a value -- use 0 otherwise. + int ALIGNMENT_ = 0, + // The base class. + typename Base_ = Fragment_base_< + Data_type_, + NUM_ELTS_, + 8 * sizeof(Data_type_), + ALIGNMENT_>> +struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { + // The size of a load/store. + static constexpr int BYTES_PER_LOAD_STORE = + Base_::NUM_REGS * sizeof(uint32_t); + + // Clear the fragment. Using PTX in that code seems to produce better SASS... + inline __device__ void clear() { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) :); + } + } + + // Immutable access to a register. + inline __device__ const uint32_t& reg(int ii) const { + return this->regs_[ii]; + } + + // Mutable access to a register. + inline __device__ uint32_t& reg(int ii) { + return this->regs_[ii]; + } + + uint32_t regs_[Base_::NUM_REGS]; + + // Immutable access to the elements. + inline __device__ const Data_type_& elt(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + inline __device__ Data_type_& elt(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to the elements with a cast. + template + inline __device__ const Cast_type& elt_as(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + template + inline __device__ Cast_type& elt_as(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Add another fragment. + inline __device__ void add(const Fragment& other) { +// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS? +// Also are we doing int addition or __half2 addition? +#pragma unroll + for (int ii = 0; ii < NUM_ELTS_; ++ii) { + this->elt(ii) += other.elt(ii); + } + } + + // Multiply by another fragment. + inline __device__ void hmul(const Fragment& other) { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); + } + } + + inline __device__ void hrelu_() { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + this->reg(ii) = fmha::hrelu2(this->reg(ii)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(const Other_fragment_& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul_(const float other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma( + const Fragment_a& a, + const Fragment_b& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), + "r"(a.reg(1)), + "r"(a.reg(2)), + "r"(a.reg(3)), + "r"(b.reg(0)), + "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), + "r"(a.reg(1)), + "r"(a.reg(2)), + "r"(a.reg(3)), + "r"(b.reg(2)), + "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear(Fragment (&frag)[M][N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + frag[mi][ni].clear(); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm( + Acc (&acc)[M][N], + const A (&a)[M], + const B (&b)[N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + acc[mi][ni].mma(a[mi], b[ni]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_cl( + Acc (&acc)[M][N], + const A (&a)[M], + const B (&b)[N]) { + using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; +#else + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + // TD [2022-06-02] We don't support Volta (SM70) yet. + assert(0); +#endif + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::arch::OpMultiplyAdd, + 1, + true>::Type; + + constexpr int kIters = Shape::kK / InstructionShape::kK; + // using FragmentA = typename WarpMma::FragmentA; + // using FragmentB = typename WarpMma::FragmentB; + using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA; + using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB; + using FragmentC = typename WarpMma::FragmentC; + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) { + // printf("FragmentA::kStorageElements = %d\n", + // FragmentA::kStorageElements); + // printf("Archmma::FragmentA::kStorageElements = %d\n", + // WarpMma::ArchMmaOperator::FragmentA::kStorageElements); + // printf("FragmentB::kStorageElements = %d\n", + // FragmentB::kStorageElements); + // printf("Archmma::FragmentB::kStorageElements = %d\n", + // WarpMma::ArchMmaOperator::FragmentB::kStorageElements); + // printf("FragmentC::kStorageElements = %d\n", + // FragmentC::kStorageElements); + // printf("Archmma::FragmentC::kStorageElements = %d\n", + // WarpMma::ArchMmaOperator::FragmentC::kStorageElements); + // } + + // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); + // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); + static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS); + static_assert( + FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == + b[0].NUM_REGS); + static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); + // const FragmentA a_cl = reinterpret_cast(a); + // const FragmentB b_cl = reinterpret_cast(b); + FragmentC c_cl = reinterpret_cast(acc); + FragmentA a_cl[kIters][M]; + FragmentA b_cl[kIters][N]; + constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2; +#pragma unroll + for (int iter = 0; iter < kIters; iter++) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + uint32_t* a_ptr = a_cl[iter][mi].raw_data(); +#pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + a_ptr[ki] = a[mi].regs_[iter * kRegs + ki]; + } + } + } +#pragma unroll + for (int iter = 0; iter < kIters; iter++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint32_t* b_ptr = b_cl[iter][ni].raw_data(); +#pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki]; + // TD [2022-06-02] For some reason the order for frag_b is different. + b_ptr[ki] = b[ni].regs_ + [InstructionShape::kK == 16 ? iter * kRegs + ki + : ki * kRegs + iter]; + } + } + } + + WarpMma mma_op; +// mma_op(c_cl, a_cl, b_cl, c_cl); +#pragma unroll + for (int iter = 0; iter < kIters; iter++) { + mma_op( + c_cl, + reinterpret_cast(a_cl[iter]), + reinterpret_cast(b_cl[iter]), + c_cl); + } + +// The modified c_cl is not copied back into acc, idk why +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int i = 0; i < 8; i++) { + acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i]; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The number of rows in the CTA tile. + int M_, + // The number of cols in the CTA tile. + int N_, + // The number of elements in the the K dimension of the GEMM loop. + int K_, + // The number of rows of warps. + int WARPS_M_, + // The number of cols of warps. + int WARPS_N_, + // The number of warps in the K dimension of the GEMM loop. + int WARPS_K_> +struct Cta_tile_ { + static constexpr int M = M_, N = N_, K = K_; + // The number of warps. + static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, + WARPS_K = WARPS_K_; + // The number of warps per CTA. + static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; + // The number of threads per warp. + static constexpr int THREADS_PER_WARP = 32; + // The number of threads per CTA. + static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_tile { + // The number of elements computed with a single warp-MMA. + static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; + + // The number of elements computed with a single CTA-MMA. + static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; + + // The number of MMAs needed to compute the GEMM. + static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), + MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), + MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); + + // // The number of elements computed per warp. + // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA, + // N_PER_WARP = MMAS_N * N_PER_MMA, + // K_PER_WARP = MMAS_K * K_PER_MMA; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using A_type = uint16_t; +using B_type = uint16_t; +using C_type = uint16_t; +using Accumulator_type = float; +using Epilogue_type = float; + +constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; +constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; +constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_extd = Cta_tile_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_with_k_with_padding = Cta_tile_extd< + Cta_tile_::M, + Cta_tile_::N, + Next_power_of_two::VALUE, + Cta_tile_::WARPS_M, + Cta_tile_::WARPS_N, + Cta_tile_::WARPS_K>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h b/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h new file mode 100644 index 000000000..119ac6a6f --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h @@ -0,0 +1,608 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile_, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS> +struct Gmem_tile_qkv { + using Cta_tile = Cta_tile_; + + static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; + // The size of each LDG. + static constexpr int BYTES_PER_LDG = 16; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8; + + // The number of threads to load a "row" of the matrix. + static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG; + + static constexpr int ROWS = ROWS_; + // The number of "rows" loaded per LDG. + static constexpr int ROWS_PER_LDG = + Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // The number of LDGs needed to load a chunk of the Q matrix. + static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); + + // Ctor. + template + inline __device__ Gmem_tile_qkv( + void* ptr_, + const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, + const BInfo& binfo, + const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT), + actual_seqlen(binfo.actual_seqlen), + ptr(reinterpret_cast(ptr_)), + tidx_(tidx) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable the loads. + // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of + // storing it row_ = row; + + // The row offset in the batched GEMM. For each seq element, we store QKV in + // that order. int64_t row_offset = (int64_t)row * + // params.qkv_stride_in_bytes; + uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); + // Add the block index. + // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + + // binfo.bidh) * BYTES_PER_ROW; + row_offset += + (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + + // Assemble the final pointer. + ptr += row_offset + col * BYTES_PER_LDG; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + smem_tile.store(fetch_); + } + + inline __device__ void load() { + int row_ = tidx_ / THREADS_PER_ROW; + const void* ptrs[LDGS]; + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); + fetch_[ii] = make_uint4(0, 0, 0, 0); + } + + // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) + Ldg_functor fct(fetch_, ptrs); +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fct.load(ii, preds[ii]); + } + } + + // Store data to memory. + inline __device__ void store(const uint4 (&data)[LDGS]) { + int row_ = tidx_ / THREADS_PER_ROW; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + char* ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) { + fmha::stg(ptr_, data[ii]); + } + } + } + + inline __device__ void move(const int steps = 1) { + // ptr += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; + actual_seqlen -= ROWS * steps; + } + + // The stride between rows for the QKV matrice. + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; + // The pointer. + char* ptr; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row the thread is processing as we move the tile. + // int row_; + const int tidx_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o { + static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4); + + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The size of each element. + // static constexpr int BYTES_PER_ELEMENT = 2; + // The size of each STG. + static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4; + static constexpr int COLS = Cta_tile::N; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of threads to store a "row" of the matrix. + static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS_PER_LOOP = + ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; + // The number of outter loop for the stores. + static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; + + // The number of "rows" stored per STG. + static constexpr int ROWS_PER_STG = + Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // Do we have to guard against partial writes/reads. + static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0; + // The number of STGs needed to store a chunk of the Q matrix. + static constexpr int STGS_PER_LOOP = + DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG); + // The number of STGs needed to store a chunk of the Q matrix in total. + static constexpr int STGS = STGS_PER_LOOP * LOOPS; + + // Ctor. + template + // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, + // const BInfo &binfo, const int tidx) + inline __device__ Gmem_tile_o( + void* ptr, + const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, + const BInfo& binfo, + const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT), + actual_seqlen(binfo.actual_seqlen), + ptr_(reinterpret_cast(ptr)), + tidx_(tidx) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable loads. + // row_ = row; + + // The row offset in the batched GEMM. + // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * + // BYTES_PER_ROW; + uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); + row_offset += + (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + // Assemble the final pointer. + ptr_ += row_offset + col * BYTES_PER_STG; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + } + + // Store data to global memory. + inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { + int row_ = tidx_ / THREADS_PER_ROW; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= this->actual_seqlen) { + break; + } + + if (BYTES_PER_ELEMENT == 4) { + if (!HAS_INCOMPLETE_STG || + (jj < STGS - 1 || this->is_active_for_last_stg_)) { + fmha::stg( + this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, + src[ii]); + } + } else if (BYTES_PER_ELEMENT == 2) { + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + uint2 out = float4_to_half4(x, y, z, w); + if (!HAS_INCOMPLETE_STG || + (jj < STGS - 1 || this->is_active_for_last_stg_)) { + fmha::stg( + this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out); + } + } + } + } + + // Store data to global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + static_assert(BYTES_PER_ELEMENT == 4); + int row_ = tidx_ / THREADS_PER_ROW; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= this->actual_seqlen) { + break; + } + + if (!HAS_INCOMPLETE_STG || + (jj < STGS - 1 || this->is_active_for_last_stg_)) { + fmha::ldg( + dst[ii], + this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); + } + } + } + + inline __device__ void move(const int steps = 1) { + // row_ += ROWS * steps; + // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + actual_seqlen -= ROWS * steps; + } + + // The stride between rows for the QKV matrice. + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; + // The pointer. + char* ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + // The length of the sequence loaded by that memory tile. + int actual_seqlen; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_mma_sd { + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // Each STG stores 8 elements. + static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8; + // The number of MMAs in the M dimension. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int MMAS_N = Mma_tile::MMAS_N; + // The number of rows computed per MMA per thread block. + static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA; + // The number of cols computed per MMA per thread block. + static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA; + // The number of threads per block. + static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA; + // The size of each row in bytes. I.e. how many bytes are stored per STG. + static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG; + // The distance between elements stored per loop (in bytes). + static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW; + + // The type of elements stored per STG. + using Type = typename fmha::Uint_from_size_in_bytes::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_sd( + void* ptr, + const Params& params, + const int bidb, + const int bidh, + const int tidx) + : ptr_(static_cast(ptr)) { + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // The distance between two blocks (in bytes). + // const size_t block_stride_bytes = params.s * params.s * + // BYTES_PER_ELEMENT; + const uint32_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT; + // Set store location for each thread at the beginning of the loop + ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG; + } + + // Store to global memory. + inline __device__ void store(const Type& data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::stg(ptr_ + offset, data); + } + + // Load from global memory. + inline __device__ void load(Type& data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::ldg(data, ptr_ + offset); + } + + // Move to the next tile. + inline __device__ void move(const int steps = 1) { + ptr_ += LOOP_STRIDE_BYTES * steps; + } + + // The pointer in global memory. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Cta_tile, + typename Base = Gmem_tile_mma_sd> +struct Gmem_tile_mma_s : public Base { + // The number of mmas in the vertical dimension. + static constexpr int M = Base::MMAS_M; + // The number of mmas in the horizontal dimension. + static constexpr int N = Base::MMAS_N; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_s( + const Params& params, + const Block_info& binfo, + const int tidx) + : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {} + + // Store to global memory. + template + inline __device__ void store( + const float (&softmax)[2 * M][4 * N], + const Mask& mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; + float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; + float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; + float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; + + float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; + float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; + float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; + float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; + + uint4 dst; + dst.x = fmha::float2_to_half2(tmp00, tmp01); + dst.y = fmha::float2_to_half2(tmp02, tmp03); + dst.z = fmha::float2_to_half2(tmp10, tmp11); + dst.w = fmha::float2_to_half2(tmp12, tmp13); + if (mask.is_valid(mi, ni, 0, 0)) { + Base::store(dst, mi, ni); + } + } + } + } + + // Store to global memory. + template + inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + uint4 dst; + dst.x = frag[ni][mi].reg(0); + dst.y = frag[ni][mi].reg(2); + dst.z = frag[ni][mi].reg(1); + dst.w = frag[ni][mi].reg(3); + if (mask.any_valid(mi, ni)) { + Base::store(dst, mi, ni); + } + } + } + } + + // Load from global memory. + template + inline __device__ void load(uint4 (®s)[M][N], const Mask& mask) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + regs[mi][ni] = make_uint4(0, 0, 0, 0); + if (mask.any_valid(mi, ni)) { + Base::load(regs[mi][ni], mi, ni); + } + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile> +struct Gmem_summary_stats { + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = 4; + static constexpr int BYTES_PER_MMA = + (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT; + static constexpr int ROWS = Cta_tile::M; + + // Ctor. + template + inline __device__ Gmem_summary_stats( + void* ptr, + const Params& params, + const int tidx) + : ptr_(reinterpret_cast(ptr)), tidx_(tidx) { + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The distance between two blocks (in bytes). + // size_t block_stride_bytes = params.s * BYTES_PER_ELEMENT; + uint32_t block_stride_bytes = params.s * BYTES_PER_ELEMENT; + + // Set store location for each thread at the beginning of the loop + ptr_row_ = ptr_ + bidx * block_stride_bytes; + ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT; + } + + // Store data to global memory. + inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) { + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + if ((warp == 0) && (lane % 4 == 0)) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg( + ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, + data[mi * 2 + 0]); + fmha::stg( + ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, + data[mi * 2 + 1]); + } + } + } + + // Store data to global memory. + inline __device__ void store_row( + const uint32_t (&data)[MMAS_M], + const int row) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg( + ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]); + } + } + + // Load from global memory. + inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg( + data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg( + data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Load from global memory. + inline __device__ void load_next( + uint32_t (&data)[MMAS_M * 2], + int move_steps = 1) { + char* ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg( + data[mi * 2 + 0], + ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg( + data[mi * 2 + 1], + ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Store data to global memory. + template + inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT); + } + } + + // Move the pointer to the next location. + inline __device__ void move() { + ptr_ += ROWS * BYTES_PER_ELEMENT; + ptr_row_ += ROWS * BYTES_PER_ELEMENT; + } + + // Move the pointer to the next location. + inline __device__ void move(const int steps) { + ptr_ += ROWS * BYTES_PER_ELEMENT * steps; + ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps; + } + + // The pointer. + char* ptr_; + char* ptr_row_; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h b/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h new file mode 100644 index 000000000..27aad1b80 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h @@ -0,0 +1,143 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int S, + int D, + int STEP, + int WARPS_M, + int WARPS_N, + uint32_t FLAGS = 0x08u> +struct FMHA_kernel_traits { + // The CTA description for the 1st GEMM. + using Cta_tile_p = fmha::Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = fmha::Cta_tile_extd; + + // Do we use one buffer for K and V. + static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u; + // Do we keep K in registers. + static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u; + // Do we keep V in registers. + static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u; + + // The global memory tile to load Q. + using Gmem_tile_q = + fmha::Gmem_tile_qkv; + + // The shared memory tile to swizzle Q. + // using Smem_tile_q = fmha::Smem_tile_a; + using Smem_tile_q = + fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = + fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = + fmha::Gmem_tile_qkv; + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = fmha::Gmem_tile_o; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + ; + + // The global memory tile to load/store S. + using Gmem_tile_s = fmha::Gmem_tile_mma_s; + + // The shared memory tile to transpose S. + using Smem_tile_st = fmha::Smem_tile_mma_transposed; + + using Gmem_tile_do = + fmha::Gmem_tile_qkv; + + // The global memory tile to store the softmax sum. + using Gmem_softmax_sum = fmha::Gmem_summary_stats; + + // The shared memory tile to store dp sum. + using Smem_dp_sum = fmha::Smem_tile_dp_sum; + + // Make sure the number of threads match. + static_assert( + (int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, + ""); + + // The number of threads. + static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; + // Make sure the number of threads matches both CTAs. + static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + static constexpr int BYTES_PER_SMEM_QK = + Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE; + // The extra amount of shared memory needed to load V. + static constexpr int BYTES_PER_SMEM_V = + SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE; + // The amount of shared memory needed for Q, K and V.. + static constexpr int BYTES_PER_SMEM_QKV = + BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V; + // The amount of shared memory needed to load Q and store O. + static constexpr int BYTES_PER_SMEM_QO = + Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE; + + // The amount of shared memory needed for Q, K, V and O. + static constexpr int BYTES_PER_SMEM = + fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO); + // Make sure we have enough shared memory. + static_assert( + Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= + BYTES_PER_SMEM, + ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/mask.h b/python/aitemplate/backend/cuda/attention/src/fmha/mask.h new file mode 100644 index 000000000..ec07012af --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/mask.h @@ -0,0 +1,117 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +namespace fmha { + +template +struct Mask { + using Mma_tile = fmha::Hmma_tile; + + template + __device__ Mask( + const BInfo& blockInfo, + int tidx, + const int loop_step_idx_ = 0) + : actual_seqlen(blockInfo.actual_seqlen - loop_step_idx_ * Cta_tile::N), + loop_step_idx(loop_step_idx_) { + const int warp = tidx / Cta_tile::THREADS_PER_WARP; + const int lane = tidx % Cta_tile::THREADS_PER_WARP; + + static_assert(Cta_tile::WARPS_K == 1, ""); + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + row = warp_m * 16 + quad; + col = warp_n * 16 + tid; + } + + inline __device__ bool is_valid( + const int mi, + const int ni, + const int ii, + const int jj) const { + // ii and jj iterate over the 2x4 fragment + // const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + + // ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + const int current_col = + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + const int current_row = row_offset + ii * 8; + const bool col_valid = current_col < actual_seqlen; + // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) + // * 4 + (jj & 1)) < actual_seqlen; + //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; + bool all_valid = Is_causal ? col_valid && + (current_col + loop_step_idx * Cta_tile::N <= current_row) + : col_valid; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("current_col=%d, current_row=%d, actual_seqlen=%d, + // col_valid=%d, all_valid=%d\n", current_col, current_row, + // actual_seqlen, col_valid, all_valid); + // } + return Is_causal ? col_valid && + (current_col + loop_step_idx * Cta_tile::N <= current_row) + : col_valid; + // return row_valid && col_valid; + } + + // BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(const int mi, const int ni) const { + return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0); + } + + inline __device__ void load(const int it) { + row_offset = it * Cta_tile::M + row; + } + int row_offset; + + int row; + int col; + const int loop_step_idx; + const int actual_seqlen; +}; + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h b/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h new file mode 100644 index 000000000..0bb8285d2 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h @@ -0,0 +1,1843 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include "utils.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. + int BYTES_PER_STS_ = 16, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast + // STS/LDS. + int ROWS_PER_XOR_PATTERN_ = 8, + // The number of cols that are used for the XOR swizzling to allow fast + // STS/LDS. + int COLS_PER_XOR_PATTERN_ = 1, + // Use or not predicates + bool USE_PREDICATES_ = true> +struct Smem_tile_without_skews { + // The size in bits of each element. + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + // The size in bytes of a single STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + // The number of elements per STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + // To support arbitrary N, we pad some values to a power-of-2. + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + // The number of bytes per row without packing of rows. + enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; + // The number of bytes per row -- we want at least 128B per row. + enum { BYTES_PER_ROW = Max::VALUE }; + // The number of rows in shared memory (two rows may be packed into a single + // one). + enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; + + // The number of threads per row. + enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; + // The number of threads per row. + enum { + THREADS_PER_ROW = + Min::VALUE + }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + // It must be at least one. + static_assert(STS_PER_ROW >= 1, ""); + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) + static_assert(ROWS_PER_STS >= 1, ""); + // The number of STS needed to store all rows. + enum { STS_PER_COL = Div_up::VALUE }; + // The number of STS in total. + enum { STS = STS_PER_COL * STS_PER_ROW }; + + // TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 + // threads, we only need to store 16 * 64 * 2 = 2KB instead of 4KB. + static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS; + static constexpr int STORING_THREADS = + PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA; + + // The size of one buffer in bytes in shared memory. + // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA + // }; + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS }; + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // Do we enable the LDS.128 fast path? + enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; + static_assert(ENABLE_LDS_FAST_PATH == 0); + // The number of rows that are used for the XOR swizzling to allow fast + // STS/LDS. + enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; + // The number of cols that are used for the XOR swizzling to allow fast + // STS/LDS. + enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; + // Use or not predicates + enum { USE_PREDICATES = USE_PREDICATES_ }; + + // The type of elements that are stored in shared memory by each thread. + using Store_type = typename Uint_from_size_in_bytes::Type; + + // Ctor. + inline __device__ Smem_tile_without_skews(void* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) { + // The row written by a thread. See doc/mma_smem_layout.xlsx. + int smem_write_row = tidx / THREADS_PER_ROW; + + // The XOR pattern. + int smem_write_xor = + smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; + // Compute the column and apply the XOR pattern. + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + + // The offset. + this->smem_write_offset_ = + smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // TODO: Why not merge it with the read offset? + // this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + // this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + int col = ii % STS_PER_ROW; + + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + + // Take the column into account. + if (STS_PER_ROW > 1) { + offset += col * THREADS_PER_ROW * BYTES_PER_STS; + } + + // Apply the XOR pattern if needed. + if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) { + const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; + offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; + } + + // Assemble the final pointer :) + // ptrs[ii] = smem_ + offset + smem_write_buffer_; + // smem_write_buffer_ is already merged with smem_write_offset_ + ptrs[ii] = smem_ + offset; + } + } + + inline __device__ void debug_reset() { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val = 0x0; + sts(val, smem_ + row * BYTES_PER_ROW + col + buffer); + } + } + } + } + } + + // Print the content of the tile (only for debug ;)). + inline __device__ void debug_print() const { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val; + lds(val, smem_ + row * BYTES_PER_ROW + col + buffer); + printf( + "block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", + blockIdx.x, + blockIdx.y, + blockIdx.z, + smem_, + buffer, + row, + col, + val); + } + } + } + } + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + // if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= + // BYTES_PER_TILE_INC_BOUNDARY ) { + // this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + // } else if( BUFFERS_PER_TILE > 1 ) { + // this->smem_read_buffer_ += BYTES_PER_BUFFER; + // } + if (BUFFERS_PER_TILE > 1 && + smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_read_offset_ += BYTES_PER_BUFFER; + } + } + + // Move the read offset to next buffer. TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer() { + this->move_to_next_read_buffer(); + } + + // Move the read offset to next N buffer (circular-buffer). + inline __device__ void move_to_next_read_buffer(int N) { + if (BUFFERS_PER_TILE > 1) { + // this->smem_read_buffer_ += N * BYTES_PER_BUFFER; + // this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? + // BYTES_PER_TILE : 0; + this->smem_read_offset_ += N * BYTES_PER_BUFFER; + this->smem_read_offset_ -= + smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; + } + } + + // Move the read offset to next N buffer (circular-buffer). TODO: Remove this + // member function!!! + inline __device__ void move_next_read_buffer(int N) { + this->move_to_next_read_buffer(N); + } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= + // BYTES_PER_TILE_INC_BOUNDARY ) { + // this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + // } else if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_buffer_ += BYTES_PER_BUFFER; + // } + if (BUFFERS_PER_TILE > 1 && + smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += BYTES_PER_BUFFER; + } + } + + // Move the write offset to next buffer. TODO: Remove that member function! + inline __device__ void move_next_write_buffer() { + this->move_to_next_write_buffer(); + } + + // Move the read offset. + inline __device__ void move_read_offset(int delta) { + this->smem_read_offset_ += delta; + } + + // Move the write offset. + inline __device__ void move_write_offset(int delta) { + this->smem_write_offset_ += delta; + } + + // Store to the tile in shared memory. + template + inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + // Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per + // buffer. + if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) { + sts(smem_ptrs, data); + } + } + + // Store to the tile in shared memory. + template + inline __device__ void store( + const Store_type (&data)[N], + uint32_t (&preds)[M], + uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store( + const Store_type (&data)[N], + uint32_t preds, + uint64_t = 0) { + this->store(data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store( + const void* (&gmem_ptrs)[N], + uint32_t preds, + uint64_t = 0) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + // The shared memory pointer. + const uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + // int smem_read_buffer_; + // The buffer base offset for write. + // int smem_write_buffer_; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + // The potential mask. + enum { HALF = MMAS_K_WITH_PADDING / 2 }; + // The remainder. + enum { MOD = MMAS_K % HALF }; + // The final value. + enum { + VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + enum { VALUE = MMAS_K - 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_a { + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE> +struct Smem_tile_row_a : public Smem_tile_without_skews< + Cta_tile, + Cta_tile::M, + Cta_tile::K, + fmha::BITS_PER_ELEMENT_A, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + 1> { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews< + Cta_tile, + Cta_tile::M, + Cta_tile::K, + fmha::BITS_PER_ELEMENT_A, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + 1>; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); + static_assert( + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 8); + + // The row and column read by the thread. + int smem_read_row = (tidx & 0x0f); + constexpr int ROWS_PER_PACKING = + Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = + ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * + Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = + smem_read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple + // matrix rows). + int offset = + mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + // ldsm(tmp, this->smem_ + this->smem_read_offset_ + + // this->smem_read_buffer_ + offset); + ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); + + // Store the value into the fragment. + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_row_a { + // The base class. + using Base = Smem_tile_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_b { + // The size in bits. + enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE> +struct Smem_tile_col_b : public Smem_tile_without_skews< + Cta_tile, + Cta_tile::N, + Cta_tile::K, + fmha::BITS_PER_ELEMENT_B, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + 1> { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews< + Cta_tile, + Cta_tile::N, + Cta_tile::K, + fmha::BITS_PER_ELEMENT_B, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + 1>; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = fmha::Hmma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // The number of STS per thread + enum { + STS_PER_THREAD_ = + Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA + }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert( + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 8); + static_assert(WARPS_M == 1); + static_assert(WARPS_N == 4 || WARPS_N == 8); + static_assert(WARPS_K == 1); + + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + + (tidx & 0x07) + (tidx & 0x10) / 2; + constexpr int ROWS_PER_PACKING = + Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = + ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * + Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & 0x08) / 8; + // The shared memory offset. + this->smem_read_offset_ = + smem_read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple + // matrix rows). + int offset = + ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + // ldsm(tmp, this->smem_ + this->smem_read_offset_ + + // this->smem_read_buffer_ + offset); + ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); + + // Store the value into the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_col_b { + // The base class. + using Base = Smem_tile_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, + // How many cols to use for the XOR pattern to avoid bank conflicts? + int COLS_PER_XOR_PATTERN_ = 1> +struct Smem_tile_row_b : public Smem_tile_without_skews< + Cta_tile, + Cta_tile::K, + Cta_tile::N, + fmha::BITS_PER_ELEMENT_B, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + COLS_PER_XOR_PATTERN_> { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The base class. + using Base = Smem_tile_without_skews< + Cta_tile, + Cta_tile::K, + Cta_tile::N, + fmha::BITS_PER_ELEMENT_B, + BYTES_PER_STS, + BUFFERS_PER_TILE, + 0, + ROWS_PER_XOR_PATTERN_, + COLS_PER_XOR_PATTERN_>; + // The fragment. + using Fragment = Fragment_b; + + // Can we use LDSM? No if the data type is 32-bit large. + enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; + // The number of elements per LDS. + enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; + + // The number of STS per thread + enum { + STS_PER_THREAD_ = + Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA + }; + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_row_b(void* smem, int tidx) : Base(smem, tidx) { + // The number of warps. + const int WARPS_M = Cta_tile::WARPS_M; + const int WARPS_N = Cta_tile::WARPS_N; + const int WARPS_K = Cta_tile::WARPS_K; + static_assert(WARPS_K == 1); + static_assert(WARPS_M == 4 || WARPS_M == 8); + static_assert(WARPS_N == 1); + + // The masks to select the warps. + const int WARP_MASK_N = Warp_masks::N; + const int WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + static_assert(USE_LDSMT); + static_assert( + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 8); + + // The row/col read by the thread. + int smem_read_row = + (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + + (tidx & 0x07) + (tidx & 0x08); + constexpr int ROWS_PER_PACKING = + Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + int smem_read_col = + ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * + Base::COLS_PER_XOR_PATTERN; + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = + smem_read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + + smem_read_col * BYTES_PER_LDS; + + // Fill zeroes for group conv + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = + Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && + Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + // The size of each element in bits. + const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + const int BYTES_PER_MMA_PER_CTA = + Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Prepare the offset. + int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * + Base::BYTES_PER_ROW_BEFORE_PACKING; + if (BYTES_PER_MMA_PER_CTA == 32) { + offset += this->smem_read_offset_; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + offset += + this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2; + } else { + offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA; + } + + // Load the data using LDSM.MT88.2. + // uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; + uint32_t ptr = this->smem_ + offset; + uint4 tmp; + if (USE_LDSMT) { + ldsmt(tmp, ptr); + } else { + lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW_BEFORE_PACKING); + lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW_BEFORE_PACKING); + } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", + // BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og); + // } + // Store those values in the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not + // recompute those. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8) { + this->smem_read_offset_ ^= + BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && + Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_row_b { + // The base class. + using Base = Smem_tile_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v : public fmha::Smem_tile_without_skews< + Cta_tile, + Cta_tile::K, + Cta_tile::N, + 16, + 16, + 1, + 0, + Rows_per_xor_pattern_col_b::VALUE, + 1> { + // The base class. + using Base = Smem_tile_without_skews< + Cta_tile, + Cta_tile::K, + Cta_tile::N, + 16, + 16, + 1, + 0, + Rows_per_xor_pattern_col_b::VALUE, + 1>; + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The fragment. + using Fragment = Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // The row/col read by the thread. + int read_row, read_col; + + static_assert( + Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && + (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + + read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); + constexpr int ROWS_PER_PACKING = + Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; + read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * + Base::COLS_PER_XOR_PATTERN; + read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by 16 * #warps row. + int row = ki * 16 * Cta_tile::WARPS_K; + + // Load the data using LDSM.MT88.2. + uint4 tmp; + fmha::ldsmt( + tmp, + this->smem_ + this->smem_read_offset_ + + row * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + // Move the pointer for the next ni. I expect the compiler to not + // recompute those. + if (Mma_tile::MMAS_N == 1) { + // noop + } else if (Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (Mma_tile::MMAS_N == 8) { + this->smem_read_offset_ ^= + BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); + } else { + assert(false); // Not implemented! + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o { + // The MMA tile. + using Mma_tile = fmha::Hmma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The accumulators. + using Data_type = typename Accumulator::Data_type; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type); + // The size of each STS. + static constexpr int BYTES_PER_STS = 8; + // The size of each row in shared memory. + static constexpr int BYTES_PER_ROW = + Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT; + + // The size of each LDS. + static constexpr int BYTES_PER_LDS = 16; + static constexpr int THREADS_PER_ROW = + Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS; + + // The number of rows. + static constexpr int ROWS = Cta_tile::M; + // The number of "rows" to process per loop iteration (in the "epilogue"). + static constexpr int ROWS_PER_LOOP = + ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; + // The number of outer loops. + static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; + // Make sure it matches our expectations. + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of rows loaded per LDS. + static constexpr int ROWS_PER_LDS = + Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; + // Do we have to guard against partial writes/reads. + static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0; + // The total number of LDS per loop. + static constexpr int LDS_PER_LOOP = + fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS); + + // The amount of shared memory. + static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW; + + // The write pointer. + uint32_t smem_write_, smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; + + // static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); + static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + static_assert( + Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && + (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); + static_assert( + Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || + Cta_tile::N == 128); + + int write_row = (tidx & 0x1c) / 4; + + const int lane = tidx % 32; + const int warp = tidx / 32; + + constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT; + constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS; + int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP; + + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("write_row = %d, write_col = %d\n", write_row, write_col); + // } + + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && + // (write_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= + 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8))); + // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? + // 4 : (Cta_tile::N == 128 ? 16 : 8)))); + + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("read_row = %d, read_col = %d\n", read_row, read_col); + // } + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && + // (read_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } + // Assemble the read pointer. + this->smem_read_ = + smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = + read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + template + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + + jj * Cta_tile::N * BYTES_PER_ELEMENT; + uint32_t smem_read = this->smem_read_ + imm; + // TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's + // a better way. + if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) { + smem_read ^= 8 * BYTES_PER_LDS; + } + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("imm diff = %d\n", smem_read - this->smem_read_); + // } + if (!HAS_INCOMPLETE_LDS || + (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_)) { + // fmha::lds(tmp[jj], this->smem_read_ + imm); + fmha::lds(tmp[jj], smem_read); + } + } + + // Perform the reduction. + out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]); +// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("out reduction: out = %.6f\n", reinterpret_cast(out[ii])[0]); +// } +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("out reduction tmp = %.6f, out = %.6f\n", + // reinterpret_cast(tmp[jj])[0], + // reinterpret_cast(out[ii])[0]); + // } + } + } + } + + // Store the accumulators. + template + inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { + // uint32_t smem_write_og = this->smem_write_; + static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS; + +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - + // smem_write_og); + // } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + +// Store 2nd column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - + // smem_write_og); + // } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of + // 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 8, "Not implemented"); + if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + this->smem_write_ ^= 15 * 32; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + this->smem_write_ ^= 7 * 32; + } else if (Mma_tile::MMAS_N >= 2) { + this->smem_write_ ^= 3 * 32; + } else { + this->smem_write_ ^= 3 * 32; + } + // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_mma { + using Mma_tile = fmha::Hmma_tile; + using Fragment = fmha::Fragment_a; + + enum { COLS = Cta_tile::N }; + enum { BYTES_PER_ELT = 2 }; + enum { BYTES_PER_STS = 4 }; + enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO + enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K == 1); + inline __device__ Smem_tile_mma(char* smem, int tidx) { + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + int write_col, write_row; + static_assert( + WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || + (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1); + if (WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); + write_col ^= (write_row & 0x07) * 4; + } else { + write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x03); + // write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW + // == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4; + write_col ^= (write_row & + (BYTES_PER_ROW == 32 + ? 0x01 + : (BYTES_PER_ROW == 64 + ? 0x03 + : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * + 4; + } + + // write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + static_assert(COLS == Cta_tile::N); +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + // ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::sts(smem_ + offset + 0 * + // BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(smem_ + offset + 8 * + // BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; + // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni + // * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } + } + } + + template + inline __device__ void store(const Fragment (&frag)[N][M]) { + static_assert(COLS == Cta_tile::N); + uint4 regs[M][N]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // Need to transpose ref(1) and reg(2) here since when we load it we + // transpose again. + regs[mi][ni] = make_uint4( + frag[ni][mi].reg(0), + frag[ni][mi].reg(2), + frag[ni][mi].reg(1), + frag[ni][mi].reg(3)); + } + } + this->store(regs); + } + + // uint32_t smem_; + // uint32_t write_offset_; + uint32_t smem_write_; +}; + +template > +struct Smem_tile_mma_transposed : public Base { + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + using Fragment = typename Base::Fragment; + inline __device__ Smem_tile_mma_transposed(char* smem, int tidx) + : Base(smem, tidx) { + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + int read_row, read_col; + read_row = (tidx & 0x0f); + read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; + + // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : + // (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : + // 0x0f)))); + read_col ^= (read_row & 0x07); + // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + template + inline __device__ void load(Fragment (&frag)[M][N]) { + static_assert(Base::COLS == Cta_tile::N); + for (int mi = 0; mi < M; mi++) { + for (int ni = 0; ni < N; ni++) { + // size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni + // * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + // fmha::ldsmt(dst, this->smem_ + offset); + // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * + // WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::ldsmt(dst, offset); + frag[mi][ni].reg(0) = dst.x; + frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! + frag[mi][ni].reg(2) = dst.y; + frag[mi][ni].reg(3) = dst.w; + } + } + } + + // uint32_t read_offset_; + uint32_t smem_read_; +}; + +template > +struct Smem_tile_mma_epilogue : public Base { + enum { BYTES_PER_LDS = 16 }; + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; + enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; + static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); + enum { WARPS_M = Base::WARPS_M }; + enum { WARPS_N = Base::WARPS_N }; + static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + + using Acc = fmha::Fragment_accumulator; + + inline __device__ Smem_tile_mma_epilogue(char* smem, int tidx) + : Base(smem, tidx) { + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + const int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : + // (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07))); + static_assert( + Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || + Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256); + read_col ^= + (read_row & + (Base::BYTES_PER_ROW == 32 + ? 0x01 + : (Base::BYTES_PER_ROW == 64 + ? 0x03 + : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07)))); + // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void load(uint4 (&data)[NUM_LDS]) { + for (int ii = 0; ii < NUM_LDS; ii++) { + // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + // fmha::lds(data[ii], this->smem_ + offset); + // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + fmha::lds(data[ii], offset); + } + } + + template + inline __device__ void store(const Acc (&acc)[M][N]) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // 1st row - 4 elements per row. + float tmp00 = acc[mi][ni].elt(0); + float tmp01 = acc[mi][ni].elt(1); + float tmp02 = acc[mi][ni].elt(4); + float tmp03 = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + float tmp10 = acc[mi][ni].elt(2); + float tmp11 = acc[mi][ni].elt(3); + float tmp12 = acc[mi][ni].elt(6); + float tmp13 = acc[mi][ni].elt(7); + + uint32_t x = fmha::float2_to_half2(tmp00, tmp01); + uint32_t y = fmha::float2_to_half2(tmp02, tmp03); + uint32_t z = fmha::float2_to_half2(tmp10, tmp11); + uint32_t w = fmha::float2_to_half2(tmp12, tmp13); + + // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 + // * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, + // x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); offset ^= + // 4 * Base::BYTES_PER_STS; fmha::sts(this->smem_ + offset + 0 * + // BYTES_PER_ROW, y); fmha::sts(this->smem_ + offset + 8 * + // BYTES_PER_ROW, w); size_t offset = (this->smem_write_ ^ (ni * 32)) + + // mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, + // offset - this->smem_write_); + // } + fmha::sts(offset + 0 * BYTES_PER_ROW, x); + fmha::sts(offset + 8 * BYTES_PER_ROW, z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, y); + fmha::sts(offset + 8 * BYTES_PER_ROW, w); + } + } + } + + template + inline __device__ void store(const uint4 (®s)[M][N]) { + for (int mi = 0; mi < M; mi++) { + for (int ni = 0; ni < N; ni++) { + // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 + // * BYTES_PER_ROW; + uint32_t offset = (this->write_offset_ ^ (ni * 32)) + + mi * WARPS_M * 16 * BYTES_PER_ROW; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + offset ^= 4 * Base::BYTES_PER_STS; + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + } + } + } + + // uint32_t read_offset_; + uint32_t smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_transpose { + using Mma_tile = fmha::Hmma_tile; + using Fragment_write = fmha::Fragment_b; + using Fragment_read = fmha::Fragment_b; + + enum { COLS = Cta_tile::N }; + enum { BYTES_PER_ELT = 2 }; + enum { BYTES_PER_STS = 4 }; + enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO + enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; + + enum { BYTES_PER_LDS = 16 }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + enum { WARPS_N = Cta_tile::WARPS_N }; + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K == 1); + static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); + + inline __device__ Smem_tile_transpose(char* smem, int tidx) { + smem_ = __nvvm_get_smem_pointer(smem); + // uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + int write_col, write_row; + static_assert( + WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || + (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); + if (WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); + } else { + write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x03); + } + write_col ^= (write_row & 0x07) * 4; + + write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * + // BYTES_PER_STS; + + int read_row, read_col; + read_row = (tidx & 0x0f); + read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; + + read_col ^= (read_row & 0x07); + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + template + inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + } + } + + template + inline __device__ void load(Fragment_read (&frag_r)[N]) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + fmha::ldsmt(dst, this->smem_ + offset); + frag_r[ni].reg(0) = dst.x; + frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! + frag_r[ni].reg(2) = dst.z; + frag_r[ni].reg(3) = dst.w; + } + } + + template + inline __device__ void transpose( + const Fragment_write (&frag_w)[M][N], + Fragment_read (&frag_r)[M], + int mi) { + static_assert(COLS == Cta_tile::N); +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + } +#pragma unroll + for (int ni = 0; ni < N; ni++) { + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint4 dst; + fmha::ldsmt(dst, this->smem_ + offset); + frag_r[ni].reg(0) = dst.x; + frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! + frag_r[ni].reg(2) = dst.z; + frag_r[ni].reg(3) = dst.w; + } + } + + uint32_t smem_; + uint32_t write_offset_; + uint32_t read_offset_; + // uint32_t smem_write_; + // uint32_t smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gmem_tile, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1> +struct Smem_tile_dp_sum { + using Cta_tile = typename Gmem_tile::Cta_tile; + using Mma_tile = fmha::Hmma_tile; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = 4; + static constexpr int ROWS = Gmem_tile::ROWS; + static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW; + static constexpr int MMAS_M = Mma_tile::MMAS_M; + + static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG; + static constexpr int LDGS = Gmem_tile::LDGS; + + static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA; + + // The size of one buffer in bytes in shared memory. + static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT; + // The number of buffers. + static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_; + // The size in bytes of total buffers. + static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE; + // The boundary for smem_read_offset and smem_write_offset increment. + static constexpr int ROWS_PER_TILE_INC_BOUNDARY = + ROWS * BUFFERS_PER_TILE - ROWS; + + inline __device__ Smem_tile_dp_sum(float* smem, const int tidx) + : smem_(smem), + smem_read_buffer_(smem), + smem_write_buffer_(smem), + tidx_(tidx) {} + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + if (BUFFERS_PER_TILE > 1 && + (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY) { + this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += ROWS; + } + } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + if (BUFFERS_PER_TILE > 1 && + (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY) { + this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_write_buffer_ += ROWS; + } + } + + inline __device__ void store(const float (&sum)[LDGS]) { + if (tidx_ % THREADS_PER_ROW == 0) { + int row = tidx_ / THREADS_PER_ROW; +#pragma unroll + for (int i = 0; i < LDGS; ++i) { + if (row + i * ROWS_PER_LDG < ROWS) { + smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i]; + } + } + } + } + + inline __device__ void store(const float sum, const int buffer_idx) { + float* smem_write = smem_ + buffer_idx * ROWS; + int row = tidx_ / THREADS_PER_ROW; + if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) { + smem_write[row] = sum; + } + } + + inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) { + float* smem_write = smem_ + buffer_idx * ROWS; + if (tidx_ % THREADS_PER_ROW == 0) { + int row = tidx_ / THREADS_PER_ROW; +#pragma unroll + for (int i = 0; i < LDGS; ++i) { + if (row + i * ROWS_PER_LDG < ROWS) { + smem_write[row + i * ROWS_PER_LDG] = sum[i]; + } + } + } + } + + inline __device__ void store_pair( + const float (&sum)[MMAS_M * 2], + const int buffer_idx) { + float* smem_write = smem_ + buffer_idx * ROWS; + // Extract the position in the warp. + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + int row = lane / 4; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; + smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; + } + } + + template + inline __device__ void load(float (&sum)[N], const int (&row)[N]) { +#pragma unroll + for (int ni = 0; ni < N; ni++) { + sum[ni] = smem_read_buffer_[row[ni]]; + } + } + + template + inline __device__ void load( + float (&sum)[N], + const int (&row)[N], + const int buffer_idx) { + float* smem_read = smem_ + buffer_idx * ROWS; +#pragma unroll + for (int ni = 0; ni < N; ni++) { + sum[ni] = smem_read[row[ni]]; + } + } + + static inline __device__ float reduce_warp(float sum) { + fmha::SumOp sum_op; + return fmha::Allreduce::run(sum, sum_op); + } + + const int tidx_; + float* const smem_; + float* smem_read_buffer_; + float* smem_write_buffer_; +}; + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h b/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h new file mode 100644 index 000000000..02e82c427 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h @@ -0,0 +1,708 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Sum_ { + static constexpr bool IS_SUM = true; + static inline __device__ float apply(float x, float y) { + return x + y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Max_ { + static constexpr bool IS_SUM = false; + static inline __device__ float apply(float x, float y) { + return x > y ? x : y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp_(float x, float max) { + return __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp2_(float x, float max) { + return exp2f(x - max); + // With fast-math, this produces the same PTX instruction as the assembly + // below float diff = x - max; float res; asm ("ex2.approx.ftz.f32 %0, + // %1;\n\t" : "=f"(res) : "f"(diff)); return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReadType {}; +template <> +struct ReadType<4> { + using T = float; +}; +template <> +struct ReadType<8> { + using T = float2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_reduce { + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static_assert(COLS == 4 || COLS == 8); + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + static constexpr int THREADS_PER_GROUP = + Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; + // TD [2022-05-02]: No longer true if head_dim != 64 + // static_assert(THREADS_PER_GROUP == 16); // DEBUG + static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; + static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; + static_assert(LOOPS == 1); + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float* smem_, const int tidx) { + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast( + smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + smem_read_row_ = + &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qid_]; + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if (qid_ == 0) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } + } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; + } + } + + __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * 4; + frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4]; + } + } + + int qid_; + float* smem_write_; + read_t* smem_read_; + read_t* smem_read_row_; +}; + +template +struct Softmax_base { + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + // The number of groups of warp such that we have at most 4 warps writing + // consecutive elements. + static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4); + // The number of elements that we are going to store per row. + static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS; + // The number of rows. + static constexpr int ROWS = Cta_tile::M * GROUPS; + // The total number of elements. + static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW; + + // Ctor. + template + inline __device__ Softmax_base(const Params& params, void* smem, int tidx) + : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + smem_(reinterpret_cast(smem)), + tidx_(tidx) { + // Move to the 1st mask loaded by the thread+ tidx; + // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * + // sizeof(uint32_t); + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = + warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + template + inline __device__ void apply_mask(const Mask& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + } + } + } + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + constexpr float kLog2e = M_LOG2E; + const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + // elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + elt_[mi][ni] = apply_exp2_( + elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e, max_base2); + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void scale_apply_exp( + const float (&max)[MMAS_M * 2], + const float scale_) { + const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E; + const float scale = scale_ * M_LOG2E; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + const float max_scaled = max[mi] * max_scale; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled); + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + constexpr float kLog2e = M_LOG2E; + const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); + } + } + } + // inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) { + // constexpr float kLog2e = M_LOG2E; + // #pragma unroll + // for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + // float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * + // kLog2e; max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + // + threadIdx.x % 8); #pragma unroll for( int mi = 0; mi < MMAS_M * + // 2; ++mi ) { + // elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); + // } + // } + // } + + template + inline __device__ void apply_dropout(Philox& ph, uint32_t p_dropout_in_uint) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; mi++) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ni++) { + uint4 tmp = ph(); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, + // tmp.z, tmp.w); + // } + elt_[mi][4 * ni + 0] = + encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]); + elt_[mi][4 * ni + 1] = + encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]); + elt_[mi][4 * ni + 2] = + encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]); + elt_[mi][4 * ni + 3] = + encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]); + } + } + } + + template + inline __device__ void apply_dropout( + Philox& ph0, + Philox& ph1, + uint32_t p_dropout_in_uint) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; mi++) { + static_assert(MMAS_N % 2 == 0); +#pragma unroll + for (int ni = 0; ni < MMAS_N; ni += 2) { + uint4 tmp = ph0(); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x, + // tmp.y, tmp.z, tmp.w); + // } + elt_[mi][4 * ni + 0] = + encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]); + elt_[mi][4 * ni + 1] = + encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]); + elt_[mi][4 * ni + 2] = + encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]); + elt_[mi][4 * ni + 3] = + encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]); + tmp = ph1(); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x, + // tmp.y, tmp.z, tmp.w); + // } + elt_[mi][4 * (ni + 1) + 0] = encode_dropout( + tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]); + elt_[mi][4 * (ni + 1) + 1] = encode_dropout( + tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]); + elt_[mi][4 * (ni + 1) + 2] = encode_dropout( + tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]); + elt_[mi][4 * (ni + 1) + 3] = encode_dropout( + tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]); + } + } + } + + template + inline __device__ void apply_dropout_16bits( + Philox& ph, + uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ni++) { + uint16_t tmp[8]; + fmha::uint4_to_ushort8(ph(), tmp); +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, +// tmp.w); +// } +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = encode_dropout( + tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, + elt_[mi * 2 + ii][4 * ni + jj]); + } + } + } + } + } + + template + inline __device__ void apply_dropout_16bits( + Philox& ph0, + Philox& ph1, + uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + static_assert(MMAS_N % 2 == 0); +#pragma unroll + for (int ni = 0; ni < MMAS_N; ni += 2) { + uint16_t tmp[8]; + fmha::uint4_to_ushort8(ph0(), tmp); +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, +// tmp.w); +// } +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = encode_dropout( + tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, + elt_[mi * 2 + ii][4 * ni + jj]); + } + } + fmha::uint4_to_ushort8(ph1(), tmp); +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, +// tmp.w); +// } +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * (ni + 1) + jj] = encode_dropout( + tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, + elt_[mi * 2 + ii][4 * (ni + 1) + jj]); + } + } + } + } + } + + // Scale all the elements. + inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes + // a huge deal. + float inv_sum[MMAS_M * 2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + inv_sum[mi] = + (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Subtract all elements by dp_sum + inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] -= dp_sum[mi]; + } + } + } + + // The pointer to the mask. + const char* packed_mask_ptr_; + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax : public Softmax_base { + // The base class. + using Base = Softmax_base; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + static_assert(Fragment_a::NUM_REGS == 4); + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + // The MMAs. + static constexpr int MMAS_M = Base::MMAS_M; + static constexpr int MMAS_N = Base::MMAS_N; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + using Accumulator_out = Fragment; + static_assert(Accumulator_out::NUM_REGS == 4); + + static_assert(std::is_same::value); + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + // Ctor. + template + inline __device__ Softmax(const Params& params, void* smem, int tidx) + : Base(params, smem, tidx), + params_scale_bmm1_(params.scale_bmm1), + smem_sum_(static_cast(smem), tidx), + smem_max_( + static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, + tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { + const float scalef = + reinterpret_cast(this->params_scale_bmm1_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack_noscale( + const Accumulator (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); + } + } + } + + template + __device__ inline void thread_reduce_( + float (&frag)[2 * MMAS_M], + Operator& op) { +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + frag[mi] = + zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]); +#pragma unroll + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } + } + } + + template + __device__ inline void reduce_( + float (&frag)[2 * MMAS_M], + Operator& op, + Smem_tile_red& smem_red) { + thread_reduce_(frag, op); + quad_reduce(frag, frag, op); + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + quad_allreduce(frag, tmp, op); + } + + template + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) { + MaxOp max; + reduce_(frag, max, smem_max_); + } + + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) { + SumOp sum; + reduce_(frag, sum, smem_sum_); + } + + template + __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]) { + SumOp sum; + thread_reduce_(frag, sum); + quad_reduce(frag, frag, sum); + smem_sum_.store(frag); + } + + template + __device__ inline void reduce_after_sync_( + float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS], + Operator& op, + Smem_tile_red& smem_red) { +#pragma unroll + for (int ii = 0; ii < NROWS; ii++) { + typename Smem_tile_red::read_t tmp[MMAS_M]; + smem_red.load_row(tmp, rows[ii]); + quad_allreduce(frag[ii], tmp, op); + } + } + + template + __device__ inline void reduce_sum_after_sync_( + float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]) { + SumOp sum; + reduce_after_sync_(frag, rows, sum, smem_sum_); + } + + template + __device__ inline void reduce_max_after_sync_( + float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]) { + MaxOp max; + reduce_after_sync_(frag, rows, max, smem_max_); + } + + const uint32_t params_scale_bmm1_; + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/utils.h b/python/aitemplate/backend/cuda/attention/src/fmha/utils.h new file mode 100644 index 000000000..7bc0b3df9 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha/utils.h @@ -0,0 +1,1332 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Row {}; +struct Col {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Next_power_of_two {}; + +template +struct Next_power_of_two { + enum { VALUE = M }; +}; +template <> +struct Next_power_of_two<3, false> { + enum { VALUE = 4 }; +}; +template <> +struct Next_power_of_two<5, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<6, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<7, false> { + enum { VALUE = 8 }; +}; +template <> +struct Next_power_of_two<9, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<10, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<11, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<12, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<13, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<14, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<15, false> { + enum { VALUE = 16 }; +}; +template <> +struct Next_power_of_two<24, false> { + enum { VALUE = 32 }; +}; +template <> +struct Next_power_of_two<48, false> { + enum { VALUE = 64 }; +}; +template <> +struct Next_power_of_two<80, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<96, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<112, false> { + enum { VALUE = 128 }; +}; +template <> +struct Next_power_of_two<144, false> { + enum { VALUE = 256 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Prev_power_of_two {}; + +template +struct Prev_power_of_two { + enum { VALUE = N }; +}; +template <> +struct Prev_power_of_two<3, false> { + enum { VALUE = 2 }; +}; +template <> +struct Prev_power_of_two<5, false> { + enum { VALUE = 4 }; +}; +template <> +struct Prev_power_of_two<6, false> { + enum { VALUE = 4 }; +}; +template <> +struct Prev_power_of_two<7, false> { + enum { VALUE = 4 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Div_up { + enum { VALUE = (M + N - 1) / N }; +}; + +constexpr int DivUpConstexpr(int M, int N) { + return (M + N - 1) / N; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max { + enum { VALUE = A >= B ? A : B }; +}; + +constexpr int MaxConstexpr(int A, int B) { + return A >= B ? A : B; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max_3 { + enum { VALUE = Max::VALUE, C>::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Min { + enum { VALUE = A <= B ? A : B }; +}; + +constexpr int MinConstexpr(int A, int B) { + return A <= B ? A : B; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Uint_from_size_in_bytes {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<2> { + using Type = uint16_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<4> { + using Type = uint32_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<8> { + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<16> { + using Type = uint4; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Warp_masks {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Warp_masks<8, 1, 1> { + enum { M = 0xe0, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<4, 2, 1> { + enum { M = 0x60, N = 0x80, K = 0x00 }; +}; +template <> +struct Warp_masks<4, 1, 2> { + enum { M = 0x60, N = 0x00, K = 0x80 }; +}; +template <> +struct Warp_masks<4, 1, 1> { + enum { M = 0x60, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 4, 1> { + enum { M = 0x20, N = 0xc0, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 2, 2> { + enum { M = 0x20, N = 0x40, K = 0x80 }; +}; +template <> +struct Warp_masks<2, 2, 1> { + enum { M = 0x20, N = 0x40, K = 0x00 }; +}; +template <> +struct Warp_masks<2, 1, 2> { + enum { M = 0x20, N = 0x00, K = 0x40 }; +}; +template <> +struct Warp_masks<2, 1, 1> { + enum { M = 0x20, N = 0x00, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 8, 1> { + enum { M = 0x00, N = 0xe0, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 4, 2> { + enum { M = 0x00, N = 0x60, K = 0x80 }; +}; +template <> +struct Warp_masks<1, 4, 1> { + enum { M = 0x00, N = 0x60, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 2, 2> { + enum { M = 0x00, N = 0x20, K = 0x40 }; +}; +template <> +struct Warp_masks<1, 2, 1> { + enum { M = 0x00, N = 0x20, K = 0x00 }; +}; +template <> +struct Warp_masks<1, 1, 4> { + enum { M = 0x00, N = 0x00, K = 0x60 }; +}; +template <> +struct Warp_masks<1, 1, 2> { + enum { M = 0x00, N = 0x00, K = 0x20 }; +}; +template <> +struct Warp_masks<1, 1, 1> { + enum { M = 0x00, N = 0x00, K = 0x00 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int clz(int x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) { + return 31 - i; + } + } + return 32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int find_log_2(int x, bool round_up = false) { + int a = 31 - clz(x); + if (round_up) { + a += (x & (x - 1)) ? 1 : 0; + } + return a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { + // uint32_t c; + // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + // return c; + __half2 result = __hmul2( + reinterpret_cast(a), + reinterpret_cast(b)); + return reinterpret_cast(result); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmul4(uint2 a, uint2 b) { + uint2 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint4 a, uint4 b) { + uint4 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { + uint32_t res; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); +#else + const uint32_t zero = 0u; + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} +static inline __device__ uint32_t habs2(uint32_t x) { + uint32_t res; + asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +template +static inline __device__ T clamp(T x, T lb, T ub) { + return x < lb ? lb : (x > ub ? ub : x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t clamp_to_zero(uint16_t x) { + uint16_t mask; + asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); + return mask & x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t float_to_half(float f) { + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t lo = float_to_half(a); + uint16_t hi = float_to_half(b); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_half2(float a) { + return float2_to_half2(a, a); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(const float2& f) { + return float2_to_half2(f.x, f.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 +float4_to_half4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t +hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + d = hrelu2(hfma2(a, b, c)); +#endif + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h0_h0(uint32_t x) { + uint32_t y; + asm volatile( + "{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float h0_to_float(uint32_t h2) { + float f; + asm volatile( + "{\n" + ".reg .f16 lo, hi;\n" + "mov.b32 {lo, hi}, %1;\n" + "cvt.f32.f16 %0, lo;\n" + "}\n" + : "=f"(f) + : "r"(h2)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h1_h1(uint32_t x) { + uint32_t y; + asm volatile( + "{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { + return hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd4(uint2 a, uint2 b) { + uint2 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd(uint2 a, uint2 b) { + return hadd4(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd8(uint4 a, uint4 b) { + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two half2's into float, then take their dot product. +// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const +// __half2 b) { +static inline __device__ float hfma2_to_float( + const __half2 a, + const __half2 b) { + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + return af.x * bf.x + af.y * bf.y; + // sum += af.x * bf.x + af.y * bf.y; + // sum = __fmaf_rn(sum, af.x, bf.x); + // sum = __fmaf_rn(sum, af.y, bf.y); + // float2 prod = __half22float2(__hmul2(a, b)); + // sum += prod.x + prod.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's into float, then take their dot product. +static inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = fmha::hfma2_to_float( + reinterpret_cast(a.x), + reinterpret_cast(b.x)); + sum += fmha::hfma2_to_float( + reinterpret_cast(a.y), + reinterpret_cast(b.y)); + sum += fmha::hfma2_to_float( + reinterpret_cast(a.z), + reinterpret_cast(b.z)); + sum += fmha::hfma2_to_float( + reinterpret_cast(a.w), + reinterpret_cast(b.w)); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fadd4(uint4 a, uint4 b) { + float4 c; + c.x = + reinterpret_cast(a.x) + reinterpret_cast(b.x); + c.y = + reinterpret_cast(a.y) + reinterpret_cast(b.y); + c.z = + reinterpret_cast(a.z) + reinterpret_cast(b.z); + c.w = + reinterpret_cast(a.w) + reinterpret_cast(b.w); + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fmul4(uint4 a, float b) { + float4 c; + c.x = reinterpret_cast(a.x) * b; + c.y = reinterpret_cast(a.y) * b; + c.z = reinterpret_cast(a.z) * b; + c.w = reinterpret_cast(a.w) * b; + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd(uint4 a, uint4 b) { + return hadd8(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 half2_to_float2(uint32_t x) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) { + float2 tmp = half2_to_float2(h); + x = tmp.x; + y = tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { + uint16_t d; + asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void uint4_to_ushort8( + const uint4 a, + uint16_t (&b)[8]) { + uint32_t* b_tmp = reinterpret_cast(&b[0]); + b_tmp[0] = a.x; + b_tmp[1] = a.y; + b_tmp[2] = a.z; + b_tmp[3] = a.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float sigmoid(float x) { + return 1.f / (1.f + expf(-x)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint16_t& dst) { + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint32_t& dst) { + dst = 0u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint2& dst) { + dst = make_uint2(0u, 0u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint4& dst) { + dst = make_uint4(0u, 0u, 0u, 0u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// P R E D I C A T E P A C K I N G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum { + BYTES_PER_REG = 4, + PREDS_PER_BYTE = 4, + PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// G E N E R I C P R E D I C A T E D L D G S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_(Functor& fct, const uint32_t (&preds)[M]) { + // The number of complete bytes (where we use all the predicates in a byte). + enum { COMPLETE = N / PREDS_PER_BYTE }; + // Make sure we did allocate enough predicates. + static_assert(Div_up::VALUE <= M, ""); + // The remainder. + enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + +// Clear the fetch registers. +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + fct.clear(ii); + } + + // Run complete steps. + bool p[PREDS_PER_BYTE]; +#pragma unroll + for (int ii = 0; ii < COMPLETE; ++ii) { + // The predicate. + uint32_t reg = preds[ii / BYTES_PER_REG]; + +// Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); + } + } + + // Skip the rest of the code if we do not have a remainder. + if (REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // The predicate register. + uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; + +// Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int ii = 0; ii < REMAINDER; ++ii) { + fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_(Functor& fct, uint32_t preds) { + uint32_t tmp[1] = {preds}; + load_(fct, tmp); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint8_t& dst, const void* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint16_t& dst, const void* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint32_t& dst, const void* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint2& dst, const void* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint4& dst, const void* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldg_functor { + // Ctor. + inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) + : fetch_(fetch), ptrs_(ptrs) {} + + // Clear the element. + inline __device__ void clear(int ii) { + fmha::clear(fetch_[ii]); + } + + // Trigger the loads. + inline __device__ void load(int ii, bool p) { + if (p) { + ldg(fetch_[ii], ptrs_[ii]); + } + } + + // The fetch registers. + Data_type (&fetch_)[N]; + // The pointers. + const void* (&ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg_( + Data_type (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + Ldg_functor fct(fetch, ptrs); + load_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg( + uint8_t (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg( + uint16_t (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg( + uint32_t (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg( + uint2 (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg( + uint4 (&fetch)[N], + const void* (&ptrs)[N], + uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint16_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint32_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint2& dst, uint32_t ptr) { + asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint4& dst, uint32_t ptr) { + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint8_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint16_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint32_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint2 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint4 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint16_t val) { + asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint32_t val) { + asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint2 val) { + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" + : + : "r"(ptr), "r"(val.x), "r"(val.y)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint4 val) { + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + sts(ptrs[ii], data[ii]); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { + return x > y ? x : y; + } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { + return max(x, y); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { + return x + y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce( + float (&dst)[M], + float (&src)[M], + Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce( + __half2 (&dst)[M], + __half2 (&src)[M], + Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce( + float (&dst)[M], + float2 (&src)[M], + Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce( + __half2 (&dst)[M], + float2 (&src)[M], + Operator& op) { + __half2 tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = + op(reinterpret_cast(src[mi].x), + reinterpret_cast(src[mi].y)); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce( + float (&dst)[M], + float (&src)[M], + Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce( + __half2 (&dst)[M], + __half2 (&src)[M], + Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce( + float (&dst)[M], + float2 (&src)[M], + Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce( + __half2 (&dst)[M], + float2 (&src)[M], + Operator& op) { + __half2 tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = + op(reinterpret_cast(src[mi].x), + reinterpret_cast(src[mi].y)); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu new file mode 100644 index 000000000..46bddc48e --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu @@ -0,0 +1,155 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include "fmha.h" +#include "fmha_block_fprop_kernel_1xN.h" + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax> +__global__ void fmha_block_fprop_fp16_sm80_loop_kernel( + Fused_multihead_attention_fprop_params params) { + fmha::device_block_1xN_loop< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax>(params); +} + +template +void run_fmha_block_fp16_sm80_loop_( + Launch_params& launch_params, + const bool configure) { + bool is_causal = launch_params.params.is_causal; + // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? + auto kernel = launch_params.is_dropout + ? (is_causal ? (launch_params.return_softmax + ? &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + true, + true> + : &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + true, + false>) + : (launch_params.return_softmax + ? &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + false, + true> + : &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + false, + false>)) + : (is_causal ? (launch_params.return_softmax + ? &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + true, + true> + : &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + true, + false>) + : (launch_params.return_softmax + ? &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + false, + true> + : &fmha_block_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + false, + false>)); + + constexpr int N = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.s + N - 1) / N; + constexpr int smem_size_softmax_lse = + Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr int M = Kernel_traits::Cta_tile_p::M; + size_t STEPS = (launch_params.params.s + M - 1) / M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; + launch_params.elts_per_thread = elts_per_head; + return; + } + + dim3 grid(launch_params.params.h, launch_params.params.b); + kernel<<>>( + launch_params.params); + + FMHA_CHECK_CUDA(cudaPeekAtLastError()); +} + +void run_fmha_block_fp16_sm80( + Launch_params& launch_params, + const bool configure) { + if (launch_params.params.d == 16) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; + run_fmha_block_fp16_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 32) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; + run_fmha_block_fp16_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 64) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; + run_fmha_block_fp16_sm80_loop_(launch_params, configure); + } +} diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h new file mode 100644 index 000000000..89776414a --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h @@ -0,0 +1,661 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/*************************************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include "fmha_blockmask.h" +#include "fmha_fprop_kernel_1xN.h" +#include "fmha_kernel.h" + +namespace fmha { + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax, + bool Is_first, + bool Is_last, + typename Params, + typename Prng> +inline __device__ void device_block_1xN_( + const Params& params, + const int bidb, + const int bidh, + int steps, + Prng& ph0, + Prng& ph1, + const int loop_step_idx) { + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_o_tmp = fmha::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + // if( binfo.stop_early() ) return; + if (binfo.stop_early(loop_step_idx * Cta_tile_p::N)) + return; + + Blockmask blockmask(params, loop_step_idx); + int block_row_idx = 0; + int mask_val = blockmask.mask_val(0); + if (mask_val == -1) + return; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("mask_val = %d.\n", mask_val); + // } + + Gemm1 gemm_q_k(smem_, tidx); + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q( + params.q_ptr, + params.q_row_stride_in_elts, + params.q_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o( + params.o_ptr, + params.o_row_stride_in_elts, + params.o_head_stride_in_elts, + binfo, + tidx); + Gmem_tile_o_tmp gmem_o_tmp( + params.o_tmp_ptr, + params.o_row_stride_in_elts, + params.o_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + + // Wind gmem tiles to the correct position. + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + int block_row_idx_next = mask_val / 4; + int block_row_idx_to_move = block_row_idx_next - block_row_idx; + gmem_q.move(block_row_idx_to_move); + gmem_o.move(block_row_idx_to_move); + gmem_o_tmp.move(block_row_idx_to_move); + if (Return_softmax) { + gmem_s.move(block_row_idx_to_move); + } + gmem_softmax_lse.move(block_row_idx_to_move); + block_row_idx = block_row_idx_next; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("begin = %d, steps = %d\n", begin, steps); + // } + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k( + params.k_ptr, + params.k_row_stride_in_elts, + params.k_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v( + params.v_ptr, + params.v_row_stride_in_elts, + params.v_head_stride_in_elts, + binfo, + tidx); + // The base pointer of smem_v; + char* smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be + // careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + + // Allocate the shared memory tile loader for O. We use the same as K so be + // careful!!! + Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + if (!Is_first) { + gmem_k.move(loop_step_idx); + gmem_v.move(loop_step_idx); + if (Return_softmax) { + gmem_s.move(loop_step_idx * steps); + } + } + + // Trigger the loads for K. + gmem_k.load(); + // Trigger the loads for Q. + gmem_q.load(); + // Trigger the loads for V. + gmem_v.load(); + + if (!Is_first) { + __syncthreads(); + } + + float p_prev_lse[Mma_tile_p::MMAS_M * 2]; + if (!(Is_first || mask_val % 2 == 1)) { + gmem_softmax_lse.load( + reinterpret_cast(p_prev_lse)); + } + + // Commit the data for Q and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_v.commit(smem_v); + + // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); #pragma unroll for(int it=0;it < + // Gmem_tile_k::LDGS;it++){ + // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + // } + + // Commit the data for K to shared memory. + if (!Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire + // kernel. + typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + smem_v.load(frag_v[ki], ki); + } + + // Commit the data for V to shared memory if it has not been done already. + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); + + Smem_softmax_sum smem_softmax_lse( + reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); + + // Load over the entire sequence length. + for (int l = 0; l < steps; l++) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("block_row_idx = %d\n", block_row_idx); + // } + if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen) + break; + + int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("mask_val = %d, mask_val_next = %d\n", mask_val, + // mask_val_next); + // } + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator< + typename fmha::Accumulator_type, + Cta_tile_p::WARPS_K>::apply(acc_p); + + // Do this part of P = Q * K^T. + gemm_q_k(acc_p); + + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + bool is_first_read = Is_first || mask_val % 2 == 1; + // if (!Is_first) { gmem_o_tmp.load(out, 0); } + if (!is_first_read) { + gmem_o_tmp.load(out, 0); + } + + // Trigger the load for the next Q values. + bool not_last_iter = (l < steps - 1) && (mask_val_next != -1); + block_row_idx_next = mask_val_next / 4; + int block_row_idx_to_move = block_row_idx_next - block_row_idx; + if (not_last_iter) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(block_row_idx_to_move); + gmem_q.load(); + } + + // Load the mask for that iteration. + mask.load(block_row_idx); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + + // Apply the mask. + softmax.apply_mask(mask); + + // softmax.unpack_noscale_half_and_apply_mask(acc_p, mask); + + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0) { + // if we share K and V, it could be that V was not fully read yet but we + // write into smem for reduction + __syncthreads(); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); + // } + // } + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + // if (!Is_first) { + if (!is_first_read) { + smem_softmax_lse.store_pair(p_prev_lse, l % 2); + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = + // p_prev_lse[mi]; } + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; + } + } + + // Trigger the load for the next LSE values. + if (not_last_iter) { + // if (!Is_first) { + if (!(Is_first || mask_val_next % 2 == 1)) { + gmem_softmax_lse.load_next( + reinterpret_cast(p_prev_lse), + block_row_idx_to_move); + } + } + + // __half2 p_max[Mma_tile_p::MMAS_M]; + // softmax.template reduce_max(p_max); + is_first_read ? softmax.template reduce_max(p_max) + : softmax.template reduce_max(p_max); + + // if ((threadIdx.x == 0) && (l == 38)) { + // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, + // %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : + // p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); + // } + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], + // softmax.elt_[0][1]); + // } + // } + + // Compute the exponential value. + // softmax.apply_exp(p_max); + softmax.scale_apply_exp(p_max, params.scale_bmm1f); + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], + // softmax.elt_[0][1]); + // } + // } + + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + // if (!Is_first) { + // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; + // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? + // expf(p_prev_lse[mi] - p_max[mi]) : 0; + // } + // } + // softmax.reduce_sum(p_sum); + softmax.reduce_sum_before_sync_(p_sum); + // softmax.template reduce_sum_before_sync_(p_sum); + + // float p_sum_log[Mma_tile_p::MMAS_M * 2]; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { + // float sum = p_sum[mi]; + // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + // + __logf(sum); constexpr float kLog2e = M_LOG2E; p_sum_log[mi] = (sum + // == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); + // } + // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); + // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); gmem_softmax_lse.move(); + + // // Finalize softmax on the accumulators of P^T. + // softmax.scale(p_sum); + + constexpr bool encode_dropout_in_sign_bit = Return_softmax; + if (Is_dropout) { + // softmax.template apply_dropout(ph0, + // params.p_dropout_in_uint); softmax.template + // apply_dropout(ph0, ph1, + // params.p_dropout_in_uint); + softmax.template apply_dropout_16bits( + ph0, ph1, params.p_dropout_in_uint16_t); + } + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + softmax.pack(frag_p); + if (Return_softmax) { + gmem_s.store(frag_p, mask); + if (not_last_iter) { + gmem_s.move(block_row_idx_to_move); + } + } + + // Commit the values for Q into shared memory. + if (not_last_iter) { + gmem_q.commit(gemm_q_k.smem_q); + } + + if (Is_dropout && encode_dropout_in_sign_bit) { +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ki++) { +#pragma unroll + for (int mi = 0; mi < Mma_tile_o::MMAS_M; mi++) { + frag_p[ki][mi].hrelu_(); + } + } + } + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator< + typename fmha::Accumulator_type, + Cta_tile_o::WARPS_K>::apply(acc_o); + +// Do this part of O = P^T * V^T. +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + } + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. + float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + // TODO: not sure if this is right for seqlen 128 or 256 + int rows[Gmem_tile_o::STGS_PER_LOOP]; + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + rows[jj] = + tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; + } + softmax.reduce_max_after_sync_(p_max_o, rows); + static_assert(Mma_tile_o::MMAS_M == 1); + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_max_o[jj][0] *= params.scale_bmm1f; + } + float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; + // if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); } + if (!is_first_read) { + smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); + // } + // } + + static_assert(Gmem_tile_o::LOOPS == 1); + + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, 0); + + // Make sure the data is in shared memory. + __syncthreads(); + + static_assert(Mma_tile_o::MMAS_M == 1); + float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + softmax.reduce_sum_after_sync_(p_sum_o, rows); + // if (!Is_first) { + if (!is_first_read) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); + p_sum_o[jj][0] += p_prev_scale_o[jj]; + } + } + + float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; +#pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + p_sum_log[jj][0] = + (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); + // if (sum == 0.f || sum != sum) { + // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o + // = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); + // } + // if (Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); + // } + // } + if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && + (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) { + gmem_softmax_lse.store_row( + reinterpret_cast(p_sum_log[jj]), + rows[jj]); + } + } + if (not_last_iter) { + gmem_softmax_lse.move(block_row_idx_to_move); + } + + // Load from shared memory. + // if (!Is_first) { + if (!is_first_read) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); + } + } + // smem_o.template load(out); + is_first_read ? smem_o.template load(out) + : smem_o.template load(out); + + const bool is_final_write = Is_last || + ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || + ((mask_val & 0x2) != 0) || + ((Is_causal) && + (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("is_final_write = %d\n", is_final_write); +// } +#pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + if (Is_dropout && is_final_write) { + inv_sum *= params.rp_dropout; + } + out[jj] = fmha::fmul4(out[jj], inv_sum); + } + + // if (Is_dropout && Is_last) { + // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); + // } + // } + + // Output the values. + if (is_final_write) { + gmem_o.store(out, 0); + } else { + gmem_o_tmp.store(out, 0); + } + + // Move to the next part of the output. + gmem_o.move(block_row_idx_to_move); + if (!(Is_first && Is_last)) { + gmem_o_tmp.move(block_row_idx_to_move); + } + gemm_q_k.reload_k(); + + // Make sure we are reading from the correct buffer. + gemm_q_k.smem_q.move_to_next_read_buffer(); + // Trigger the load from shared memory for the next series of Q values. + if (not_last_iter) { + gemm_q_k.reload_q(); + } + + if (mask_val_next == -1) + break; + mask_val = mask_val_next; + block_row_idx += block_row_idx_to_move; + + } // Outer loop over the sequence length. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax, + typename Params> +inline __device__ void device_block_1xN_loop(const Params& params) { + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + // auto seeds = at::cuda::philox::unpack(params.philox_args); + auto seeds = std::make_tuple(0, 0); + Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); + const int STEPS = params.s / Kernel_traits::Cta_tile_p::M; + + constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; + if (params.s == N_per_loop) { + fmha::device_block_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + true, + true>(params, bidb, bidh, STEPS, ph0, ph1, 0); + } else { + const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; + fmha::device_block_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + true, + false>(params, bidb, bidh, STEPS, ph0, ph1, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; + loop_step_idx++) { + fmha::device_block_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + false, + false>(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx); + } + fmha::device_block_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + false, + true>(params, bidb, bidh, STEPS, ph0, ph1, max_loop_steps - 1); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h b/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h new file mode 100644 index 000000000..9de497e7f --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h @@ -0,0 +1,69 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Blockmask { + template + __device__ Blockmask(const Params& params, int loop_step_idx) + : blockmask_ptr(params.blockmask + loop_step_idx * params.s / 16) {} + + __device__ int mask_val(int block_row_idx) const { + return blockmask_ptr[block_row_idx]; + } + + const int* blockmask_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu new file mode 100644 index 000000000..5031d81a0 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu @@ -0,0 +1,262 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +// #include "fmha.h" +// #include "fmha_fprop_kernel_1xN.h" + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax> +__global__ void fmha_fprop_fp16_sm80_loop_kernel( + Fused_multihead_attention_fprop_params params) { + fmha::device_1xN_loop( + params); +} + +template +void run_fmha_fp16_sm80_loop_( + Launch_params& launch_params, + const bool configure) { + bool is_causal = launch_params.params.is_causal; + // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? + auto kernel = launch_params.is_dropout + ? (is_causal ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + true, + true> + : &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + true, + false>) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + false, + true> + : &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + true, + false, + false>)) + : (is_causal ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + true, + true> + : &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + true, + false>) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + false, + true> + : &fmha_fprop_fp16_sm80_loop_kernel< + Kernel_traits, + false, + false, + false>)); + + constexpr int N = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.s + N - 1) / N; + constexpr int smem_size_softmax_lse = + Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + if (smem_size >= 48 * 1024) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr int M = Kernel_traits::Cta_tile_p::M; + size_t STEPS = (launch_params.params.s + M - 1) / M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; + launch_params.elts_per_thread = elts_per_head; + return; + } + + dim3 grid(launch_params.params.h, launch_params.params.b); + kernel<<>>( + launch_params.params); + + FMHA_CHECK_CUDA(cudaPeekAtLastError()); +} + +void run_fmha_fp16_sm80( + Launch_params& launch_params, + const bool configure) { + if (launch_params.params.d == 16) { + if (launch_params.params.s == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (launch_params.params.s == 256) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + // TD [2022-05-15] 512 gives wrong results rn + // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>; + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } else if (launch_params.params.d == 32) { + if (launch_params.params.s == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (launch_params.params.s == 256) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } else if (launch_params.params.d == 64) { + if (launch_params.params.s == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (launch_params.params.s >= 256) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if (dprops->major == 7 && dprops->minor == 5) { + // if (launch_params.is_dropout) { // Need to use the same block size + // as backward + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } + // } + } + } else if (launch_params.params.d == 128) { + if (launch_params.params.s == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0 && + // !launch_params.is_dropout) { + // // TD [2022-06-05] Keep K in registers to reduce register spilling + // // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } + } + } + // if (launch_params.params.d == 64) { + // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>; + // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // if (launch_params.params.d == 64) { + // if( launch_params.params.s == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.s >= 256 ) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0) { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } else if (dprops->major == 7 && dprops->minor == 5) { + // if (launch_params.is_dropout) { // Need to use the same block + // size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, + // 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, + // 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } + // } + // } + // } + // if (launch_params.params.d == 128) { + // if( launch_params.params.s == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } else { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0 && + // !launch_params.is_dropout) { + // // TD [2022-06-05] Keep K in registers to reduce register + // spilling + // // Gives about 6% speedup compared to using block size 128. + // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, + // 0x18u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } else { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, + // 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, + // configure); + // } + // } + // } +} diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h new file mode 100644 index 000000000..1cd4c191c --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h @@ -0,0 +1,795 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/*************************************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include "fmha_kernel.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gemm_Q_K_base { + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + using Fragment_q = typename Smem_tile_q::Fragment; + using Fragment_k = typename Smem_tile_k::Fragment; + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + + static constexpr int SMEM_BYTES_SOFTMAX = + Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; + + __device__ inline Gemm_Q_K_base( + char* smem_ptr_q, + char* smem_ptr_k, + const int tidx) + : smem_q(smem_ptr_q, tidx), smem_k(smem_ptr_k, tidx) {} + + __device__ inline void load_q() { + smem_q.load(frag_q[0], 0); + } + + __device__ inline void reload_q() { + smem_q.load(frag_q[0], 0); + } + + Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; + Smem_tile_q smem_q; + Smem_tile_k smem_k; +}; + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = + Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + // If V is stored in shared memory, we can't load K using the same shared + // memory. + static_assert(Kernel_traits::V_IN_REGS); + + static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_SOFTMAX = + SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); + + // Q | K / V + // | O | SOFTMAX + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); + + __device__ inline Gemm_Q_K(char* smem_, const int tidx) + : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} + + __device__ inline void load_k() { +#pragma unroll + for (int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) { + Base::smem_k.load(frag_k[ki], ki); + } + } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]) { +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + } + } + + __device__ inline void reload_k() { + // Noop. + } + + Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; +}; + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + using Base = Gemm_Q_K_base; + using Smem_tile_o = typename Base::Smem_tile_o; + using Smem_tile_q = typename Base::Smem_tile_q; + using Smem_tile_k = typename Base::Smem_tile_k; + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + using Fragment_k = typename Base::Fragment_k; + using Mma_tile_p = typename Base::Mma_tile_p; + Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = + Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS; + static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V); + + static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); + static_assert( + Smem_tile_v::BYTES_PER_TILE == (int)Smem_tile_k::BYTES_PER_TILE); + static constexpr int SMEM_OFFSET_O = + SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE; + static constexpr int SMEM_OFFSET_SOFTMAX = + SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; + + // If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX + // If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; + + __device__ inline Gemm_Q_K(char* smem_, const int tidx) + : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} + + __device__ inline void load_k() { + Base::smem_k.load(frag_k[0], 0); + } + + template + __device__ inline void operator()(Acc (&acc_p)[M][N]) { +// Do this part of P^T = (Q * K^T)^T. +#pragma unroll + for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { + // Trigger the load from shared memory for the next series of Q values. + Base::smem_q.load(Base::frag_q[ki & 1], ki); + Base::smem_k.load(frag_k[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + } + } + + __device__ inline void reload_k() { + Base::smem_k.load(frag_k[0], 0); + } +}; + +template +constexpr size_t get_dynamic_smem_size() { + return Gemm_Q_K::SMEM_BYTES; +} + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax, + bool Is_first, + bool Is_last, + typename Params, + typename Prng> +inline __device__ void device_1xN_( + const Params& params, + const int bidb, + const int bidh, + int begin, + int steps, + Prng& ph0, + Prng& ph1, + const int loop_step_idx) { + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_o_tmp = fmha::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + // if( binfo.stop_early() ) return; + if (binfo.stop_early(loop_step_idx * Cta_tile_p::N)) + return; + + Gemm1 gemm_q_k(smem_, tidx); + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q( + params.q_ptr, + params.q_row_stride_in_elts, + params.q_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o( + params.o_ptr, + params.o_row_stride_in_elts, + params.o_head_stride_in_elts, + binfo, + tidx); + Gmem_tile_o_tmp gmem_o_tmp( + params.o_tmp_ptr, + params.o_row_stride_in_elts, + params.o_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + + // Wind gmem tiles to the correct position. + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + const int begin_og = begin; + begin = Is_causal + ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) + : begin; + const int steps_og = steps; + steps -= begin - begin_og; + gmem_q.move(begin); + gmem_o.move(begin); + gmem_o_tmp.move(begin); + if (Return_softmax) { + gmem_s.move(begin); + } + gmem_softmax_lse.move(begin); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("begin = %d, steps = %d\n", begin, steps); + // } + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k( + params.k_ptr, + params.k_row_stride_in_elts, + params.k_head_stride_in_elts, + binfo, + tidx); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v( + params.v_ptr, + params.v_row_stride_in_elts, + params.v_head_stride_in_elts, + binfo, + tidx); + // The base pointer of smem_v; + char* smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be + // careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + + // Allocate the shared memory tile loader for O. We use the same as K so be + // careful!!! + Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + if (!Is_first) { + gmem_k.move(loop_step_idx); + gmem_v.move(loop_step_idx); + if (Return_softmax) { + gmem_s.move(loop_step_idx * steps_og); + } + } + + // Trigger the loads for K. + gmem_k.load(); + // Trigger the loads for Q. + gmem_q.load(); + // Trigger the loads for V. + gmem_v.load(); + + if (!Is_first) { + __syncthreads(); + } + + float p_prev_lse[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + gmem_softmax_lse.load( + reinterpret_cast(p_prev_lse)); + } + + // Commit the data for Q and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_v.commit(smem_v); + + // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); #pragma unroll for(int it=0;it < + // Gmem_tile_k::LDGS;it++){ + // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + // } + + // Commit the data for K to shared memory. + if (!Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire + // kernel. + typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + smem_v.load(frag_v[ki], ki); + } + + // Commit the data for V to shared memory if it has not been done already. + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); + + Smem_softmax_sum smem_softmax_lse( + reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); + + // Load over the entire sequence length. + for (int l = 0; l < steps; l++) { + if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen) + break; + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator< + typename fmha::Accumulator_type, + Cta_tile_p::WARPS_K>::apply(acc_p); + + // Do this part of P = Q * K^T. + gemm_q_k(acc_p); + + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + if (!Is_first) { + gmem_o_tmp.load(out, 0); + } + + // Trigger the load for the next Q values. + if (l < steps - 1) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(); + } + + // Load the mask for that iteration. + mask.load(begin + l); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + + // Apply the mask. + softmax.apply_mask(mask); + + if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0) { + // if we share K and V, it could be that V was not fully read yet but we + // write into smem for reduction + __syncthreads(); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); + // } + // } + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + smem_softmax_lse.store_pair(p_prev_lse, l % 2); + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = + // p_prev_lse[mi]; } + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; + } + } + + // Trigger the load for the next LSE values. + if (l < steps - 1) { + if (!Is_first) { + gmem_softmax_lse.load_next( + reinterpret_cast(p_prev_lse)); + } + } + + softmax.template reduce_max(p_max); + + // if ((threadIdx.x == 0) && (l == 38)) { + // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, + // %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : + // p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); + // } + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], + // softmax.elt_[0][1]); + // } + // } + + // Compute the exponential value. + // softmax.apply_exp(p_max); + softmax.scale_apply_exp(p_max, params.scale_bmm1f); + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], + // softmax.elt_[0][1]); + // } + // } + + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + // if (!Is_first) { + // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; + // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? + // expf(p_prev_lse[mi] - p_max[mi]) : 0; + // } + // } + // softmax.reduce_sum(p_sum); + softmax.reduce_sum_before_sync_(p_sum); + // softmax.template reduce_sum_before_sync_(p_sum); + + // float p_sum_log[Mma_tile_p::MMAS_M * 2]; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { + // float sum = p_sum[mi]; + // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + // + __logf(sum); constexpr float kLog2e = M_LOG2E; p_sum_log[mi] = (sum + // == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); + // } + // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); + // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); gmem_softmax_lse.move(); + + // // Finalize softmax on the accumulators of P^T. + // softmax.scale(p_sum); + + constexpr bool encode_dropout_in_sign_bit = Return_softmax; + if (Is_dropout) { + // softmax.template apply_dropout(ph0, + // params.p_dropout_in_uint); softmax.template + // apply_dropout(ph0, ph1, + // params.p_dropout_in_uint); + softmax.template apply_dropout_16bits( + ph0, ph1, params.p_dropout_in_uint16_t); + } + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + softmax.pack(frag_p); + if (Return_softmax) { + gmem_s.store(frag_p, mask); + gmem_s.move(); + } + + // Commit the values for Q into shared memory. + if (l < steps - 1) { + gmem_q.commit(gemm_q_k.smem_q); + } + + if (Is_dropout && encode_dropout_in_sign_bit) { +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ki++) { +#pragma unroll + for (int mi = 0; mi < Mma_tile_o::MMAS_M; mi++) { + frag_p[ki][mi].hrelu_(); + } + } + } + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator< + typename fmha::Accumulator_type, + Cta_tile_o::WARPS_K>::apply(acc_o); + +// Do this part of O = P^T * V^T. +#pragma unroll + for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l + // == 0)) { + // float2 tmp_p = __half22float2(reinterpret_cast<__half2 + // &>(frag_p[ki])); float2 tmp_v = + // __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); + // printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = + // %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, + // tmp_v.y, acc_o[0][0].elt(0)); + // } + } + + // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, + // acc_o[0][2].elt(0)); + // } + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. + float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + // TODO: not sure if this is right for seqlen 128 or 256 + int rows[Gmem_tile_o::STGS_PER_LOOP]; + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + rows[jj] = + tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; + } + softmax.reduce_max_after_sync_(p_max_o, rows); + static_assert(Mma_tile_o::MMAS_M == 1); + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_max_o[jj][0] *= params.scale_bmm1f; + } + float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; + if (!Is_first) { + smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); + // } + // } + + static_assert(Gmem_tile_o::LOOPS == 1); + + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, 0); + + // Make sure the data is in shared memory. + __syncthreads(); + + static_assert(Mma_tile_o::MMAS_M == 1); + float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + softmax.reduce_sum_after_sync_(p_sum_o, rows); + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); + p_sum_o[jj][0] += p_prev_scale_o[jj]; + } + } + + float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; +#pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + p_sum_log[jj][0] = + (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); + // if (sum == 0.f || sum != sum) { + // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o + // = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); + // } + // if (Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && + // (l == 0)) { + // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); + // } + // } + if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && + (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) { + gmem_softmax_lse.store_row( + reinterpret_cast(p_sum_log[jj]), + rows[jj]); + } + } + gmem_softmax_lse.move(); + + // Load from shared memory. + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); + } + } + smem_o.template load(out); + + const bool is_final_write = Is_last || + ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || + ((Is_causal) && + ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); +#pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + if (Is_dropout && is_final_write) { + inv_sum *= params.rp_dropout; + } + out[jj] = fmha::fmul4(out[jj], inv_sum); + } + + // if (Is_dropout && Is_last) { + // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); + // } + // } + + // Output the values. + if (is_final_write) { + gmem_o.store(out, 0); + gmem_o.move(); + } else { + gmem_o_tmp.store(out, 0); + } + + // Move to the next part of the output. + if (!(Is_first && Is_last)) { + gmem_o_tmp.move(); + } + gemm_q_k.reload_k(); + + // Make sure we are reading from the correct buffer. + gemm_q_k.smem_q.move_to_next_read_buffer(); + // Trigger the load from shared memory for the next series of Q values. + if (l < steps - 1) { + gemm_q_k.reload_q(); + } + + } // Outer loop over the sequence length. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Kernel_traits, + bool Is_dropout, + bool Is_causal, + bool Return_softmax, + typename Params> +inline __device__ void device_1xN_loop(const Params& params) { + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.x; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + // auto seeds = at::cuda::philox::unpack(params.philox_args); + auto seeds = std::make_tuple(0, 0); + Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); + const int STEPS = params.s / Kernel_traits::Cta_tile_p::M; + + constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; + if (params.s == N_per_loop) { + fmha::device_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + true, + true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + } else { + const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; + fmha::device_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + true, + false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; + loop_step_idx++) { + fmha::device_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + false, + false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); + } + fmha::device_1xN_< + Kernel_traits, + Is_dropout, + Is_causal, + Return_softmax, + false, + true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h b/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h new file mode 100644 index 000000000..43692802b --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h @@ -0,0 +1,204 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfoPadded { + template + __device__ BlockInfoPadded( + const Params& params, + const int bidb, + const int bidh, + const int tidx) + : bidb(bidb), bidh(bidh), h(params.h) { + // The block index. + sum_s = params.cu_seqlens[bidb]; + actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; + bidx = sum_s * params.h + bidh; + + tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; + } + + __device__ bool stop_early(const int start_col = 0) const { + return actual_seqlen <= start_col; + } + + int actual_seqlen; + int bidx; + int sum_s; + int bidh; + int bidb; + int tidx_global; + int h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Noloop_traits { + // Interpretation of Cta_tile dims, i.e. Cta_tile_p: + enum { STEP = Cta_tile::M }; + enum { SEQLEN = Cta_tile::N }; + + template + inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) + : bidc_(bidc) { + const int seqlen = binfo.actual_seqlen; + const int steps = (seqlen + STEP - 1) / STEP; + const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; + + const int step_begin = bidc_ * steps_per_chunk; + const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); + const int actual_steps = max(0, step_end - step_begin); + loop_offset_ = step_begin; + num_steps_ = actual_steps; + } + + template + inline __device__ void move_all(Tiles&... tiles) const { + using expand_type = int[]; + for (int s = 0; s < loop_offset_; s++) { + expand_type{(tiles.move(), 0)...}; + } + } + + inline __device__ int get_idx_dk() const { + // return bidc_; + return bidc_ * 2 + 0; + } + + inline __device__ int get_idx_dv() const { + // return CHUNKS + bidc_; + return bidc_ * 2 + 1; + } + + inline __device__ int offset_loop_count(const int l) { + // convert loop counter to position in the outer sequence + return (loop_offset_ + l) * STEP; + } + + const uint32_t bidc_; + int loop_offset_; + int num_steps_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::tuple work_dist( + const int total_ctas, + const int heads_total) { + constexpr int STEPS_PER_HEAD = + Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; + + const int num_full_heads = heads_total / total_ctas; + const int heads_last_wave = heads_total % total_ctas; + + int num_main_groups = 0; + int main_steps = 0; + int rest_steps = 0; + if (heads_last_wave > 0) { + // Number of CTA groups that process within heads. + num_main_groups = total_ctas / heads_last_wave; + // Remaining CTAs that process between heads. + const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); + if (rest_ctas == 0) { + // We have exactly "num_main_groups" CTAs to process each of the remaining + // heads. + main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; + num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 + rest_steps = STEPS_PER_HEAD % main_steps; + + } else { + // Ideal number of steps if we could load-balance as evenly as possible. + const int steps_ideal = + (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; + // Iterations that a "rest" CTA has to do at most. + const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; + // Find the first step distribution, s.t. the maximum work of the "rest" + // CTAs is less than the work of the main CTAs. + main_steps = steps_ideal; + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; + for (; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++) { + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; + const int max_rest_total_steps = rest_steps * max_rest_iters; + if (max_rest_total_steps < main_steps) + break; + } + rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; + } + } + + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + using Mma_tile_p = fmha::Hmma_tile; + + const int max_steps = + STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); + const int elts_per_thread_per_step = + Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; + const int elts_per_thread = max_steps * elts_per_thread_per_step; + + return { + num_full_heads, + num_main_groups, + heads_last_wave, + main_steps, + rest_steps, + elts_per_thread}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_utils.h b/python/aitemplate/backend/cuda/attention/src/fmha_utils.h new file mode 100644 index 000000000..af8456621 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/fmha_utils.h @@ -0,0 +1,111 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define FMHA_CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + auto msg = std::string("CUDA error(") + __FILE__ + ":" + \ + std::to_string(__LINE__) + cudaGetErrorString(status_); \ + std::cerr << msg << std::endl; \ + throw std::runtime_error(msg); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum Data_type { + DATA_TYPE_FP16, + DATA_TYPE_FP32, + DATA_TYPE_INT32, + DATA_TYPE_INT8 +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) { + if (dtype == DATA_TYPE_FP16) { + half x = __float2half_rn(norm); + uint16_t h = reinterpret_cast(x); + ushort2 h2 = {h, h}; + alpha = reinterpret_cast(h2); + } else if (dtype == DATA_TYPE_FP32) { + alpha = reinterpret_cast(norm); + } else if (dtype == DATA_TYPE_INT32) { + int32_t inorm = static_cast(norm); + alpha = reinterpret_cast(inorm); + } else { + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline size_t get_size_in_bytes(size_t n, Data_type dtype) { + switch (dtype) { + case DATA_TYPE_FP32: + return n * 4; + case DATA_TYPE_FP16: + return n * 2; + case DATA_TYPE_INT32: + return n * 4; + case DATA_TYPE_INT8: + return n; + default: + assert(false); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/python/aitemplate/backend/cuda/attention/src/licenses/LICENSE b/python/aitemplate/backend/cuda/attention/src/licenses/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/licenses/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/python/aitemplate/backend/cuda/attention/src/philox.cuh b/python/aitemplate/backend/cuda/attention/src/philox.cuh new file mode 100644 index 000000000..36e788400 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/src/philox.cuh @@ -0,0 +1,171 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Pytorch also has an implementation of Philox RNG: +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +#pragma once +// Philox CUDA. + +namespace { + +class Philox { + public: + __device__ inline Philox( + unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0), key(reinterpret_cast(seed)) { + // key.x = (unsigned int)seed; + // key.y = (unsigned int)(seed >> 32); + // counter = make_uint4(0, 0, 0, 0); + // counter.z = (unsigned int)(subsequence); + // counter.w = (unsigned int)(subsequence >> 32); + // STATE = 0; + // incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2* tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, + // counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // if (STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; +// 7-round philox +#pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + // output = single_round(counter_, key_); + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, + // counter.z, counter.w); printf("Philox output: %u, %u, %u, %u\n", + // output.x, output.y, output.z, output.w); + // } + incr(); + // } + // return a float4 directly + // unsigned long ret; + // switch(STATE) { + // case 0: ret = output.x; break; + // case 1: ret = output.y; break; + // case 2: ret = output.z; break; + // case 3: ret = output.w; break; + //} + // STATE = (STATE + 1) % 4; + return output; + } + + private: + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128(uint4 ctr) { + uint4 res; + asm("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), + "r"(ctr.y), + "r"(ctr.z), + "r"(ctr.w), + "n"(1), + "n"(0), + "n"(0), + "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, + // counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, + // counter.z, counter.w); + // } + } + __device__ unsigned int mulhilo32( + unsigned int a, + unsigned int b, + unsigned int* result_high) { + *result_high = __umulhi(a, b); + return a * b; + } + __device__ uint2 mulhilo32_v2(const unsigned int a, const unsigned int b) { + uint2* res; + unsigned long long tmp; + asm("mul.wide.u32 %0, %1, %2;\n\t" : "=l"(tmp) : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; + } + __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { + // unsigned int hi0; + // unsigned int hi1; + // unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + // unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + // uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); + uint4 ret = { + res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; +// Inverse of 2^32. +constexpr float M_RAN_INVM32 = 2.3283064e-10f; +__device__ __inline__ float4 uniform4(const uint4 x) { + return make_float4( + x.x * M_RAN_INVM32, + x.y * M_RAN_INVM32, + x.z * M_RAN_INVM32, + x.w * M_RAN_INVM32); +} + +} // namespace diff --git a/python/aitemplate/backend/cuda/common/__init__.py b/python/aitemplate/backend/cuda/common/__init__.py new file mode 100644 index 000000000..2115b6952 --- /dev/null +++ b/python/aitemplate/backend/cuda/common/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +""" +CUDA Common module init +""" +from .dummy_op import * diff --git a/python/aitemplate/backend/cuda/common/dummy_op.py b/python/aitemplate/backend/cuda/common/dummy_op.py new file mode 100644 index 000000000..da293ee4e --- /dev/null +++ b/python/aitemplate/backend/cuda/common/dummy_op.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Dummy op codegen for CUDA. +""" + +from typing import Any, Dict + +from ... import registry + + +@registry.reg("cuda.size.gen_function") +def dummy_gen_function(func_attrs: Dict[str, Any]) -> str: + return "" + + +@registry.reg("cuda.size.func_decl") +def dummy_gen_function_decl(func_attrs): + return "" + + +@registry.reg("cuda.size.func_call") +def dummy_gen_function_call(func_attrs, indent): + return "" diff --git a/python/aitemplate/backend/cuda/conv2d/__init__.py b/python/aitemplate/backend/cuda/conv2d/__init__.py new file mode 100644 index 000000000..7d83ce1fd --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +""" +cuda conv2d module init +""" +from . import ( + conv2d, + conv2d_bias, + conv2d_bias_add, + conv2d_bias_add_hardswish, + conv2d_bias_add_relu, + conv2d_bias_few_channels, + conv2d_bias_hardswish, + conv2d_bias_hardswish_few_channels, + conv2d_bias_relu, + conv2d_bias_relu_few_channels, + conv2d_bias_sigmoid, + transposed_conv2d, + transposed_conv2d_bias, +) diff --git a/python/aitemplate/backend/cuda/conv2d/common.py b/python/aitemplate/backend/cuda/conv2d/common.py new file mode 100644 index 000000000..9e0de0d91 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/common.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +common template for conv2d +""" +import re +from collections import OrderedDict +from hashlib import sha1 +from typing import List + +import jinja2 + +from ...target import Target +from ..gemm_universal.common import add_profiler, build_profiler # noqa: F401 + + +KERNEL_KEY_TEMPLATE = jinja2.Template( + """ +cutlass{{opcode_class}}_{{extended_name}}_{{threadblock}}_{{layout}}_align_{{align_ab}}_{{align_c}} +""" +) + + +def kernel_name(op): + """generate cuda kernel name""" + from cutlass_lib import library + + threadblock = op.tile_description.procedural_name() + extended_name = op.extended_name() + opcode_class_name = library.OpcodeClassNames[ + op.tile_description.math_instruction.opcode_class + ] + layout = op.layout_name() + align_ab = op.A.alignment + align_c = op.C.alignment + name = KERNEL_KEY_TEMPLATE.render( + threadblock=threadblock, + extended_name=extended_name, + opcode_class_name=opcode_class_name, + layout=layout, + align_ab=align_ab, + align_c=align_c, + ) + return name.replace("\n", "") + + +def emit_instance(op): + """emit instance""" + import cutlass_lib + + if hasattr(op, "binary_op"): + emiter = cutlass_lib.conv2d_operation.EmitConv2dWithBroadcastInstance() + else: + emiter = cutlass_lib.conv2d_operation.EmitConv2dInstance() + op_def = emiter.emit(op) + return op_def + + +def extract_config(func_attrs, f_proc_op=None): + """Extracts cutlass config for conv kernels.""" + import copy + + import cutlass_lib + + def f_proc_op_default(op): + # import cutlass_lib + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + op_kind = cutlass_lib.library.OperationKind.Conv2d + conv_kind = cutlass_lib.library.ConvKind.Fprop + ret = [] + conv2d_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.conv_kind == conv_kind: + if f_proc_op is None: + ret = f_proc_op_default(op) + else: + ret = f_proc_op(op) + if len(ret) > 0: + for op_inst in ret: + key = kernel_name(op_inst) + conv2d_ops[key] = op_inst + return conv2d_ops + + +def extract_config_name(config): + """Extracts config name from a given config.""" + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = config.split("\n")[2] + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid config: \n" + config) + return match.groups()[0] + + +def gen_function( + func_attrs, + instance_template, + exec_template, + src_template, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + f_emit_instance=emit_instance, + extra_header="", +): + """Function definition codegen.""" + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + + inst_def_flag = set() + instances = {} + instance_decl = "" + for key, value in exec_path.items(): + fname = "f" + sha1(key.encode()).hexdigest() + if value not in inst_def_flag: + config = f_emit_instance(op_instance[value]) + inst_def_flag.add(value) + else: + config = "" + inst = instance_template.render( + config=config, name=fname, config_name=extract_config_name(config) + ) + instances[key] = inst + instance_decl += inst + shape_eval_func = shape_eval_template.render( + indent=" ", + dtype="int64_t ", + x_dim0="*batch", + x_dim1="*in_h", + x_dim2="*in_w", + x_dim3="*in_ch", + w_dim0="*out_ch", + w_dim1="*kernel_h", + w_dim2="*kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + div="/", + ) + shape_save_func = shape_save_template.render( + indent=" ", + y_dim0="*out_batch", + y_dim1="*out_h", + y_dim2="*out_w", + y_dim3="*out_ch", + ) + shape_func = shape_eval_func + shape_save_func + exec_paths = "" + for key in instances: + fname = "f" + sha1(key.encode()).hexdigest() + program = exec_template.render(indent=" ", instance=fname) + exec_inst = exec_cond_remplate.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + return src_template.render( + instances=instance_decl, + function_name=func_name, + dtype="cutlass::half_t", + shape_function=shape_func, + exec_paths=exec_paths, + extra_header=extra_header, + ) + + +def cal_align_ab(x_shape: List[int]) -> int: + """Returns input alignment.""" + k = x_shape[3] # CI + if k % 8 == 0: + return 8 + if k % 4 == 0: + return 4 + if k % 2 == 0: + return 2 + raise RuntimeError("a/b is not aligned") + + +def function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + ab_alignment = cal_align_ab(x_shape) + tmp = cfg.split("_") + align_c = int(tmp[-1]) + align_ab = int(tmp[-2]) + if align_c != func_attrs["epilogue_alignment"]: + return False + if align_ab != ab_alignment: + return False + return True diff --git a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py new file mode 100644 index 000000000..ddcef02b3 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +common templates for conv_bias_activation subgraph +""" +import jinja2 + +from . import common + +# pylint: disable=C0103,C0301 + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = cutlass::conv::device::ImplicitGemmConvolution<{{config_name}}>; +""" +) + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementCompute; +// TODO: cast to right dtype +{{indent}}typename {{instance}}::Arguments arguments{ +{{indent}} problem_size, +{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, +{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, +{{indent}} {(cutlass::half_t*)(bias_ptr), cutlass::layout::TensorNHWC::Stride(0)}, +{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, +{{indent}}}; +{{indent}}{{instance}} implicit_gemm_op; +{% if is_profiler %} +{{indent}}size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}auto status = implicit_gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op.initialize(arguments, workspace); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include +#include + +{{extra_header}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + cutlass::half_t* in_ptr, + cutlass::half_t* weight_ptr, + cutlass::half_t* out_ptr, + cutlass::half_t* bias_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_h, + int64_t* out_w, + int stride, + int dilation, + int pad, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNHWC; + TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); + TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv2dProblemSize problem_size( + {i32_batch, i32_in_h, i32_in_w, i32_in_ch}, + {i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, + {pad, pad, pad, pad}, + {stride, stride}, + {dilation, dilation}, + {i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, + cutlass::conv::Mode::kCrossCorrelation, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv2d specialization." + ); +} +""" +) + + +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; +{{op_func}} + +int main(int argc, char** argv) { + int64_t batch = std::stoi(argv[1]); + int64_t in_h = std::stoi(argv[2]); + int64_t in_w = std::stoi(argv[3]); + int64_t in_ch = std::stoi(argv[4]); + int64_t kernel_h = std::stoi(argv[5]); + int64_t kernel_w = std::stoi(argv[6]); + int64_t out_ch = std::stoi(argv[7]); + int stride = std::stoi(argv[8]); + int pad = std::stoi(argv[9]); + int dilation = std::stoi(argv[10]); + {{shape_func}} + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + cutlass::HostTensor x({NI, HI, WI, CI}); + cutlass::HostTensor w({CO, KH, KW, CI}); + cutlass::HostTensor b({(int)CO, 1, 1, 1}); + cutlass::HostTensor y({NO, HO, WO, CO}); + // + // warmup + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + (cutlass::half_t*) b.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 5; ++i) { + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + (cutlass::half_t*) b.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + +""" +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + uint8_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int, + int, + int, + cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{in_ptr}}, +{{indent}} {{weight_ptr}}, +{{indent}} {{out_ptr}}, +{{indent}} {{bias_ptr}}, +{{indent}} global_workspace, +{{indent}} {{p_batch}}, +{{indent}} {{p_out_ch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_kernel_h}}, +{{indent}} {{p_kernel_w}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} {{stride}}, +{{indent}} {{dilation}}, +{{indent}} {{pad}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_profiler(func_attrs, workdir, shape_template, extra_header=""): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_h", + x_dim2="in_w", + x_dim3="in_ch", + w_dim0="out_ch", + w_dim1="kernel_h", + w_dim2="kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + ) + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance(op) + + config_name = common.extract_config_name(config) + name = "DeviceConvFwdInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + dtype="cutlass::half_t", + shape_func="", + exec_paths=exec_program, + extra_header=extra_header, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + b = func_attrs["inputs"][2] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + bias_ptr=b._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py new file mode 100644 index 000000000..0647769a1 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +common template for conv2d bias act residual add +""" +import jinja2 + +from . import common + +# pylint: disable=C0301,C0103 + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = cutlass::conv::device::ImplicitGemmConvolution<{{config_name}}>; +""" +) + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementCompute; +// TODO: cast to right dtype +{{indent}}typename {{instance}}::Arguments arguments{ +{{indent}} problem_size, +{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, +{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, +{{indent}} {(cutlass::half_t*)(res_ptr), layout_C}, +{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, +{{indent}} cutlass::conv::SplitKMode::kSerial, +{{indent}} (cutlass::half_t*)(bias_ptr), +{{indent}} nullptr, 0, *out_ch +{{indent}}}; +{{indent}}{{instance}} implicit_gemm_op; +{% if is_profiler %} +{{indent}}size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}auto status = implicit_gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op.initialize(arguments, workspace); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +return; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include +#include + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + cutlass::half_t* in_ptr, + cutlass::half_t* weight_ptr, + cutlass::half_t* out_ptr, + cutlass::half_t* bias_ptr, + cutlass::half_t* res_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_h, + int64_t* out_w, + int stride, + int dilation, + int pad, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNHWC; + TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); + TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv2dProblemSize problem_size( + {i32_batch, i32_in_h, i32_in_w, i32_in_ch}, + {i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, + {pad, pad, pad, pad}, + {stride, stride}, + {dilation, dilation}, + {i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, + cutlass::conv::Mode::kCrossCorrelation, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv2d specialization." + ); +} +""" +) + + +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; +{{op_func}} + +int main(int argc, char** argv) { + int64_t batch = std::stoi(argv[1]); + int64_t in_h = std::stoi(argv[2]); + int64_t in_w = std::stoi(argv[3]); + int64_t in_ch = std::stoi(argv[4]); + int64_t kernel_h = std::stoi(argv[5]); + int64_t kernel_w = std::stoi(argv[6]); + int64_t out_ch = std::stoi(argv[7]); + int stride = std::stoi(argv[8]); + int pad = std::stoi(argv[9]); + int dilation = std::stoi(argv[10]); + {{shape_func}} + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + cutlass::HostTensor x({NI, HI, WI, CI}); + cutlass::HostTensor w({CO, KH, KW, CI}); + cutlass::HostTensor b({(int)CO, 1, 1, 1}); + cutlass::HostTensor r({NO, HO, WO, CO}); + cutlass::HostTensor y({NO, HO, WO, CO}); + // + // warmup + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + (cutlass::half_t*) b.device_data(), + (cutlass::half_t*) r.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 5; ++i) { + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + (cutlass::half_t*) b.device_data(), + (cutlass::half_t*) r.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + uint8_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int, + int, + int, + cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{in_ptr}}, +{{indent}} {{weight_ptr}}, +{{indent}} {{out_ptr}}, +{{indent}} {{bias_ptr}}, +{{indent}} {{res_ptr}}, +{{indent}} global_workspace, +{{indent}} {{p_batch}}, +{{indent}} {{p_out_ch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_kernel_h}}, +{{indent}} {{p_kernel_w}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} {{stride}}, +{{indent}} {{dilation}}, +{{indent}} {{pad}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_profiler(func_attrs, workdir, shape_template): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_h", + x_dim2="in_w", + x_dim3="in_ch", + w_dim0="out_ch", + w_dim1="kernel_h", + w_dim2="kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + ) + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance(op) + config_name = common.extract_config_name(config) + name = "DeviceConvFwdInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + dtype="cutlass::half_t", + shape_func="", + exec_paths=exec_program, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/conv2d/common_conv2d_few_channels.py b/python/aitemplate/backend/cuda/conv2d/common_conv2d_few_channels.py new file mode 100644 index 000000000..c24f0a4db --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/common_conv2d_few_channels.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +common functions for conv2d op with few channels(< 8) +""" + +from collections import OrderedDict + +from ...target import Target +from . import common + + +def apply_special_config(func_attrs, op): + import cutlass_lib + + x = func_attrs["inputs"][0] + in_ch = x._attrs["shape"][-1]._attrs["values"][0] + + if in_ch == 3: + # By default we don't use it since the perf is worse than pad4+fixchannel + op.iterator_algorithm = cutlass_lib.library.IteratorAlgorithm.FewChannels + op.A.alignment = 1 + op.B.alignment = 1 + op.tile_description.stages = 2 + elif in_ch in [2, 4, 8]: + op.iterator_algorithm = cutlass_lib.library.IteratorAlgorithm.FixedChannels + op.A.alignment = in_ch + op.B.alignment = in_ch + op.tile_description.stages = 3 + return op + + +def extract_config(func_attrs): + """extract epilogue for conv op + + Parameters + ---------- + func_attrs : Dict + [description] op attributes + + Returns + ------- + [type]: Dict + [description] + + Raises + ------ + NotImplementedError + [description] + """ + import copy + + import cutlass_lib + + def f_proc_op_special(op): + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + op = apply_special_config(func_attrs, op) + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + op_kind = cutlass_lib.library.OperationKind.Conv2d + conv_kind = cutlass_lib.library.ConvKind.Fprop + ret = [] + conv2d_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.conv_kind == conv_kind: + ret = f_proc_op_special(op) + if len(ret) > 0: + for op_inst in ret: + key = common.kernel_name(op_inst) + conv2d_ops[key] = op_inst + return conv2d_ops diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d.py b/python/aitemplate/backend/cuda/conv2d/conv2d.py new file mode 100644 index 000000000..7e5da403f --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d.py @@ -0,0 +1,420 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen for conv2d. +""" +import jinja2 + +from ... import registry +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301 + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = cutlass::conv::device::ImplicitGemmConvolution<{{config_name}}>; +""" +) + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementCompute; +// TODO: cast to right dtype +{{indent}}typename {{instance}}::Arguments arguments{ +{{indent}} problem_size, +{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, +{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, +{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, +{{indent}}}; +{{indent}}{{instance}} implicit_gemm_op; +{% if is_profiler %} +{{indent}}size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}auto status = implicit_gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op.initialize(arguments, workspace); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +{{extra_header}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + cutlass::half_t* in_ptr, + cutlass::half_t* weight_ptr, + cutlass::half_t* out_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_h, + int64_t* out_w, + int stride, + int dilation, + int pad, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNHWC; + TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); + TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv2dProblemSize problem_size( + {i32_batch, i32_in_h, i32_in_w, i32_in_ch}, + {i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, + {pad, pad, pad, pad}, + {stride, stride}, + {dilation, dilation}, + {i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, + cutlass::conv::Mode::kCrossCorrelation, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv2d specialization." + ); +} +""" +) + + +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; + +{{op_func}} + +int main(int argc, char** argv) { + int64_t batch = std::stoi(argv[1]); + int64_t in_h = std::stoi(argv[2]); + int64_t in_w = std::stoi(argv[3]); + int64_t in_ch = std::stoi(argv[4]); + int64_t kernel_h = std::stoi(argv[5]); + int64_t kernel_w = std::stoi(argv[6]); + int64_t out_ch = std::stoi(argv[7]); + int stride = std::stoi(argv[8]); + int pad = std::stoi(argv[9]); + int dilation = std::stoi(argv[10]); + {{shape_func}} + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + cutlass::HostTensor x({NI, HI, WI, CI}); + cutlass::HostTensor w({CO, KH, KW, CI}); + cutlass::HostTensor y({NO, HO, WO, CO}); + + // + // warmup + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 5; ++i) { + conv((cutlass::half_t*) x.device_data(), + (cutlass::half_t*) w.device_data(), + (cutlass::half_t*) y.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KH, + &KW, + &HI, + &WI, + &NO, + &HO, + &WO, + stride, + dilation, + pad, + stream); + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + +""" +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + uint8_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int, + int, + int, + cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{in_ptr}}, +{{indent}} {{weight_ptr}}, +{{indent}} {{out_ptr}}, +{{indent}} global_workspace, +{{indent}} {{p_batch}}, +{{indent}} {{p_out_ch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_kernel_h}}, +{{indent}} {{p_kernel_w}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} {{stride}}, +{{indent}} {{dilation}}, +{{indent}} {{pad}}, +{{indent}} stream +{{indent}}); +""" +) + + +@registry.reg("cuda.conv2d.config") +def conv2d_config(func_attrs, dtype="float16"): + """Populates conv2d cutlass configs into 'op_instance' field.""" + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv2d.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """Codegen for conv2d profiler.""" + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_h", + x_dim2="in_w", + x_dim3="in_ch", + w_dim0="out_ch", + w_dim1="kernel_h", + w_dim2="kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + ) + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance(op) + config_name = common.extract_config_name(config) + name = "DeviceConvFwdInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + dtype="cutlass::half_t", + shape_func="", + exec_paths=exec_program, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +@registry.reg("cuda.conv2d.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """Codegen for conv2d function.""" + return common.gen_function( + func_attrs, + INSTANCE_TEMPLATE, + EXEC_TEMPLATE, + SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d.func_decl") +def conv2d_gen_function_decl(func_attrs): + """Codegen for conv2d function declaration.""" + func_name = func_attrs["name"] + return FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + """Codegen for conv2d function call.""" + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.conv2d.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py new file mode 100644 index 000000000..c1ce2ac94 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias codegen +""" +from ... import registry +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias.config") +def conv2d_config(func_attrs, dtype="float16"): + """Populates all available conv2d configs into the op_instance field.""" + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """Codegen for conv2d profiler.""" + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """Codegen for conv2d function.""" + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias.func_decl") +def conv2d_gen_function_decl(func_attrs): + """Codegen for conv2d function declaration.""" + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + """Codegen for conv2d function call.""" + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py new file mode 100644 index 000000000..663495f22 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias add codegen +""" +from ... import registry +from ...target import Target +from . import common, common_conv2d_bias_add_activation as cbaa + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_add_identity.config") +def conv2d_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + + op.activation_op = cutlass_lib.library.EpilogueMathName["Identity"] + op.binary_op = cutlass_lib.library.EpilogueMathName["Plus"] + op.unary_op = cutlass_lib.library.EpilogueMathName["Identity"] + + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + func_attrs["op_instance"] = common.extract_config(func_attrs, fproc_f16) + + +@registry.reg("cuda.conv2d_bias_add_identity.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cbaa.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_add_identity.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cbaa.INSTANCE_TEMPLATE, + cbaa.EXEC_TEMPLATE, + cbaa.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_add_identity.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cbaa.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_add_identity.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + b = func_attrs["inputs"][2] + r = func_attrs["inputs"][3] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return cbaa.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + bias_ptr=b._attrs["name"], + res_ptr=r._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.conv2d_bias_add_identity.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py new file mode 100644 index 000000000..10aa46619 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias add hardswish codegen +""" +from ... import registry +from ...target import Target +from . import common, common_conv2d_bias_add_activation as cbaa + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_add_hardswish.config") +def conv2d_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + + op.activation_op = cutlass_lib.library.EpilogueMathName["Identity"] + op.binary_op = cutlass_lib.library.EpilogueMathName["Add"] + op.unary_op = cutlass_lib.library.EpilogueMathName["HardSwish"] + + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + func_attrs["op_instance"] = common.extract_config(func_attrs, fproc_f16) + + +@registry.reg("cuda.conv2d_bias_add_hardswish.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cbaa.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_add_hardswish.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cbaa.INSTANCE_TEMPLATE, + cbaa.EXEC_TEMPLATE, + cbaa.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_add_hardswish.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cbaa.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_add_hardswish.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + b = func_attrs["inputs"][2] + r = func_attrs["inputs"][3] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return cbaa.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + bias_ptr=b._attrs["name"], + res_ptr=r._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.conv2d_bias_add_hardswish.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py new file mode 100644 index 000000000..b6b96704f --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias add relu codegen +""" +from ... import registry +from ...target import Target +from . import common, common_conv2d_bias_add_activation as cbaa + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_add_relu.config") +def conv2d_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + + op.activation_op = cutlass_lib.library.EpilogueMathName["Identity"] + op.binary_op = cutlass_lib.library.EpilogueMathName["Plus"] + op.unary_op = cutlass_lib.library.EpilogueMathName["ReLu"] + + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + func_attrs["op_instance"] = common.extract_config(func_attrs, fproc_f16) + + +@registry.reg("cuda.conv2d_bias_add_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cbaa.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_add_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cbaa.INSTANCE_TEMPLATE, + cbaa.EXEC_TEMPLATE, + cbaa.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_add_relu.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cbaa.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_add_relu.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + b = func_attrs["inputs"][2] + r = func_attrs["inputs"][3] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return cbaa.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + bias_ptr=b._attrs["name"], + res_ptr=r._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.conv2d_bias_add_relu.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py new file mode 100644 index 000000000..b8ddfa205 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +specialize conv2d op with few channels(< 8) +""" +from collections import OrderedDict + +from ... import registry +from ...target import Target +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + + +def apply_special_config(func_attrs, op): + import cutlass_lib + + x = func_attrs["inputs"][0] + in_ch = x._attrs["shape"][-1]._attrs["values"][0] + + if in_ch == 3: + # By default we don't use it since the perf is worse than pad4+fixchannel + op.iterator_algorithm = cutlass_lib.library.IteratorAlgorithm.FewChannels + op.A.alignment = 1 + op.B.alignment = 1 + op.tile_description.stages = 2 + elif in_ch in [2, 4, 8]: + op.iterator_algorithm = cutlass_lib.library.IteratorAlgorithm.FixedChannels + op.A.alignment = in_ch + op.B.alignment = in_ch + op.tile_description.stages = 3 + return op + + +def extract_config(func_attrs): + """extract epilogue for conv op + + Parameters + ---------- + func_attrs : Dict + [description] op attributes + + Returns + ------- + [type]: Dict + [description] + + Raises + ------ + NotImplementedError + [description] + """ + import copy + + import cutlass_lib + + def f_proc_op_special(op): + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.accumulator_type() == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + op = apply_special_config(func_attrs, op) + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + op_kind = cutlass_lib.library.OperationKind.Conv2d + conv_kind = cutlass_lib.library.ConvKind.Fprop + ret = [] + conv2d_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.conv_kind == conv_kind: + ret = f_proc_op_special(op) + if len(ret) > 0: + for op_inst in ret: + key = common.kernel_name(op_inst) + conv2d_ops[key] = op_inst + return conv2d_ops + + +@registry.reg("cuda.conv2d_bias_few_channels.config") +def conv2d_config(func_attrs, dtype="float16"): + """extract configurations for profiling + + Parameters + ---------- + func_attrs : Dict + [description] op attributes + dtype : str, optional + [description] by default "float16" + + Returns + ------- + [type] + [description] + + Raises + ------ + NotImplementedError + [description] + """ + func_attrs["op_instance"] = extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_few_channels.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """generate code for profiling""" + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_few_channels.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """generating special conv2d kernel and all of its auxiliary functions + + Parameters + ---------- + func_attrs : Dict + [description] attributes of conv2d op + exec_cond_remplate : [type] + [description] + shape_eval_template : [type] + [description] + shape_save_template : [type] + [description] + + Returns + ------- + [type] + [description] + """ + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_few_channels.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_few_channels.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_few_channels.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py new file mode 100644 index 000000000..e31ad9095 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias hardswish codegen +""" +from ... import registry +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_hardswish.config") +def conv2d_config(func_attrs, dtype="float16"): + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_hardswish.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_hardswish.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_hardswish.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_hardswish.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_hardswish.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py new file mode 100644 index 000000000..f305f3344 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +specialize conv2d op with few channels(< 8) +""" + +from ... import registry + +from . import common, common_conv2d_bias_activation as cba +from .common_conv2d_few_channels import extract_config + + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.config") +def conv2d_config(func_attrs, dtype="float16"): + """extract configurations for profiling + + Parameters + ---------- + func_attrs : Dict + [description] op attributes + dtype : str, optional + [description] by default "float16" + + Returns + ------- + [type] + [description] + + Raises + ------ + NotImplementedError + [description] + """ + func_attrs["op_instance"] = extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """generate code for profiling""" + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """generating special conv2d kernel and all of its auxiliary functions + + Parameters + ---------- + func_attrs : Dict + [description] attributes of conv2d op + exec_cond_remplate : [type] + [description] + shape_eval_template : [type] + [description] + shape_save_template : [type] + [description] + + Returns + ------- + [type] + [description] + """ + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_hardswish_few_channels.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py new file mode 100644 index 000000000..ea75bdd9d --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias relu codegen +""" +from ... import registry +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_relu.config") +def conv2d_config(func_attrs, dtype="float16"): + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_relu.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_relu.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_relu.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py new file mode 100644 index 000000000..e207bc10a --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +specialize conv2d op with few channels(< 8) +""" + +from ... import registry +from . import common, common_conv2d_bias_activation as cba +from .common_conv2d_few_channels import extract_config + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.config") +def conv2d_config(func_attrs, dtype="float16"): + """extract configurations for profiling + + Parameters + ---------- + func_attrs : Dict + op attributes + dtype : str, optional + by default "float16" + + Returns + ------- + None + """ + func_attrs["op_instance"] = extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """generate code for profiling""" + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """generating special conv2d kernel and all of its auxiliary functions + + Parameters + ---------- + func_attrs : Dict + [description] attributes of conv2d op + exec_cond_remplate : [type] + [description] + shape_eval_template : [type] + [description] + shape_save_template : [type] + [description] + + Returns + ------- + [type] + [description] + """ + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_relu_few_channels.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py new file mode 100644 index 000000000..5ad4ccd6a --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias sigmoid codegen +""" + +from ... import registry +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + + +@registry.reg("cuda.conv2d_bias_sigmoid.config") +def conv2d_config(func_attrs, dtype="float16"): + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv2d_bias_sigmoid.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + cba.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.conv2d_bias_sigmoid.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + cba.SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv2d_bias_sigmoid.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv2d_bias_sigmoid.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + return cba.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.conv2d_bias_sigmoid.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py new file mode 100644 index 000000000..b1b6acbc1 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +transposed conv2d op codegen +""" +import re + +import jinja2 + +from ... import registry +from . import common, conv2d + +# pylint: disable=C0103,C0415,W0613,C0301 + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +{{extra_header}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + cutlass::half_t* in_ptr, + cutlass::half_t* weight_ptr, + cutlass::half_t* out_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_h, + int64_t* out_w, + int stride, + int dilation, + int pad, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNHWC; + TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); + TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv2dProblemSize problem_size( + {i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, + {i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, + {pad, pad, pad, pad}, + {stride, stride}, + {dilation, dilation}, + {i32_batch, i32_in_h, i32_in_w, i32_in_ch}, + cutlass::conv::Mode::kCrossCorrelation, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv2d specialization." + ); +} +""" +) + + +def conv_transpose_instance(op_def): + tmp = op_def.replace("DefaultConv2dFprop", "DefaultConv2dDgrad") + tmp = re.sub( + r"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<\d>", + "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>", + tmp, + ) + return tmp + + +def emit_instance(op, f_instance_convertor=conv_transpose_instance): + """Emits cutlass instance.""" + import cutlass_lib + + emiter = cutlass_lib.conv2d_operation.EmitConv2dInstance() + op_def = emiter.emit(op) + op_def = f_instance_convertor(op_def) + return op_def + + +@registry.reg("cuda.transposed_conv2d.config") +def conv2d_config(func_attrs, dtype="float16"): + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.transposed_conv2d.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + conv2d.INSTANCE_TEMPLATE, + conv2d.EXEC_TEMPLATE, + SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + f_emit_instance=emit_instance, + ) + + +@registry.reg("cuda.transposed_conv2d.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return conv2d.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.transposed_conv2d.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return conv2d.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.transposed_conv2d.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_h", + x_dim2="in_w", + x_dim3="in_ch", + w_dim0="out_ch", + w_dim1="kernel_h", + w_dim2="kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + ) + file_pairs = [] + for op_name, op in op_instance.items(): + config = emit_instance(op) + + config_name = common.extract_config_name(config) + name = "DeviceConvBwdInstance" + instance = conv2d.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = conv2d.EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + dtype="cutlass::half_t", + shape_func="", + exec_paths=exec_program, + ) + code = conv2d.PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +@registry.reg("cuda.transposed_conv2d.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py new file mode 100644 index 000000000..2df9642fa --- /dev/null +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +transposed conv2d + bias + (relu) codegen +""" +import re + +import jinja2 + +from ... import registry +from . import common, common_conv2d_bias_activation as cba + +# pylint: disable=C0103,C0415,W0613,C0301 + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +{{extra_header}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("Got cutlass error: ") + cutlassGetStatusString(error) + \\ + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + cutlass::half_t* in_ptr, + cutlass::half_t* weight_ptr, + cutlass::half_t* out_ptr, + cutlass::half_t* bias_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_h, + int64_t* out_w, + int stride, + int dilation, + int pad, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNHWC; + TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); + TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv2dProblemSize problem_size( + {i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, + {i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, + {pad, pad, pad, pad}, + {stride, stride}, + {dilation, dilation}, + {i32_batch, i32_in_h, i32_in_w, i32_in_ch}, + cutlass::conv::Mode::kCrossCorrelation, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv2d specialization." + ); +} +""" +) + + +def _conv_transpose_instance(op_def): + tmp = op_def.replace("DefaultConv2dFprop", "DefaultConv2dDgrad") + tmp = re.sub( + r"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<\d>", + "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>", + tmp, + ) + return tmp + + +def emit_instance(op, f_instance_convertor=_conv_transpose_instance): + import cutlass_lib + + emiter = cutlass_lib.conv2d_operation.EmitConv2dInstance() + op_def = emiter.emit(op) + op_def = f_instance_convertor(op_def) + return op_def + + +@registry.reg("cuda.transposed_conv2d_bias.config") +@registry.reg("cuda.transposed_conv2d_bias_relu.config") +def conv2d_config(func_attrs, dtype="float16"): + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.transposed_conv2d_bias.gen_function") +@registry.reg("cuda.transposed_conv2d_bias_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + return common.gen_function( + func_attrs, + cba.INSTANCE_TEMPLATE, + cba.EXEC_TEMPLATE, + SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + f_emit_instance=emit_instance, + ) + + +@registry.reg("cuda.transposed_conv2d_bias.func_decl") +@registry.reg("cuda.transposed_conv2d_bias_relu.func_decl") +def conv2d_gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return cba.FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.transposed_conv2d_bias.func_call") +@registry.reg("cuda.transposed_conv2d_bias_relu.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + b = func_attrs["inputs"][2] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return cba.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + bias_ptr=b._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[3]._attrs["name"], + p_kernel_h="&" + wshape[1]._attrs["name"], + p_kernel_w="&" + wshape[2]._attrs["name"], + p_in_h="&" + xshape[1]._attrs["name"], + p_in_w="&" + xshape[2]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_h="&" + yshape[1]._attrs["name"], + p_out_w="&" + yshape[2]._attrs["name"], + stride=func_attrs["stride"], + dilation=func_attrs["dilate"], + pad=func_attrs["pad"], + indent=indent, + ) + + +@registry.reg("cuda.transposed_conv2d_bias.gen_profiler") +@registry.reg("cuda.transposed_conv2d_bias_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_h", + x_dim2="in_w", + x_dim3="in_ch", + w_dim0="out_ch", + w_dim1="kernel_h", + w_dim2="kernel_w", + stride="stride", + dilate="dilation", + pad="pad", + ) + file_pairs = [] + for op_name, op in op_instance.items(): + config = emit_instance(op) + + config_name = common.extract_config_name(config) + name = "DeviceConvBwdInstance" + instance = cba.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = cba.EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + dtype="cutlass::half_t", + shape_func="", + exec_paths=exec_program, + ) + code = cba.PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +@registry.reg("cuda.transposed_conv2d_bias.filter") +@registry.reg("cuda.transposed_conv2d_bias_relu.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/cuda_common.py b/python/aitemplate/backend/cuda/cuda_common.py new file mode 100644 index 000000000..20093b05c --- /dev/null +++ b/python/aitemplate/backend/cuda/cuda_common.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +CUDA common functions for codegen. +""" +from typing import Dict + +DTYPE_TO_CUDATYPE: Dict[str, str] = { + "float16": "half", + "float": "float", + "int64": "int64_t", +} + + +DTYPE_TO_CUTLASSTYPE: Dict[str, str] = { + "float16": "cutlass::half_t", + "float": "float", +} + + +def dtype_to_cuda_type(dtype: str): + """Returns the corresponding cuda type.""" + cuda_type = DTYPE_TO_CUDATYPE.get(dtype) + + if cuda_type is None: + raise NotImplementedError("CUDA - Unsupported dtype: {}".format(dtype)) + return cuda_type + + +def dtype_to_cutlass_type(dtype: str): + """Returns the corresponding cutlass type.""" + cutlass_type = DTYPE_TO_CUTLASSTYPE.get(dtype) + + if cutlass_type is None: + raise NotImplementedError("CUDA - Unsupported dtype: {}".format(dtype)) + return cutlass_type diff --git a/python/aitemplate/backend/cuda/elementwise/__init__.py b/python/aitemplate/backend/cuda/elementwise/__init__.py new file mode 100644 index 000000000..0bf6e473f --- /dev/null +++ b/python/aitemplate/backend/cuda/elementwise/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +(c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +""" +from . import fused_elementwise + +__all__ = ["fused_elementwise"] diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh new file mode 100644 index 000000000..2adddd531 --- /dev/null +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -0,0 +1,299 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef CUSTOM_MATH +#define CUSTOM_MATH + +#ifndef __HALF2_TO_UI +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#endif + +#ifndef __HALF_TO_US +#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) +#endif + +template +__device__ T sign_custom(const T a) { + return T(a > T(0)) - T(a < T(0)); +} + +__device__ half2 h2sign_custom(const half2 a) { + return half2(sign_custom(a.x), sign_custom(a.y)); +} + +__device__ half2 fast_tanh(half2 x) { +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && \ + (__CUDA_ARCH__ >= 750) + + asm volatile("tanh.approx.f16x2 %0, %1;" + : "=r"(__HALF2_TO_UI(x)) + : "r"(__HALF2_TO_UI(x))); + return x; + +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif +} + +__device__ half fast_tanh(half x) { +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && \ + (__CUDA_ARCH__ >= 750) + + asm volatile("tanh.approx.f16 %0, %1;" + : "=h"(__HALF_TO_US(x)) + : "h"(__HALF_TO_US(x))); + return x; + +#else + return half(cutlass::fast_tanh(float(x))); +#endif +} + +// Return 1 +__device__ half one() { + uint16_t bits = 0x3c00u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for half_t) +__device__ half constant_half() { + uint16_t bits = 0x3800u; + return reinterpret_cast(bits); +} + +__device__ float fsigmoid_custom(const float a) { + return (cutlass::fast_tanh(a * 0.5f) + 1.0f) * 0.5f; +} + +__device__ half hsigmoid_custom(const half a) { + half half_val = constant_half(); + half one_val = one(); + return __hmul((__hadd(fast_tanh(__hmul(a, half_val)), one_val)), half_val); +} + +__device__ half2 h2sigmoid_custom(const half2 a) { + half2 halfX2 = half2(constant_half(), constant_half()); + half2 oneX2 = half2(one(), one()); + return __hmul2((__hadd2(fast_tanh(__hmul2(a, halfX2)), oneX2)), halfX2); +} + +__device__ float fsilu(const float a) { + return a * fsigmoid_custom(a); +} + +__device__ half hsilu(const half a) { + return __hmul(a, hsigmoid_custom(a)); +} + +__device__ half2 h2silu(const half2 a) { + return __hmul2(a, h2sigmoid_custom(a)); +} + +__device__ float leaky_relu(const float a, const float negativeSlope) { + return a > 0.f ? a : a * negativeSlope; +} + +__device__ half leaky_relu(const half a, const half negativeSlope) { + return a > half(0.f) ? a : __hmul(a, negativeSlope); +} + +__device__ half2 leaky_relu(const half2 a, const half2 negativeSlope) { + return half2( + leaky_relu(a.x, negativeSlope.x), leaky_relu(a.y, negativeSlope.y)); +} + +__device__ float relu(const float a) { + return a > 0.f ? a : 0.f; +} + +__device__ half relu(const half a) { + return a > half(0.f) ? a : half(0.f); +} + +__device__ half2 relu(const half2 a) { + half2 zeroX2 = half2(half(0.f), half(0.f)); +#if __CUDA_ARCH__ >= 800 + return __hmax2(a, zeroX2); +#else + return half2(relu(a.x), relu(a.y)); +#endif +} + +template +__device__ T hard_tanh(const T a, T min_val, T max_val) { + if (a <= min_val) { + return min_val; + } else if (a >= max_val) { + return max_val; + } else { + return a; + } +} + +__device__ half2 +h2hard_tanh(const half2 a, const half2 min_val, const half2 max_val) { + return half2( + hard_tanh(a.x, min_val.x, max_val.x), + hard_tanh(a.y, min_val.y, max_val.y)); +} + +__device__ half replace_if_inf( + const half a, + const half inf_replace, + const half neginf_replace) { + auto is_inf = __hisinf(a); + if (is_inf == -1) { + return neginf_replace; + } + if (is_inf == 1) { + return inf_replace; + } + return a; +} + +__device__ float replace_if_inf( + const float a, + const float inf_replace, + const float neginf_replace) { + auto is_inf = isinf(a); + if (is_inf == -1) { + return neginf_replace; + } + if (is_inf == 1) { + return inf_replace; + } + return a; +} + +__device__ half2 nan_to_num( + const half2 a, + const half2 nan_replace, + const half2 inf_replace, + const half2 neginf_replace) { + half2 isnan = __hisnan2(a); + return half2( + isnan.x ? nan_replace.x + : replace_if_inf(a.x, inf_replace.x, neginf_replace.x), + isnan.y ? nan_replace.y + : replace_if_inf(a.y, inf_replace.y, neginf_replace.y)); +} + +__device__ half nan_to_num( + const half a, + const half nan_replace, + const half inf_replace, + const half neginf_replace) { + if (__hisnan(a)) { + return nan_replace; + } + return replace_if_inf(a, inf_replace, neginf_replace); +} + +__device__ float nan_to_num( + const float a, + const float nan_replace, + const float inf_replace, + const float neginf_replace) { + if (isnan(a)) { + return nan_replace; + } + return replace_if_inf(a, inf_replace, neginf_replace); +} + +__device__ half2 clamp_nan_to_num( + const half2 a, + const half2 clamp_min, + const half2 clamp_max, + const half2 nan_replace) { + half2 isnan = __hisnan2(a); + return half2( + isnan.x ? nan_replace.x : hard_tanh(a.x, clamp_min.x, clamp_max.x), + isnan.y ? nan_replace.y : hard_tanh(a.y, clamp_min.y, clamp_max.y)); +} + +__device__ half clamp_nan_to_num( + const half a, + const half clamp_min, + const half clamp_max, + const half nan_replace) { + return __hisnan(a) ? nan_replace : hard_tanh(a, clamp_min, clamp_max); +} + +__device__ float clamp_nan_to_num( + const float a, + const float clamp_min, + const float clamp_max, + const float nan_replace) { + return isnan(a) ? nan_replace : hard_tanh(a, clamp_min, clamp_max); +} + +// Backup functions for CUDA_ARCH < 800 +__device__ half nanh() { + return __float2half(nanf("")); +} + +__device__ bool half_isnan(half h) { + return h != h; +} + +__device__ half hmin(half a, half b) { + return (a < b) ? a : b; +} + +__device__ half hmax(half a, half b) { + return (a > b) ? a : b; +} + +// max/min functions that let NaNs pass through +__device__ float fmaxf_nan(const float a, const float b) { + return (isnan(a) || isnan(b)) ? nanf("") : fmaxf(a, b); +} + +__device__ half hmax_nan(const half a, const half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmax_nan(a, b); +#else + return (half_isnan(a) || half_isnan(b)) ? nanh() : hmax(a, b); +#endif +} + +__device__ half2 hmax2_nan(const half2 a, const half2 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmax2_nan(a, b); +#else + return half2(hmax_nan(a.x, b.x), hmax_nan(a.y, b.y)); +#endif +} + +__device__ float fminf_nan(const float a, const float b) { + return (isnan(a) || isnan(b)) ? nanf("") : fminf(a, b); +} + +__device__ half hmin_nan(const half a, const half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmin_nan(a, b); +#else + return (half_isnan(a) || half_isnan(b)) ? nanh() : hmin(a, b); +#endif +} + +__device__ half2 hmin2_nan(const half2 a, const half2 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmin2_nan(a, b); +#else + return half2(hmin_nan(a.x, b.x), hmin_nan(a.y, b.y)); +#endif +} + +#endif diff --git a/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py b/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py new file mode 100644 index 000000000..f25013aec --- /dev/null +++ b/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Elementwise codegen for CUDA. +""" + +import os +from typing import Any, Dict + +from ... import registry +from ...backend_spec import CUDASpec +from ...common import elementwise_common +from ...target import Target + +HEAD_TEMPLATE = """ +#include +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/constants.h" +""" + + +@registry.reg("cuda.fused_elementwise.gen_function") +def fused_elementwise_gen_function(func_attrs: Dict[str, Any]) -> str: + """Generates fused_elementwise function definition.""" + custom_libs = Target.current().get_custom_libs( + os.path.dirname(__file__), "custom_math.cuh" + ) + return elementwise_common.fused_elementwise_gen_function( + func_attrs=func_attrs, + custom_libs=custom_libs, + head_template=HEAD_TEMPLATE, + backend_spec=CUDASpec(), + ) + + +@registry.reg("cuda.fused_elementwise.func_decl") +def fused_elementwise_gen_function_decl(func_attrs): + """Generates fused_elementwise function declaration.""" + return elementwise_common.fused_elementwise_gen_function_decl( + func_attrs=func_attrs, + backend_spec=CUDASpec(), + ) + + +@registry.reg("cuda.fused_elementwise.func_call") +def fused_elementwise_gen_function_call(func_attrs, indent): + """Generates fused_elementwise function call.""" + return elementwise_common.fused_elementwise_gen_function_call( + func_attrs=func_attrs, + indent=indent, + backend_spec=CUDASpec(), + ) diff --git a/python/aitemplate/backend/cuda/embedding/__init__.py b/python/aitemplate/backend/cuda/embedding/__init__.py new file mode 100644 index 000000000..3e3aab46b --- /dev/null +++ b/python/aitemplate/backend/cuda/embedding/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from .bert_embeddings import * diff --git a/python/aitemplate/backend/cuda/embedding/bert_embeddings.py b/python/aitemplate/backend/cuda/embedding/bert_embeddings.py new file mode 100644 index 000000000..2ca8d5816 --- /dev/null +++ b/python/aitemplate/backend/cuda/embedding/bert_embeddings.py @@ -0,0 +1,450 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +bert_embeddings kernel codegen for CUDA. +""" + +from typing import Any, Dict + +import jinja2 + +from ... import registry + +# pylint: disable=C0301 + +FUNC_TEMPLATE = jinja2.Template( + """ +#include +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +#define FINAL_MASK 0xffffffff + +namespace { + +template +__inline__ __device__ T warpReduceSum(T* val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val[0] += __shfl_xor_sync(FINAL_MASK, val[0], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) { + __shared__ T shared[33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + shared[wid] = val[0]; + } + + __syncthreads(); + + // blockDim.x is round up to multiples of 32 + bool is_mask = threadIdx.x < (blockDim.x / 32); +#pragma unroll + val[0] = is_mask ? shared[lane] : (T)(0.0f); + + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T normalize(T val, T mean, T variance, T gamma, T beta) { + return (val - mean) * variance * gamma + beta; +} + +// __inline__ __device__ float sigmoid(float val) { +// return 1.0f / (1.0f + expf(-1.0f * val)); +// } + +// fast sigmoid +__inline__ __device__ float sigmoid(float val) { + return (cutlass::fast_tanh(val * 0.5f) + 1.0f) * 0.5f; +} + +template +__global__ void bert_embeddings_kernel( + uint4* output, + INDEX_T* input_ids, + INDEX_T* token_type_ids, + INDEX_T* position_ids, + uint4* word_embeddings, + uint4* token_type_embeddings, + uint4* position_embeddings, + uint4* gamma, + uint4* beta, + const int64_t embedding_dim, + const int64_t vocab_size, + const int64_t type_vocab_size, + const int64_t max_position_embeddings, + const float eps) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int embedding_dim_div_8 = embedding_dim / 8; + + const int64_t input_id = input_ids[bid]; + const int64_t token_type_id = token_type_ids[bid]; + const int64_t position_id = position_ids[bid]; + + // index bound check + if (input_id < 0 || input_id >= vocab_size || token_type_id < 0 || + token_type_id >= type_vocab_size || position_id < 0 || + position_id >= max_position_embeddings) { + return; + } + + word_embeddings = word_embeddings + input_id * embedding_dim_div_8; + token_type_embeddings = + token_type_embeddings + token_type_id * embedding_dim_div_8; + position_embeddings = position_embeddings + position_id * embedding_dim_div_8; + + uint4 word_embedding{0, 0, 0, 0}; + uint4 token_type_embedding{0, 0, 0, 0}; + uint4 position_embedding{0, 0, 0, 0}; + + if (tid < embedding_dim_div_8) { + word_embedding = word_embeddings[tid]; + token_type_embedding = token_type_embeddings[tid]; + position_embedding = position_embeddings[tid]; + } + uint4 embedding{0, 0, 0, 0}; + + half* word_emb_vec = reinterpret_cast(&word_embedding); + half* token_emb_vec = reinterpret_cast(&token_type_embedding); + half* pos_emb_vec = reinterpret_cast(&position_embedding); + + half* emb_vec = reinterpret_cast(&embedding); + + // layernorm + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + +#pragma unroll + for (int i = 0; i < 8; i++) { + float sum = word_emb_vec[i] + token_emb_vec[i] + pos_emb_vec[i]; + local_sums[0] += sum; + emb_vec[i] = (half)sum; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / embedding_dim; + } + __syncthreads(); + + local_sums[0] = 0.0f; + + if (tid < embedding_dim_div_8) { +#pragma unroll + for (int i = 0; i < 8; i++) { + float val = emb_vec[i]; + local_sums[0] += (val - s_mean) * (val - s_mean); + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / embedding_dim + eps); + } + __syncthreads(); + + if (tid < embedding_dim_div_8) { + uint4 local_gamma = gamma[tid]; + half* gamma_vec = reinterpret_cast(&local_gamma); + uint4 local_beta = beta[tid]; + half* beta_vec = reinterpret_cast(&local_beta); +#pragma unroll + for (int i = 0; i < 8; i++) { + emb_vec[i] = normalize( + (float)emb_vec[i], + s_mean, + s_variance, + (float)gamma_vec[i], + (float)beta_vec[i]); + } + } + + // write to output + if (tid < embedding_dim_div_8) { + output = output + bid * embedding_dim_div_8; + output[tid] = embedding; + } +} + +template +void bert_embeddings_launcher( + half* output, + INDEX_T* input_ids, + INDEX_T* token_type_ids, + INDEX_T* position_ids, + half* word_embeddings, + half* token_type_embeddings, + half* position_embeddings, + half* gamma, + half* beta, + const int64_t indices_num, + const int64_t embedding_dim, + const int64_t vocab_size, + const int64_t type_vocab_size, + const int64_t max_position_embeddings, + const float eps, + cudaStream_t stream) { + if (embedding_dim % 8 != 0) { + throw std::runtime_error("embedding dim must be multiple of 8"); + } + dim3 grid(indices_num); + + // round up to multiple of 32 + int64_t num_threads = embedding_dim / 8; + num_threads = (num_threads + 31) / 32 * 32; + dim3 block(num_threads); + + bert_embeddings_kernel<<>>( + reinterpret_cast(output), + input_ids, + token_type_ids, + position_ids, + reinterpret_cast(word_embeddings), + reinterpret_cast(token_type_embeddings), + reinterpret_cast(position_embeddings), + reinterpret_cast(gamma), + reinterpret_cast(beta), + embedding_dim, + vocab_size, + type_vocab_size, + max_position_embeddings, + eps); +} + +} // namespace + +{{func_signature}} +{ + bert_embeddings_launcher<{{index_type}}>( + output, + input_ids, + token_type_ids, + position_ids, + word_embeddings, + token_type_embeddings, + position_embeddings, + gamma, + beta, + indices_num, + embedding_dim, + vocab_size, + type_vocab_size, + max_position_embeddings, + eps, + stream + ); +} + +""" +) + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(half* output, + {{index_type}}* input_ids, + {{index_type}}* token_type_ids, + {{index_type}}* position_ids, + half* word_embeddings, + half* token_type_embeddings, + half* position_embeddings, + half* gamma, + half* beta, + const int64_t indices_num, + const int64_t embedding_dim, + const int64_t vocab_size, + const int64_t type_vocab_size, + const int64_t max_position_embeddings, + const float eps, + cudaStream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}} {{calculate_indices_num}} +{{indent}} {{func_name}}( +{{indent}} {{output}}, +{{indent}} {{input_ids}}, +{{indent}} {{token_type_ids}}, +{{indent}} {{position_ids}}, +{{indent}} {{word_embeddings}}, +{{indent}} {{token_type_embeddings}}, +{{indent}} {{position_embeddings}}, +{{indent}} {{gamma}}, +{{indent}} {{beta}}, +{{indent}} {{indices_num}}, +{{indent}} {{embedding_dim}}, +{{indent}} {{vocab_size}}, +{{indent}} {{type_vocab_size}}, +{{indent}} {{max_position_embeddings}}, +{{indent}} {{eps}}, +{{indent}} stream /* default stream */ +{{indent}} ); + +{{indent}}} + """ +) + +INDICES_NUM_TEMPLATE = jinja2.Template( + """ + int64_t indices_num = 1; + {% for dim_name in dim_names %} + indices_num *= {{dim_name}}; + {% endfor %} + """ +) + + +def python_int_dtype_to_c_dtype(dtype): + if dtype == "int64": + return "int64_t" + if dtype in ["int", "int32"]: + return "int32_t" + return dtype + + +@registry.reg("cuda.bert_embeddings.gen_function") +def bert_embeddings_gen_function(func_attrs: Dict[str, Any]) -> str: + dtype = python_int_dtype_to_c_dtype(func_attrs["inputs"][0]._attrs["dtype"]) + return FUNC_TEMPLATE.render( + index_type=dtype, + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + index_type=dtype, + ).strip(), + ) + + +@registry.reg("cuda.bert_embeddings.func_decl") +def bert_embeddings_gen_function_decl(func_attrs: Dict[str, Any]) -> str: + dtype = python_int_dtype_to_c_dtype(func_attrs["inputs"][0]._attrs["dtype"]) + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render( + func_name=func_attrs["name"], + index_type=dtype, + ).strip() + ) + + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast(&({{name}}->raw()))" +) + +FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") +FUNC_CALL_INT32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") + + +def get_int_param_template(tensor): + name = tensor._attrs["name"] + dtype = tensor._attrs["dtype"] + if dtype == "int64": + return FUNC_CALL_INT64_PARAM_TEMPLATE.render(name=name) + elif dtype in ("int", "int32"): + return FUNC_CALL_INT32_PARAM_TEMPLATE.render(name=name) + else: + raise NotImplementedError(f"Unsupported dtype: {dtype}") + + +@registry.reg("cuda.bert_embeddings.func_call") +def bert_embeddings_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: + ( + input_ids, + token_type_ids, + position_ids, + word_embeddings, + token_type_embeddings, + position_embeddings, + gamma, + beta, + ) = func_attrs["inputs"] + + indices_dims = [shape._attrs["name"] for shape in input_ids.shape()] + indices_num_str = INDICES_NUM_TEMPLATE.render( + dim_names=indices_dims, + ) + embedding_dim = word_embeddings._size(-1).value() + vocab_size = word_embeddings._size(0).value() + type_vocab_size = token_type_embeddings._size(0).value() + max_position_embeddings = position_embeddings._size(0).value() + + eps = func_attrs["eps"] + output_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + + input_ids_str = get_int_param_template(input_ids) + token_type_ids_str = get_int_param_template(token_type_ids) + position_ids_str = get_int_param_template(position_ids) + + word_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=word_embeddings._attrs["name"] + ) + token_type_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=token_type_embeddings._attrs["name"] + ) + position_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=position_embeddings._attrs["name"] + ) + + gamma_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=gamma._attrs["name"]) + beta_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=beta._attrs["name"]) + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + calculate_indices_num=indices_num_str, + output=output_str, + input_ids=input_ids_str, + token_type_ids=token_type_ids_str, + position_ids=position_ids_str, + word_embeddings=word_embeddings_str, + token_type_embeddings=token_type_embeddings_str, + position_embeddings=position_embeddings_str, + gamma=gamma_str, + beta=beta_str, + indices_num="indices_num", + embedding_dim=embedding_dim, + vocab_size=vocab_size, + type_vocab_size=type_vocab_size, + max_position_embeddings=max_position_embeddings, + eps=eps, + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py new file mode 100644 index 000000000..604984059 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from . import bmm_rcr_softmax, gemm_rcr_bias_softmax, gemm_rcr_softmax + +__all__ = ["bmm_rcr_softmax", "gemm_rcr_bias_softmax", "gemm_rcr_softmax"] diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py new file mode 100644 index 000000000..4a63ff1fc --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common functions and templates for bmm-family ops +""" +import jinja2 + +from ...common import gemm_common +from ..gemm_universal import common + +from . import common_softmax + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, +{% if has_bias %} + cutlass::half_t*, +{% endif %} + cutlass::half_t*, + cutlass::half_t*, + float*, + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{% if has_bias %} +{{indent}} {{bias_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} {{d_ptr}}, +{{indent}} {{n_ptr}}, +{{indent}} {{soft_ptr}}, +{{indent}} global_workspace, +{{indent}} {{a_dim0_ptr}}, +{{indent}} {{a_dim1_ptr}}, +{{indent}} {{a_dim2_ptr}}, +{{indent}} {{b_dim0_ptr}}, +{{indent}} {{b_dim1_ptr}}, +{{indent}} {{b_dim2_ptr}}, +{{indent}} {{c_dim0_ptr}}, +{{indent}} {{c_dim1_ptr}}, +{{indent}} {{c_dim2_ptr}}, +{{indent}} stream +{{indent}}); +""" +) + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + // cast to int64_t to avoid overflow + int64_t a_ptr_sz = static_cast(a_dim0) * static_cast(a_dim1) * static_cast(a_dim2); + int64_t b_ptr_sz = static_cast(b_dim0) * static_cast(b_dim1) * static_cast(b_dim2); + int64_t c_ptr_sz = static_cast(c_dim0) * static_cast(c_dim1) * static_cast(c_dim2); + int64_t ptr_max_sz = std::max({a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + + memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d_ptr: index 3 + memory_pool->AllocateFloatTensor(c_dim0 * c_dim1, mem_pool_sz); // n_ptr: index 4 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // soft_ptr: index 5 +""" +) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + args_parser_template, + emit_kernel=False, + bias_ptr_arg=None, +): + """Generate code for profiling""" + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + assert not (has_d and has_bias) + for op_name, op in op_instance.items(): + config = common_softmax.emit_instance(op, emit_kernel=emit_kernel) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = common_softmax.EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + problem_args=problem_args_template.render(), + ) + op_func = src_template.render( + custom_libs=common_softmax.gen_custom_libs(), + instances=instance, + function_name="bmm", + input_ndims=3, + weight_ndims=3, + shape_eval=shape_func, + exec_paths=exec_program, + has_d=has_d, + ) + func_call = FUNC_CALL_TEMPLATE.render( + func_name="bmm", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestTensorByIdx(2)", + d_ptr="memory_pool->RequestTensorByIdx(3)", + n_ptr="memory_pool->RequestTensorByIdx(4)", + soft_ptr="memory_pool->RequestTensorByIdx(5)", + has_d=has_d, + a_dim0_ptr="&a_dim0", + a_dim1_ptr="&a_dim1", + a_dim2_ptr="&a_dim2", + b_dim0_ptr="&b_dim0", + b_dim1_ptr="&b_dim1", + b_dim2_ptr="&b_dim2", + c_dim0_ptr="&c_dim0", + c_dim1_ptr="&c_dim1", + c_dim2_ptr="&c_dim2", + ) + code = common_softmax.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parser_template.render(), + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render( + name=name, has_d=has_d, has_bias=has_bias + ), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function_decl(func_attrs): + """Rendering argument to function declaration template""" + func_name = func_attrs["name"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + return FUNC_DECL_TEMPLATE.render(func_name=func_name, ndims=3, has_d=has_d) + + +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args, +): + """Generate the code for main function""" + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_softmax.gen_function( + func_attrs, + common_softmax.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + dim_info_dict=dim_info_dict, + emit_kernel=True, + ) + + +def gen_function_call(func_attrs, indent=" ", bias_ptr_arg=None): + """Rendering the code to function call template""" + + a = func_attrs["inputs"][0] + ashape = func_attrs["input_accessors"][0].original_shapes + b = func_attrs["inputs"][1] + bshape = func_attrs["input_accessors"][1].original_shapes + + c = func_attrs["inputs"][2] + d = func_attrs["inputs"][3] + n = func_attrs["inputs"][4] + + soft = func_attrs["outputs"][0] + cshape = func_attrs["output_accessors"][0].original_shapes + has_d = False + has_bias = bias_ptr_arg is not None + assert not (has_d and has_bias) + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr=c._attrs["name"], + d_ptr=d._attrs["name"], + n_ptr=n._attrs["name"], + soft_ptr=soft._attrs["name"], + has_d=has_d, + a_dim0_ptr="&" + ashape[0]._attrs["name"], + a_dim1_ptr="&" + ashape[1]._attrs["name"], + a_dim2_ptr="&" + ashape[2]._attrs["name"], + b_dim0_ptr="&" + bshape[0]._attrs["name"], + b_dim1_ptr="&" + bshape[1]._attrs["name"], + b_dim2_ptr="&" + bshape[2]._attrs["name"], + c_dim0_ptr="&" + cshape[0]._attrs["name"], + c_dim1_ptr="&" + cshape[1]._attrs["name"], + c_dim2_ptr="&" + cshape[2]._attrs["name"], + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py new file mode 100644 index 000000000..751a19a84 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor] +This is special in template based gemm solution +This is used for `torch.nn.functional.linear` +When use for `linear`, need set A->Data, B->Weight +""" +import jinja2 + +from ... import registry +from ..gemm_universal import common +from ..gemm_universal.layout import RCR +from . import bmm_common_softmax as bmm_common, common_softmax + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t B = std::atoi(argv[1]); + int64_t M = std::atoi(argv[2]); + int64_t N = std::atoi(argv[3]); + int64_t K = std::atoi(argv[4]); + + int64_t a_dim0 = B; + int64_t a_dim1 = M; + int64_t a_dim2 = K; + int64_t b_dim0 = B; + int64_t b_dim1 = N; + int64_t b_dim2 = K; + int64_t c_dim0 = B; + int64_t c_dim1 = M; + int64_t c_dim2 = N; +""" +) + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + /* + A: B*M*K (RowMajor) + B: B*N*K (ColumnMajor) + C/D/sofmax: B*M*N (RowMajor) + N: B*M*1 (RowMajor) + */ + + {M, N, K}, + B, + {a_ptr, LayoutA(K)}, + {b_ptr, LayoutB(K)}, + {c_ptr, LayoutC(N)}, + {d_ptr, LayoutC(N)}, + { + float(1.0), + float(0.0) + }, + {n_ptr, LayoutC(1)}, + {soft_ptr, LayoutC(N)}, + M*K, + N*K, + M*N, + M*N, + M*N, + M*N + + +""" +) + + +@registry.reg("cuda.bmm_rcr_softmax.config") +def bmm_rcr_softmax_config(func_attrs, dtype="float16"): + """This function sets a callback for processing the epilogue of the kernel + associated with func_attrs. + + Parameters + ---------- + func_attrs: Dictionary + kernel attributes dictionary + layout: layout object + kernel layout + Returns + ------- + None + """ + common.make_fproc_f16(func_attrs, RCR) + + +@registry.reg("cuda.bmm_rcr_softmax.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + """Generate code for profiling""" + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_softmax.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + ) + + +@registry.reg("cuda.bmm_rcr_softmax.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + """Generate the code for main function""" + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE.render(), + ) + + +@registry.reg("cuda.bmm_rcr_softmax.func_decl") +def gen_function_decl(func_attrs): + """Rendering argument to function declaration template""" + func_name = func_attrs["name"] + return bmm_common.FUNC_DECL_TEMPLATE.render(func_name=func_name, ndims=3) + + +@registry.reg("cuda.bmm_rcr_softmax.func_call") +def gen_function_call(func_attrs, indent=" "): + """Rendering the code to function call template""" + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rcr_softmax.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py new file mode 100644 index 000000000..ff5e4b084 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py @@ -0,0 +1,538 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common template for softmax. +""" +import os +import re +from hashlib import sha1 + +import jinja2 + +from ...common import gemm_common +from ...target import Target +from ..gemm_universal import common + +# pylint: disable=C0301,C0415,R1705 + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +#include "gemm_with_softmax.h" + +{{custom_libs}} + +{{extra_code}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, +{% if has_d %} + cutlass::half_t* d_ptr, +{% endif %} + cutlass::half_t* c_ptr, + cutlass::half_t* d_ptr, + float* n_ptr, + cutlass::half_t* soft_ptr, + uint8_t* workspace, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream + ) { + {{shape_eval}} + {{output_addr_calculator}} + {{extra_shape}} + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this gemm specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}typename {{instance}}::Arguments arguments{ + +{{problem_args}} + +{{indent}}}; +{{indent}}{{instance}} gemm_op; +{% if is_profiler %} +{{indent}}size_t workspace_size = 0; //gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} + +{{indent}}auto status = gemm_op.initialize(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; + +""" +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + float*, + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{% if has_bias %} +{{indent}} {{bias_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} {{d_ptr}}, +{{indent}} {{n_ptr}}, +{{indent}} {{soft_ptr}}, +{{indent}} global_workspace, +{{indent}} {{split_k}}, +{% for dim in adims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in bdims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in cdims %} +{{indent}} {{dim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +""" +) + + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + // cast to int64_t to avoid overflow + int64_t a_ptr_sz = static_cast(a_dim0) * static_cast(a_dim1); + int64_t b_ptr_sz = static_cast(b_dim0) * static_cast(b_dim1); + int64_t c_ptr_sz = static_cast(c_dim0) * static_cast(c_dim1); + int64_t ptr_max_sz = std::max({a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d_ptr: index 3 + memory_pool->AllocateFloatTensor(c_dim0, mem_pool_sz); // n_ptr: index 4 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // soft_ptr: index 5 +""" +) + + +DEFAULT_EXTRA_SHAPE_TEMPLATE = jinja2.Template( + """ +{{indent}}const int M = AM; +{{indent}}const int N = BN; +{{indent}}const int K = AK; +""" +) + + +# TODO Merge all alignment into single profiler +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; + +{{op_func}} + +struct ProfilerMemoryPool { + ProfilerMemoryPool() { + std::random_device rd; + gen = std::mt19937(rd()); + uniform_dist = std::uniform_int_distribution(1, 48964896); + offsets.reserve(512); + strides.reserve(512); + copies.reserve(512); + ptrs.reserve(512); + blobs.reserve(512); + } + ~ProfilerMemoryPool() {} + + template + DType* AllocateGaussianTensor(int64_t size) { + size_t length = size * sizeof(DType); + blobs.emplace_back(length); + DType* ptr = reinterpret_cast(blobs.back().get()); + + uint64_t seed = uniform_dist(gen); + double mean = 0.f; + double std = 1.f; + + cutlass::reference::device::BlockFillRandomGaussian(ptr, size, seed, mean, + std); + + return ptr; + } + + + cutlass::half_t* AllocateHalfGaussianTensor(int64_t size) { + return reinterpret_cast( + AllocateGaussianTensor<__half>(size)); + } + + int AllocateHalfTensor(int64_t size, int64_t copy) { + offsets.push_back(0); + strides.push_back(size); + copies.push_back(copy); + auto ptr = AllocateHalfGaussianTensor(size * copy); + ptrs.push_back(reinterpret_cast(ptr)); + return ptrs.size() - 1; + } + + float* AllocateFloatGaussianTensor(int64_t size) { + return reinterpret_cast( + AllocateGaussianTensor(size)); + } + + int AllocateFloatTensor(int64_t size, int64_t copy) { + offsets.push_back(0); + strides.push_back(size); + copies.push_back(copy); + auto ptr = AllocateFloatGaussianTensor(size * copy); + ptrs.push_back(reinterpret_cast(ptr)); + return ptrs.size() - 1; + } + + template + T* RequestTensorByIdx(int idx) { + auto copy = copies.at(idx); + auto offset = offsets.at(idx); + auto stride = strides.at(idx); + T* ptr = reinterpret_cast(ptrs.at(idx)); + ptr += offset; + offset += stride; + if (offset == copy * stride) { + offset = 0; + } + offsets[idx] = offset; + return ptr; + } + + std::vector offsets; + std::vector strides; + std::vector copies; + std::vector ptrs; + std::vector > blobs; + std::mt19937 gen; + std::uniform_int_distribution uniform_dist; +}; + +int main(int argc, char** argv) { + int device_idx; + cudaDeviceProp device_properties; + cudaError_t result = cudaGetDevice(&device_idx); + auto memory_pool = std::make_unique(); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&device_properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + + + {{args_parse}} + + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + using ElementInputN = typename {{name}}::ElementN; + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + {{tensor_decl}} + + // warmup + {{func_call}} + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 5; ++i) { + {{func_call}} + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} +""" +) + + +def gen_custom_libs(): + custom_libs = Target.current().get_custom_libs( + os.path.dirname(__file__), "include/gemm_with_softmax.h" + ) + return custom_libs + + +def _gemm_softmax_instance(op_def): + tmp = op_def.replace("GemmSoftmax", "GemmSoftmaxUniversal") + tmp = re.sub( + r"GemmIdentityThreadblockSwizzle<\d>", + "GemmBatchedIdentityThreadblockSwizzle", + tmp, + ) + return tmp + + +def emit_instance(op, f_instance_convertor=_gemm_softmax_instance, emit_kernel=False): + import cutlass_lib + + emiter = cutlass_lib.gemm_operation.EmitGemmInstance() + if emit_kernel: + emiter = cutlass_lib.gemm_operation.EmitGemmSoftmaxInstance() + + op_def = emiter.emit(op) + op_def = f_instance_convertor(op_def) + return op_def + + +def gen_function( + func_attrs, + src_template, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + dim_info_dict, + f_instance_convertor=_gemm_softmax_instance, + emit_kernel=False, + support_split_k=False, + output_addr_calculator="", + extra_code="", +): + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + inst_def_flag = set() + instances = {} + instance_decl = "" + for exec_item in exec_path.values(): + fname = "f" + sha1(exec_item.exec_cond.encode()).hexdigest() + algo = exec_item.algo + if algo not in inst_def_flag: + config = emit_instance(op_instance[algo], f_instance_convertor, emit_kernel) + inst_def_flag.add(algo) + else: + config = "" + inst = common.INSTANCE_TEMPLATE.render( + config=config, name=fname, config_name=common.extract_config_name(config) + ) + instances[exec_item.exec_cond] = inst + instance_decl += inst + shape_eval_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + exec_paths = "" + for key, _ in instances.items(): + fname = "f" + sha1(key.encode()).hexdigest() + program = EXEC_TEMPLATE.render( + indent=" ", + instance=fname, + problem_args=problem_args, + support_split_k=support_split_k, + ) + exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + return src_template.render( + custom_libs=gen_custom_libs(), + instances=instance_decl, + function_name=func_name, + dtype="cutlass::half_t", + shape_eval=shape_eval_func, + output_addr_calculator=output_addr_calculator, + exec_paths=exec_paths, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=support_split_k, + has_d=common.has_d(func_attrs), + has_d1=common.has_d1(func_attrs), + extra_code=extra_code, + ) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + args_parser_template, + emit_kernel=False, + support_split_k=False, + output_addr_calculator="", + bias_ptr_arg=None, + extra_code="", +): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + + ndims = 2 + adims = ["&a_dim" + str(i) for i in range(ndims)] + bdims = ["&b_dim" + str(i) for i in range(ndims)] + cdims = ["&c_dim" + str(i) for i in range(ndims)] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + for op_name, op in op_instance.items(): + config = emit_instance(op, emit_kernel=emit_kernel) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render(), + ) + op_func = src_template.render( + custom_libs=gen_custom_libs(), + instances=instance, + function_name="gemm", + input_ndims=2, + weight_ndims=2, + shape_eval=shape_func, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + func_call = FUNC_CALL_TEMPLATE.render( + func_name="gemm", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestTensorByIdx(2)", + d_ptr="memory_pool->RequestTensorByIdx(3)", + n_ptr="memory_pool->RequestTensorByIdx(4)", + soft_ptr="memory_pool->RequestTensorByIdx(5)", + split_k="split_k", + adims=adims, + bdims=bdims, + cdims=cdims, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parser_template.render(), + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_bias=has_bias), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py new file mode 100644 index 000000000..90e9d25a6 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor] +This is special in template based gemm solution +This is used for `torch.nn.functional.linear` +When use for `linear`, need set A->Data, B->Weight +""" +import jinja2 + +from ... import registry +from ..gemm_universal import common +from . import common_softmax, gemm_rcr_softmax + + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + /* + A: M*K (RowMajor) + B: N*K (ColumnMajor) + C/D/sofmax: M*N (RowMajor) + N: M*1 (RowMajor) + */ + + {M, N, K}, + 1, + {a_ptr, LayoutA(K)}, + {b_ptr, LayoutB(K)}, + {c_ptr, 0}, + {d_ptr, LayoutC(N)}, + { + float(1.0), + float(1.0) + }, + {n_ptr, LayoutC(1)}, + {soft_ptr, LayoutC(N)} + +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.config") +def gemm_rcr_bias_softmax_config(func_attrs, dtype="float16"): + return gemm_rcr_softmax.gemm_rcr_softmax_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return gemm_rcr_softmax.common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_softmax.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return gemm_rcr_softmax.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.func_decl") +def gen_function_decl(func_attrs): + return gemm_rcr_softmax.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.func_call") +def gen_function_call(func_attrs, indent=" "): + return gemm_rcr_softmax.gen_function_call( + func_attrs, + indent, + ) + + +@registry.reg("cuda.gemm_rcr_bias_softmax.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py new file mode 100644 index 000000000..eb3fcde49 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor] +This is special in template based gemm solution +This is used for `torch.nn.functional.linear` +When use for `linear`, need set A->Data, B->Weight +""" +import jinja2 + +from ... import registry +from ..gemm_universal import common +from ..gemm_universal.layout import RCR +from . import common_softmax + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + /* + A: M*K (RowMajor) + B: N*K (ColumnMajor) + C/D/sofmax: M*N (RowMajor) + N: M*1 (RowMajor) + */ + + {M, N, K}, + 1, + {a_ptr, LayoutA(K)}, + {b_ptr, LayoutB(K)}, + {c_ptr, LayoutC(N)}, + {d_ptr, LayoutC(N)}, + { + float(1.0), + float(0.0) + }, + {n_ptr, LayoutC(1)}, + {soft_ptr, LayoutC(N)} + +""" +) + + +@registry.reg("cuda.gemm_rcr_softmax.config") +def gemm_rcr_softmax_config(func_attrs, dtype="float16"): + common.make_fproc_f16(func_attrs, RCR) + + +def common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim0" + ) + common_softmax.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rcr_softmax.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_softmax.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_softmax.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render() + else: + problem_args = problem_args_template.render() + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_softmax.gen_function( + func_attrs, + common_softmax.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + ) + + +@registry.reg("cuda.gemm_rcr_softmax.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_softmax.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rcr_softmax.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + b = func_attrs["inputs"][1] + + tmp_c = func_attrs["inputs"][2] + tmp_d = func_attrs["inputs"][3] + tmp_n = func_attrs["inputs"][4] + + soft = func_attrs["outputs"][0] + has_bias = False + adims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][0].original_shapes + ] + bdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][1].original_shapes + ] + cdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["output_accessors"][0].original_shapes + ] + return common_softmax.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + c_ptr=tmp_c._attrs["name"], + d_ptr=tmp_d._attrs["name"], + n_ptr=tmp_n._attrs["name"], + soft_ptr=soft._attrs["name"], + split_k=func_attrs["split_k"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + ) + + +@registry.reg("cuda.gemm_rcr_softmax.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h new file mode 100644 index 000000000..3b168b3d8 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h @@ -0,0 +1,302 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +namespace cutlass { + +template < + typename ElementA_, + typename LayoutA_, + int kAlignmentA, + typename ElementB_, + typename LayoutB_, + int kAlignmentB, + typename ElementC_, + int kAlignmentC, + typename OperatorClass, + typename ArchTag, + typename ElementAccumulator, + int kStages, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueFunctorOp, + typename ThreadblockSwizzle, + typename ElementSum_ = ElementAccumulator, + typename ElementSoftmax_ = ElementC_> + +class GemmSoftmaxUniversal { + public: + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementC = ElementC_; + using ElementCompute = ElementAccumulator; + using ElementSum = ElementSum_; + using ElementSoft = ElementSoftmax_; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + + static int const kAlignment = kAlignmentA; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + /// Linear scaling operator + // using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< + // ElementC, + // kAlignment, + // ElementCompute, + // ElementCompute + // >; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // This is a mandatory data type for the atomic reduction in the GEMM epilogue + // to function. + + using ElementN = float; + + // These are mandatory layouts. + using LayoutC = cutlass::layout::RowMajor; + using LayoutN = cutlass::layout::RowMajor; + using LayoutSoft = cutlass::layout::RowMajor; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSoft = TensorRef; + + // using OperatorClass = cutlass::arch::OpClassTensorOp; + // using ArchTag = cutlass::arch::Sm80; + // static int const kStages = Stages; + // using ThreadblockSwizzle = + // cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // basic GEMM kernel + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignment, + ElementB, + LayoutB, + kAlignment, + ElementC, + LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + ThreadblockSwizzle, + kStages, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = kernel::EpilogueVisitorBiasMax< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + ElementCompute, + EpilogueFunctorOp>; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue>::Epilogue; + + // GEMM + using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + ThreadblockSwizzle>; + + // Softmax kernel + using SoftmaxApplyKernel = kernel::ApplySoftmax< + ElementC, + ElementN, + ElementSum, + ElementSoft, + kAlignmentC, + MatrixShape<1, 1024>>; + + public: + /// Arguments class + struct Arguments { + typename GemmKernel::Arguments gemm; + + typename SoftmaxApplyKernel::Arguments softmax; + + // + // Methods + // + Arguments() {} + + Arguments( + cutlass::gemm::GemmCoord problem_size, + int32_t batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueFunctorOp::Params linear_scaling, + TensorRefN ref_N_, + TensorRefSoft ref_Softmax_, + int64_t batch_stride_A_ = 0, + int64_t batch_stride_B_ = 0, + int64_t batch_stride_C_ = 0, + int64_t batch_stride_D_ = 0, + int64_t batch_stride_Max_ = 0, + int64_t batch_stride_Softmax_ = 0) + : gemm( + cutlass::gemm::GemmUniversalMode::kBatched, + problem_size, + batch_count_, + ref_A_, + ref_B_, + batch_stride_A_, + batch_stride_B_, + typename EpilogueVisitor::Arguments( + linear_scaling, + ref_C_, + ref_D_, + ref_N_.data(), + batch_stride_C_, + batch_stride_D_, + batch_stride_Max_)), + softmax( + MatrixCoord(problem_size.m(), problem_size.n()), + batch_count_, + ref_D_, + ref_N_, + ref_Softmax_, + batch_stride_D_, + batch_stride_Max_, + batch_stride_Softmax_) {} + }; + + struct Params { + typename GemmKernel::Params gemm; + + typename SoftmaxApplyKernel::Params softmax; + + // + // Methods + // + Params() {} + + Params(Arguments const& args) : gemm(args.gemm), softmax(args.softmax) {} + }; + + public: + // Gemm + + // + // Methods + // + + private: + Params params_; + + public: + /// Ctor + GemmSoftmaxUniversal() {} + + /// Initialize + Status initialize(Arguments const& args) { + params_ = Params(args); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + // + // Launch the GEMM + max kernel + // + + dim3 gemm_grid = + ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape); + + dim3 gemm_block(GemmKernel::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cutlass::Kernel + <<>>(params_.gemm); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the SoftmaxApplyKernel + // + + dim3 apply_block( + SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow); + + int cta_rows = SoftmaxApplyKernel::Shape::kRow; + int cta_columns = + SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment; + + dim3 apply_grid( + (params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows, + (params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns, + params_.softmax.args.batch_count); + + Kernel + <<>>(params_.softmax); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +} // namespace cutlass diff --git a/python/aitemplate/backend/cuda/gemm_special/__init__.py b/python/aitemplate/backend/cuda/gemm_special/__init__.py new file mode 100644 index 000000000..93043be2c --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_special/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +special gemm ops +""" +from . import bmm_rcr_n1, bmm_rrr_k1_tanh, gemm_rrr_small_nk + + +__all__ = ["bmm_rcr_n1", "bmm_rrr_k1_tanh", "gemm_rrr_small_nk"] diff --git a/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py b/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py new file mode 100644 index 000000000..5582ee24e --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py @@ -0,0 +1,616 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor] +This is special in template based gemm solution +This is used for `torch.nn.functional.linear` +When use for `linear`, need set A->Data, B->Weight + +Special kernel for GEMV case: +A: [B, M, K] +B: [B, N, K] +C: [B, M, N] +where N = 1 + +This kernel computes C = alpha * A @ B +""" + +import jinja2 + +from ....compiler.base import IntImm + +from ... import registry +from ...backend_spec import CUDASpec +from ...common import gemm_common, tensor_accessor_codegen +from ...target import Target +from ..gemm_universal import common + +# pylint: disable=C0301,W0613,W0612 + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + {{elem_input_type}}*, + {{elem_input_type}}*, + {{elem_input_type}}*, + {% for i in range(3) %} + int64_t*, + {% endfor %} + {% for i in range(3) %} + int64_t*, + {% endfor %} + {% for i in range(3) %} + int64_t*, + {% endfor %} + float, + bool, + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}}{{local_dim_defs}} +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{{indent}} {{c_ptr}}, +{% for adim in adims %} +{{indent}} {{adim}}, +{% endfor %} +{% for bdim in bdims %} +{{indent}} {{bdim}}, +{% endfor %} +{% for cdim in cdims %} +{{indent}} {{cdim}}, +{% endfor %} +{{indent}} {{alpha}}, +{{indent}} {{use_fp16_acc}}, +{{indent}} stream +{{indent}}); +{{indent}}} +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}bmm_rcr_n1_launcher<{{elem_input_type}}, {{read_vec_type}}, {{K}}>( +{{indent}} a_ptr, +{{indent}} b_ptr, +{{indent}} c_ptr, +{{indent}} B, +{{indent}} M, +{{indent}} alpha, +{{indent}} use_fp16_acc, +{{indent}} stream, +{{intent}} input_a_accessor, +{{intent}} input_b_accessor, +{{intent}} output_accessor +{{indent}}); +{{indent}}return; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include "cutlass/util/host_tensor.h" + +namespace { + +{{tensor_accessor_libs}} + +template +__forceinline__ __device__ bool load_vec_data( + ReadVecT* a_ptr, + ReadVecT* b_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor, + ReadVecT *a_vec, + ReadVecT *b_vec) { + + int64_t batch_idx = blockIdx.y; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + constexpr int64_t N_READ_ELEMS_IN_V = sizeof(ReadVecT) / sizeof(ElemT); + constexpr int64_t N_NUM_ELEMS_IN_V = K / N_READ_ELEMS_IN_V; + + int64_t b_idx_base = (batch_idx * K) / N_READ_ELEMS_IN_V; + + if (blockDim.x >= N_NUM_ELEMS_IN_V) { + // We have enough threads in a thread block where each thread takes care + // of loading one vector. + if (threadIdx.x < N_NUM_ELEMS_IN_V) { + b_vec[threadIdx.x] = *input_b_accessor.get(b_ptr, b_idx_base + threadIdx.x); + } + } else { + // We have more vectors than the available threads of a thread block, so each + // thread may read multiple vectors. + for (int64_t i = 0; i < N_NUM_ELEMS_IN_V / blockDim.x + 1; i++) { + int64_t idx = i * blockDim.x + threadIdx.x; + if (idx < N_NUM_ELEMS_IN_V) { + b_vec[idx] = *input_b_accessor.get(b_ptr, b_idx_base + idx); + } + } + } + + __syncthreads(); + if (row_idx >= M) { + return false; + } + + int64_t a_batch_stride = M * K; + int64_t a_idx_base = (batch_idx * a_batch_stride + row_idx * K) / N_READ_ELEMS_IN_V; + + CUTLASS_PRAGMA_UNROLL + for (int64_t k = 0, i = 0; k < K; k += N_READ_ELEMS_IN_V, i++) { + a_vec[i] = *input_a_accessor.get(a_ptr, a_idx_base++); + } + + return true; +} + +// Each thread reads one row from "a" and one column from "b", +// computes dot_product(a_row, b_col), and writes the result to "c". +// This kernel assumes loading "a" and "b" can be fully vectorized, +// so it reads both "a" and "b" in ReadVecT. +template +__global__ void bmm_rcr_n1_kernel_fp32_acc_vec( + ReadVecT* a_ptr, + ReadVecT* b_ptr, + ElemT* c_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor) { + + static_assert(sizeof(ReadVecT) % sizeof(ElemT) == 0, "invalid vector type"); + constexpr int64_t N_READ_ELEMS_IN_V = sizeof(ReadVecT) / sizeof(ElemT); + static_assert(N_READ_ELEMS_IN_V % 2 == 0, "invalid vector type for read"); + static_assert(K % N_READ_ELEMS_IN_V == 0, "cannot vectorize input"); + constexpr int64_t N_NUM_ELEMS_IN_V = K / N_READ_ELEMS_IN_V; + + __shared__ ReadVecT b_vec[N_NUM_ELEMS_IN_V]; + ReadVecT a_vec[N_NUM_ELEMS_IN_V]; + + if (!load_vec_data( + a_ptr, b_ptr, M, alpha, input_a_accessor, input_b_accessor, + output_accessor, a_vec, b_vec)) { + return; + } + + float result = 0.0; + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < N_NUM_ELEMS_IN_V; i++) { + const half2* a_vec_h2 = reinterpret_cast(&a_vec[i]); + const half2* b_vec_h2 = reinterpret_cast(&b_vec[i]); + CUTLASS_PRAGMA_UNROLL + for (int64_t j = 0; j < N_READ_ELEMS_IN_V / 2; ++j) { + half2 c_h2 = __hmul2(a_vec_h2[j], b_vec_h2[j]); + result += float(__low2half(c_h2)) + float(__high2half(c_h2)); + } + } + + int64_t batch_idx = blockIdx.y; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + *output_accessor.get(c_ptr, batch_idx * M + row_idx) = alpha * result; +} + +template +__forceinline__ __device__ bool load_data( + ElemT* a_ptr, + ElemT* b_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor, + ElemT *a_data, + ElemT *b_data) { + + int64_t batch_idx = blockIdx.y; + int64_t b_idx_base = batch_idx * K; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (blockDim.x >= K) { + // We have enough threads in a thread block where each thread takes care + // of loading one element. + if (threadIdx.x < K) { + b_data[threadIdx.x] = *input_b_accessor.get(b_ptr, b_idx_base + threadIdx.x); + } + } else { + // We have more elements than the available threads of a thread block, so each + // thread may load multiple elements. + for (int64_t i = 0; i < K / blockDim.x + 1; i++) { + int64_t idx = i * blockDim.x + threadIdx.x; + if (idx < K) { + b_data[idx] = *input_b_accessor.get(b_ptr, b_idx_base + idx); + } + } + } + + __syncthreads(); + + if (row_idx >= M) { + return false; + } + + int64_t a_batch_stride = M * K; + int64_t a_idx_base = batch_idx * a_batch_stride + row_idx * K; + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < K; i++) { + a_data[i] = *input_a_accessor.get(a_ptr, a_idx_base++); + } + + return true; +} + +// Each thread reads one row from "a" and one column from "b", +// computes dot_product(a_row, b_col), and writes the result to "c". +// It reads both "a" and "b" one by one in ElemT. +template +__global__ void bmm_rcr_n1_kernel_fp32_acc( + ElemT* a_ptr, + ElemT* b_ptr, + ElemT* c_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor) { + + __shared__ ElemT b_data[K]; + ElemT a_data[K]; + + if (!load_data( + a_ptr, b_ptr, M, alpha, input_a_accessor, input_b_accessor, + output_accessor, a_data, b_data)) { + return; + } + + float result = 0.0; + + const half2* a_data_h2 = reinterpret_cast(&a_data[0]); + const half2* b_data_h2 = reinterpret_cast(&b_data[0]); + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < K / 2; ++i) { + half2 c_h2 = __hmul2(a_data_h2[i], b_data_h2[i]); + result += float(__low2half(c_h2)) + float(__high2half(c_h2)); + } + if (K % 2) { + result += float(__hmul(reinterpret_cast(a_data[K-1]), + reinterpret_cast(b_data[K-1]))); + } + + int64_t batch_idx = blockIdx.y; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + + *output_accessor.get(c_ptr, batch_idx * M + row_idx) = alpha * result; +} + +template +__global__ void bmm_rcr_n1_kernel_fp16_acc_vec( + ReadVecT* a_ptr, + ReadVecT* b_ptr, + ElemT* c_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor) { + + static_assert(sizeof(ReadVecT) % sizeof(ElemT) == 0, "invalid vector type"); + constexpr int64_t N_READ_ELEMS_IN_V = sizeof(ReadVecT) / sizeof(ElemT); + static_assert(N_READ_ELEMS_IN_V % 2 == 0, "invalid vector type for read"); + static_assert(K % N_READ_ELEMS_IN_V == 0, "cannot vectorize input"); + constexpr int64_t N_NUM_ELEMS_IN_V = K / N_READ_ELEMS_IN_V; + + __shared__ ReadVecT b_vec[N_NUM_ELEMS_IN_V]; + ReadVecT a_vec[N_NUM_ELEMS_IN_V]; + + if (!load_vec_data( + a_ptr, b_ptr, M, alpha, input_a_accessor, input_b_accessor, + output_accessor, a_vec, b_vec)) { + return; + } + + half2 result_h2 = {0.0, 0.0}; + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < N_NUM_ELEMS_IN_V; i++) { + const half2* a_vec_h2 = reinterpret_cast(&a_vec[i]); + const half2* b_vec_h2 = reinterpret_cast(&b_vec[i]); + CUTLASS_PRAGMA_UNROLL + for (int64_t j = 0; j < N_READ_ELEMS_IN_V / 2; ++j) { + result_h2 = __hfma2(a_vec_h2[j], b_vec_h2[j], result_h2); + } + } + + float result = __hadd(__low2half(result_h2), __high2half(result_h2)); + + int64_t batch_idx = blockIdx.y; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + *output_accessor.get(c_ptr, batch_idx * M + row_idx) = alpha * result; +} + +template +__global__ void bmm_rcr_n1_kernel_fp16_acc( + ElemT* a_ptr, + ElemT* b_ptr, + ElemT* c_ptr, + const int64_t M, + float alpha, + TensorAccessor input_a_accessor, + TensorAccessor input_b_accessor, + TensorAccessor output_accessor) { + + __shared__ ElemT b_data[K]; + ElemT a_data[K]; + + if (!load_data( + a_ptr, b_ptr, M, alpha, input_a_accessor, input_b_accessor, + output_accessor, a_data, b_data)) { + return; + } + + half2 result_h2 = {0.0, 0.0}; + + const half2* a_data_h2 = reinterpret_cast(&a_data[0]); + const half2* b_data_h2 = reinterpret_cast(&b_data[0]); + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < K / 2; ++i) { + result_h2 = __hfma2(a_data_h2[i], b_data_h2[i], result_h2); + } + + half result = __hadd(__low2half(result_h2), __high2half(result_h2)); + if (K % 2) { + result = __hfma(reinterpret_cast(a_data[K-1]), + reinterpret_cast(b_data[K-1]), + result); + } + + int64_t batch_idx = blockIdx.y; + int64_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + *output_accessor.get(c_ptr, batch_idx * M + row_idx) = + alpha * (float)result; +} + +// N = 1, K is small +template +void bmm_rcr_n1_launcher(ElemT* a_ptr, + ElemT* b_ptr, + ElemT* c_ptr, + int64_t B, + int64_t M, + float alpha, + bool use_fp16_acc, + cudaStream_t stream, + const TensorAccessor& input_a_accessor, + const TensorAccessor& input_b_accessor, + const TensorAccessor& output_accessor) { + const int nthread = 256; + dim3 thread_block(nthread); + dim3 grid((M + nthread - 1) / nthread, B); + + if(use_fp16_acc) { + {{bmm_rcr_n1_kernel_fp16}} + <<>>( + (ReadVecT*)a_ptr, + (ReadVecT*)b_ptr, + c_ptr, + M, + alpha, + input_a_accessor, + input_b_accessor, + output_accessor + ); + } else { + {{bmm_rcr_n1_kernel_fp32}} + <<>>( + (ReadVecT*)a_ptr, + (ReadVecT*)b_ptr, + c_ptr, + M, + alpha, + input_a_accessor, + input_b_accessor, + output_accessor + ); + } +} + +} // namespace + +void {{function_name}} ( + {{elem_input_type}}* a_ptr, + {{elem_input_type}}* b_ptr, + {{elem_input_type}}* c_ptr, + {% for i in range(3) %} + int64_t *a_dim{{loop.index0}}, + {% endfor %} + {% for i in range(3) %} + int64_t *b_dim{{loop.index0}}, + {% endfor %} + {% for i in range(3) %} + int64_t *c_dim{{loop.index0}}, + {% endfor %} + float alpha, + bool use_fp16_acc, + cudaStream_t stream +) { + {{shape_function}} + {{input_output_checks}} + {{input_accessors}} + {{output_accessors}} + {{exec_paths}} +} + +""" +) + + +@registry.reg("cuda.bmm_rcr_n1.gen_function") +def gen_function(func_attrs, exec_cond_template, dim_info_dict): + func_name = func_attrs["name"] + shape_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + def _get_original_dim_val(func_attrs, input_idx, dim): + accessor = func_attrs["input_accessors"][input_idx] + shape = accessor.original_shapes + assert isinstance( + shape[dim], IntImm + ), f"input {input_idx}'s dim {dim} must be static. Instead it's dynamic" + k = shape[dim]._attrs["values"][0] + return k + + # Get original k value in case it's changed to a strided tensor after + # fusing split op into bmm_rcr. Strided dim can only be the last dim. + ak = _get_original_dim_val(func_attrs, 0, 2) + bk = _get_original_dim_val(func_attrs, 1, 2) + assert ak == bk, f"ak is not equal to bk. ak: {ak}, bk: {bk}" + + elem_input_type = "cutlass::half_t" + backend_spec = CUDASpec() + vec_lens = list(zip(*backend_spec.read_num_elements_to_backend_type))[0][:-1] + alignment = tensor_accessor_codegen.find_max_alignment( + ak, func_attrs["input_accessors"] + ) + if alignment % 2: + bmm_rcr_n1_kernel_fp32 = "bmm_rcr_n1_kernel_fp32_acc" + bmm_rcr_n1_kernel_fp16 = "bmm_rcr_n1_kernel_fp16_acc" + read_vec_type = elem_input_type + else: + for vec_idx, vec_len in enumerate(vec_lens): + if ak % vec_len == 0: + bmm_rcr_n1_kernel_fp32 = "bmm_rcr_n1_kernel_fp32_acc_vec" + bmm_rcr_n1_kernel_fp16 = "bmm_rcr_n1_kernel_fp16_acc_vec" + read_vec_type = backend_spec.read_num_elements_to_backend_type[vec_idx][ + 1 + ] + break + + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=3, + weight_ndims=3, + output_ndims=3, + ) + if ak == 0: + # avoid compilation failure (zero-sized variable not alowed in device code) + # caused by instantiating the template with K=0 + exec_paths = "" + else: + exec_paths = EXEC_TEMPLATE.render( + indent=" ", + read_vec_type=read_vec_type, + elem_input_type=elem_input_type, + K=ak, + ) + + input_a_accessor = tensor_accessor_codegen.TENSOR_ACCESSOR_TEMPLATE.render( + name="input_a_accessor", tensor_accessor=func_attrs["input_accessors"][0] + ) + + input_b_accessor = tensor_accessor_codegen.TENSOR_ACCESSOR_TEMPLATE.render( + name="input_b_accessor", tensor_accessor=func_attrs["input_accessors"][1] + ) + + return SRC_TEMPLATE.render( + function_name=func_name, + elem_input_type=elem_input_type, + bmm_rcr_n1_kernel_fp32=bmm_rcr_n1_kernel_fp32, + bmm_rcr_n1_kernel_fp16=bmm_rcr_n1_kernel_fp16, + shape_function=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + tensor_accessor_libs=tensor_accessor_codegen.get_libs(), + input_accessors=input_a_accessor + input_b_accessor, + output_accessors=tensor_accessor_codegen.TENSOR_ACCESSOR_TEMPLATE.render( + name="output_accessor", tensor_accessor=func_attrs["output_accessors"][0] + ), + ) + + +@registry.reg("cuda.bmm_rcr_n1.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, elem_input_type="cutlass::half_t" + ) + + +@registry.reg("cuda.bmm_rcr_n1.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + ashape = func_attrs["input_accessors"][0].original_shapes + adims = ["&" + dim._attrs["name"] for dim in ashape] + b = func_attrs["inputs"][1] + bshape = func_attrs["input_accessors"][1].original_shapes + bdims = ["&" + dim._attrs["name"] for dim in bshape] + c = func_attrs["outputs"][0] + cshape = func_attrs["output_accessors"][0].original_shapes + cdims = ["&" + dim._attrs["name"] for dim in cshape] + alpha = func_attrs["alpha"] + use_fp16_acc = False + if "use_fp16_acc" in Target.current()._kwargs: + use_fp16_acc = Target.current()._kwargs["use_fp16_acc"] + return FUNC_CALL_TEMPLATE.render( + local_dim_defs=common.gen_local_dim_defs(func_attrs, indent=indent), + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + c_ptr=c._attrs["name"], + adims=adims, + bdims=bdims, + cdims=cdims, + alpha=alpha, + use_fp16_acc="true" if use_fp16_acc else "false", + indent=indent, + ) + + +@registry.reg("cuda.bmm_rcr_n1.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py new file mode 100644 index 000000000..de29a6ab7 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen for bmm_rrr_k1_tanh. + +This kernel computes C = tanh(alpha * A @ B), where: +A[RowMajor]: [B, M, 1] +B[RowMajor]: [B, 1, N] +C[RowMajor]: [B, M, N] +""" +import jinja2 + +from ... import registry +from ...common import gemm_common +from ..gemm_universal import common + +# pylint: disable=C0301,W0613,W0612 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + {% for i in range(3) %} + int64_t*, + {% endfor %} + {% for i in range(3) %} + int64_t*, + {% endfor %} + {% for i in range(3) %} + int64_t*, + {% endfor %} +cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{{indent}} {{c_ptr}}, +{% for adim in adims %} +{{indent}} {{adim}}, +{% endfor %} +{% for bdim in bdims %} +{{indent}} {{bdim}}, +{% endfor %} +{% for cdim in cdims %} +{{indent}} {{cdim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}bmm_rrr_k1_tanh_launcher( +{{indent}} a_ptr, +{{indent}} b_ptr, +{{indent}} c_ptr, +{{indent}} B, +{{indent}} M, +{{indent}} N, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include "cutlass/util/host_tensor.h" +#include "cutlass/fast_math.h" + +#ifndef __HALF_TO_US +#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) +#endif + +namespace { + +__device__ half fast_tanh(half x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + + asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(__HALF_TO_US(x)) : "h"(__HALF_TO_US(x))); + return x; + + #else + return half(cutlass::fast_tanh(float(x))); + #endif +} + +template +__global__ void bmm_rrr_k1_tanh_kernel(const float4* a_ptr, + const float4* b_ptr, + float4* c_ptr, + const int B, + const int M, + const int N) { + // TODO: check boundary + half tmp[64]; + int idx = blockIdx.x * num_thread + threadIdx.x; + int m = idx % M; + int b = idx / M; + int a_idx_base = b * M + m; + float4 a_vec = __ldg(a_ptr + a_idx_base); + half* a_vec_ptr = (half*)(&a_vec); + for (int n = 0; n < N; ++n) { + int b_idx_base = b * N + n; + float4 b_vec = __ldg(b_ptr + b_idx_base); + half* b_vec_ptr = (half*)(&b_vec); + for (int i = 0; i < 8; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 8; ++j) { + tmp[i * 8 + j] = fast_tanh(__hmul(a_vec_ptr[i], b_vec_ptr[j])); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + int c_idx = (b * M * 8 + m * 8 + i) * N + n; + c_ptr[c_idx] = *((const float4*)(tmp + i * 8)); + } + } +} + + +void bmm_rrr_k1_tanh_launcher(cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* c_ptr, + int B, + int M, + int N, + cudaStream_t stream) { + const int nthread = 256; + dim3 thread_block(nthread); + dim3 grid(B * M / nthread / 8); + bmm_rrr_k1_tanh_kernel<<>>( + (const float4*)a_ptr, + (const float4*)b_ptr, + (float4*) c_ptr, + B, + M / 8, + N / 8 + ); +} + +} // namespace + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* c_ptr, + {% for i in range(3) %} + int64_t *a_dim{{loop.index0}}, + {% endfor %} + {% for i in range(3) %} + int64_t *b_dim{{loop.index0}}, + {% endfor %} + {% for i in range(3) %} + int64_t *c_dim{{loop.index0}}, + {% endfor %} + cudaStream_t stream +) { + {{shape_function}} + {{input_output_checks}} + {{exec_paths}} +} + +""" +) + + +@registry.reg("cuda.bmm_rrr_k1_tanh.gen_function") +def gen_function(func_attrs, exec_cond_template, dim_info_dict): + func_name = func_attrs["name"] + shape_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=3, + weight_ndims=3, + output_ndims=3, + ) + exec_paths = EXEC_TEMPLATE.render() + return SRC_TEMPLATE.render( + function_name=func_name, + shape_function=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + ) + + +@registry.reg("cuda.bmm_rrr_k1_tanh.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.bmm_rrr_k1_tanh.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + ashape = a._attrs["shape"] + adims = ["&" + dim._attrs["name"] for dim in ashape] + b = func_attrs["inputs"][1] + bshape = b._attrs["shape"] + bdims = ["&" + dim._attrs["name"] for dim in bshape] + c = func_attrs["outputs"][0] + cshape = c._attrs["shape"] + cdims = ["&" + dim._attrs["name"] for dim in cshape] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + c_ptr=c._attrs["name"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + ) + + +@registry.reg("cuda.bmm_rrr_k1_tanh.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py b/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py new file mode 100644 index 000000000..81ed764e8 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for A[RowMajor], B[RowMajor], C[RowMajor] +This is special in template based gemm solution +This is used for `torch.nn.functional.linear` +When use for `linear`, need set A->Data, B->Weight + +Special kernel for small K and N +K <= 8, N <= 8 +A: [M, K] A can be ND with the first N - 1 dimensions as batch dimensions +B: [K, N] +C: [M, N] +""" + +import jinja2 + +from ... import registry +from ...common import gemm_common +from ...target import Target +from ..gemm_universal import common + +# pylint: disable=C0301,W0613,W0612 + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + {% for i in range(a_ndim) %} + int64_t*, + {% endfor %} + {% for i in range(b_ndim) %} + int64_t*, + {% endfor %} + {% for i in range(c_ndim) %} + int64_t*, + {% endfor %} + bool, + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{{indent}} {{c_ptr}}, +{% for adim in adims %} +{{indent}} {{adim}}, +{% endfor %} +{% for bdim in bdims %} +{{indent}} {{bdim}}, +{% endfor %} +{% for cdim in cdims %} +{{indent}} {{cdim}}, +{% endfor %} +{{indent}} {{use_fp16_acc}}, +{{indent}} stream +{{indent}}); +""" +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}gemm_rrr_small_nk_launcher<{{N}}, {{K}}>( +{{indent}} a_ptr, +{{indent}} b_ptr, +{{indent}} c_ptr, +{{indent}} M, +{{indent}} use_fp16_acc, +{{indent}} stream +{{indent}}); +{{indent}}return; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include "cutlass/util/host_tensor.h" + +namespace { + +// For each thread, read +// A tile: 8 x K +// B matrix: K x N +// C tile: 8 x N +template +__global__ void gemm_rrr_small_nk_kernel(float4* a_ptr, + float4* b_ptr, + float4* c_ptr, + int M) { + int idx = blockIdx.x * num_thread + threadIdx.x; + + if (idx >= (M + 7) / 8) { + return; + } + + int a_idx_base = idx * K; + a_ptr += a_idx_base; + + // load b matrix + half b[K][N]; + half* b_half = reinterpret_cast(b_ptr); + for (int i = 0; i < K; ++i) { + for (int j = 0; j < N; ++j) { + b[i][j] = b_half[i * N + j]; + } + } + + int c_idx_base = idx * N; + c_ptr += c_idx_base; + + half c_tile[8][N]; + + if (idx <= M / 8 - 1) { + // fast kernel + // load a + float4 a_tile_vec[K]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < K; i++) { + a_tile_vec[i] = __ldg(a_ptr++); + } + half* a_tile = reinterpret_cast(&a_tile_vec); + + // compute + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < N; ++j) { + if (USE_FP16_ACC) { + half sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < K; ++k) { + sum = __hfma(a_tile[i * K + k], b[k][j], sum); + } + c_tile[i][j] = sum; + } else { + float sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < K; ++k) { + sum += __half2float(__hmul(a_tile[i * K + k], b[k][j])); + } + c_tile[i][j] = __float2half_rn(sum); + } + } + } + + // write c + float4* c_tile_vec = reinterpret_cast(&c_tile); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; i++) { + c_ptr[i] = c_tile_vec[i]; + } + } else { + // process tail + // load a + half* a_h = reinterpret_cast(a_ptr); + int m = M - M / 8 * 8; + half a_tile[8][K]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < m; i++) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < K; j++) { + a_tile[i][j] = a_h[i * K + j]; + } + } + + // compute + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < m; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < N; ++j) { + if (USE_FP16_ACC) { + half sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < K; ++k) { + sum = __hfma(a_tile[i][k], b[k][j], sum); + } + c_tile[i][j] = sum; + } else { + float sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < K; ++k) { + sum += __half2float(__hmul(a_tile[i][k], b[k][j])); + } + c_tile[i][j] = __float2half_rn(sum); + } + } + } + + // write c + half* c_h = reinterpret_cast(c_ptr); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < m; i++) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < N; j++) { + c_h[i * N + j] = c_tile[i][j]; + } + } + } +} + +// N <= 8, K <= 8 +template +void gemm_rrr_small_nk_launcher(cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* c_ptr, + int M, + bool use_fp16_acc, + cudaStream_t stream) { + const int nthread = 256; + dim3 thread_block(nthread); + const int n_element_per_t = nthread * 8; + dim3 grid((M + n_element_per_t - 1) / n_element_per_t); + if(use_fp16_acc) { + gemm_rrr_small_nk_kernel<<>>( + (float4*)a_ptr, + (float4*)b_ptr, + (float4*)c_ptr, + M + ); + } else { + gemm_rrr_small_nk_kernel<<>>( + (float4*)a_ptr, + (float4*)b_ptr, + (float4*)c_ptr, + M + ); + } +} + +} // namespace + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* c_ptr, + {% for i in range(a_ndim) %} + int64_t *a_dim{{loop.index0}}, + {% endfor %} + {% for i in range(b_ndim) %} + int64_t *b_dim{{loop.index0}}, + {% endfor %} + {% for i in range(c_ndim) %} + int64_t *c_dim{{loop.index0}}, + {% endfor %} + bool use_fp16_acc, + cudaStream_t stream +) { + {{shape_function}} + {{input_output_checks}} + {{exec_paths}} +} + +""" +) + + +@registry.reg("cuda.gemm_rrr_small_nk.gen_function") +def gen_function(func_attrs, exec_cond_template, dim_info_dict): + func_name = func_attrs["name"] + shape_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + b = func_attrs["inputs"][1] + bshape = b._attrs["shape"] + k = bshape[0]._attrs["values"][0] + n = bshape[1]._attrs["values"][0] + + a_ndim = func_attrs["inputs"][0]._rank() + b_ndim = func_attrs["inputs"][1]._rank() + c_ndim = func_attrs["outputs"][0]._rank() + + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=a_ndim, + weight_ndims=2, + output_ndims=c_ndim, + ) + if n == 0 or k == 0: + # avoid "zero-sized variable not allowed in device code" error + exec_paths = "" + else: + exec_paths = EXEC_TEMPLATE.render(indent=" ", N=n, K=k) + return SRC_TEMPLATE.render( + function_name=func_name, + shape_function=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + a_ndim=a_ndim, + b_ndim=b_ndim, + c_ndim=c_ndim, + ) + + +@registry.reg("cuda.gemm_rrr_small_nk.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + a_ndim = func_attrs["inputs"][0]._rank() + b_ndim = func_attrs["inputs"][1]._rank() + c_ndim = func_attrs["outputs"][0]._rank() + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, a_ndim=a_ndim, b_ndim=b_ndim, c_ndim=c_ndim + ) + + +@registry.reg("cuda.gemm_rrr_small_nk.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + ashape = a._attrs["shape"] + adims = ["&" + dim._attrs["name"] for dim in ashape] + b = func_attrs["inputs"][1] + bshape = b._attrs["shape"] + bdims = ["&" + dim._attrs["name"] for dim in bshape] + c = func_attrs["outputs"][0] + cshape = c._attrs["shape"] + cdims = ["&" + dim._attrs["name"] for dim in cshape] + use_fp16_acc = False + if "use_fp16_acc" in Target.current()._kwargs: + use_fp16_acc = Target.current()._kwargs["use_fp16_acc"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + c_ptr=c._attrs["name"], + adims=adims, + bdims=bdims, + cdims=cdims, + use_fp16_acc="true" if use_fp16_acc else "false", + indent=indent, + ) + + +@registry.reg("cuda.gemm_rrr_small_nk.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/__init__.py b/python/aitemplate/backend/cuda/gemm_universal/__init__.py new file mode 100644 index 000000000..c07983128 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/__init__.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from . import ( + bmm_ccr, + bmm_ccr_add, + bmm_crr, + bmm_crr_add, + bmm_rcr, + bmm_rcr_permute, + bmm_rrr, + bmm_rrr_add, + bmm_rrr_permute, + gemm_rcr, + gemm_rcr_bias, + gemm_rcr_bias_add, + gemm_rcr_bias_add_add, + gemm_rcr_bias_add_add_relu, + gemm_rcr_bias_add_relu, + gemm_rcr_bias_fast_gelu, + gemm_rcr_bias_gelu, + gemm_rcr_bias_hardswish, + gemm_rcr_bias_mul, + gemm_rcr_bias_mul_add, + gemm_rcr_bias_mul_tanh, + gemm_rcr_bias_permute, + gemm_rcr_bias_relu, + gemm_rcr_bias_sigmoid, + gemm_rcr_bias_sigmoid_mul, + gemm_rcr_bias_sigmoid_mul_tanh, + gemm_rcr_bias_swish, + gemm_rcr_bias_tanh, + gemm_rcr_permute, + gemm_rrr, + gemm_rrr_permute, + group_gemm_rcr, + group_gemm_rcr_bias, + group_gemm_rcr_bias_relu, + group_gemm_rcr_bias_sigmoid, + perm021fc_ccr, + perm021fc_ccr_bias, + perm021fc_ccr_bias_permute, + perm021fc_crc, + perm021fc_crc_bias, + perm102_bmm_rcr, + perm102_bmm_rcr_bias, + perm102_bmm_rrr, + perm102_bmm_rrr_bias, +) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py new file mode 100644 index 000000000..25ad9e9a8 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen for bmm_ccr, which computes A @ B + bias. +A[ColMajor], B[ColMajor], bias[RowMajor] +""" +from ... import registry +from ...common import gemm_common +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "N * K", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "M", + "ldb": "K", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.bmm_ccr.config") +def bmm_ccr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.ColumnMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.bmm_ccr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_ccr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_ccr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_ccr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_ccr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py new file mode 100644 index 000000000..ea9ff0510 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen for bmm_ccr_add, which computes A @ B + bias + C. +A[ColMajor], B[ColMajor], bias / C[RowMajor] +""" +from ... import registry +from ...common import gemm_common +from . import bmm_ccr, bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.bmm_ccr_add.config") +def bmm_ccr_add_config(func_attrs, dtype="float16"): + return bmm_ccr.bmm_ccr_config(func_attrs, dtype) + + +@registry.reg("cuda.bmm_ccr_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = bmm_ccr._get_problem_info( + bias_ptr="d_ptr", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_ccr_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = bmm_ccr._get_problem_info( + bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_ccr_add.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_ccr_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_ccr_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py new file mode 100644 index 000000000..7b22806e3 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common functions and templates for bmm-family ops +""" +from dataclasses import dataclass + +import jinja2 + +from ...common import gemm_common +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +# ARGS_PARSER is only used by profiler, so the batch is not of concern. +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t B = std::atoi(argv[1]); + int64_t M = std::atoi(argv[2]); + int64_t N = std::atoi(argv[3]); + int64_t K = std::atoi(argv[4]); + +{% for dim in a_dims %} + int64_t a_dim{{loop.index0}} = {{dim}}; +{% endfor %} +{% for dim in b_dims %} + int64_t b_dim{{loop.index0}} = {{dim}}; +{% endfor %} +{% for dim in c_dims %} + int64_t c_dim{{loop.index0}} = {{dim}}; +{% endfor %} +""" +) + +OUTPUT_ADDR_CALCULATOR = jinja2.Template( + """ + int64_t output_batch_stride = {{output_batch_stride_dim}}; + int64_t output_stride = {{output_stride_dim}}; + int64_t output_offset = {{output_offset_val}}; // default to 0 + """ +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, +{% if has_d %} + cutlass::half_t*, +{% endif %} + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(a_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(b_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(c_ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}}{{local_dim_defs}} +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{% if has_d %} +{{indent}} {{d_ptr}}, +{% endif %} +{% if has_bias %} +{{indent}} {{bias_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} global_workspace, +{% for dim in a_dims_ptr %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in b_dims_ptr %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in c_dims_ptr %} +{{indent}} {{dim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +{{indent}}} +""" +) + + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + // cast to int64_t to avoid overflow + int64_t a_ptr_sz = 1; + {% for idx in range(a_ndims) %} + {{indent}} {{indent}} a_ptr_sz *= static_cast(a_dim{{idx}}); + {% endfor %} + + int64_t b_ptr_sz = 1; + {% for idx in range(b_ndims) %} + {{indent}} {{indent}} b_ptr_sz *= static_cast(b_dim{{idx}}); + {% endfor %} + + int64_t c_ptr_sz = 1; + {% for idx in range(c_ndims) %} + {{indent}} {{indent}} c_ptr_sz *= static_cast(c_dim{{idx}}); + {% endfor %} + + // The value 1 is used to force ptr_max_sz to be non-zero + int64_t ptr_max_sz = std::max({1, a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 +{% if has_bias %} + memory_pool->AllocateHalfTensor(c_dim2, mem_pool_sz); // bias_ptr: index 3 +{% endif %} +{% if has_d %} + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d_ptr: index 3 (no bias) or 4 +{% endif %} +""" +) + + +@dataclass +class Bmm_problem_info: + alpha_value: float = 1 + beta_value: float = 0 + problem_size: str = "{M, N, K}" + batch_size: str = "B" + a_ptr: str = "a_ptr" + b_ptr: str = "b_ptr" + bias_ptr: str = "d_ptr" + c_ptr: str = "c_ptr" + a_batch_stride: str = "0" + b_batch_stride: str = "0" + bias_batch_stride: str = "0" + c_batch_stride: str = "0" + lda: str = "0" + ldb: str = "0" + ldbias: str = "0" + ldc: str = "0" + + +def _update_stride_info(mm_info, a_shapes, b_shapes, bias_shapes=None): + if len(a_shapes) == 2 or a_shapes[0] == 1: + mm_info.a_batch_stride = "0" + if len(b_shapes) == 2 or b_shapes[0] == 1: + mm_info.b_batch_stride = "0" + + if bias_shapes is None: + return + + if len(bias_shapes) < 3 or bias_shapes[0] == 1: + mm_info.bias_batch_stride = "0" + if len(bias_shapes) < 2 or all([x == 1 for x in bias_shapes[:-1]]): + mm_info.ldbias = "0" + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kBatched, + {{mm_info.problem_size}}, + {{mm_info.batch_size}}, + {ElementComputeEpilogue({{mm_info.alpha_value}}), ElementComputeEpilogue({{mm_info.beta_value}})}, + (void*) {{mm_info.a_ptr}}, + (void*) {{mm_info.b_ptr}}, + (void*) {{mm_info.bias_ptr}}, + (void*) {{mm_info.c_ptr}}, + {{mm_info.a_batch_stride}}, + {{mm_info.b_batch_stride}}, + {{mm_info.bias_batch_stride}}, + {{mm_info.c_batch_stride}}, + {{mm_info.lda}}, + {{mm_info.ldb}}, + {{mm_info.ldbias}}, + {{mm_info.ldc}} +""" +) + + +def reverse_dim_info_mapping(dim_info_dict, source, tensor_idx): + def _fill(arr, idx, val): + if len(arr) <= idx: + arr = arr + [None] * (idx - len(arr) + 1) + arr[idx] = val + return arr + + ret = [] + for name, dim_infos in dim_info_dict.items(): + for dim_info in dim_infos: + if dim_info.source == source and dim_info.tensor_idx == tensor_idx: + for dim_idx in dim_info.dim_idx: + ret = _fill(ret, dim_idx, name) + + if None in ret: + raise RuntimeError( + "dim_info_dict for source: {}, tensor_idx: {} not complete.".format( + source, tensor_idx + ) + ) + + return ret + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args, + args_parser, + bias_ptr_arg=None, +): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + + a_ndims = len(func_attrs["input_accessors"][0].original_shapes) + b_ndims = len(func_attrs["input_accessors"][1].original_shapes) + c_ndims = len(func_attrs["output_accessors"][0].original_shapes) + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + assert not (has_d and has_bias) + for op_name, op in op_instance.items(): + config = common.emit_instance(op, for_profiler=True) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + problem_args=problem_args, + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + ) + op_func = src_template.render( + instances=instance, + function_name="bmm", + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + has_d=has_d, + ) + a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] + b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] + c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] + func_call = FUNC_CALL_TEMPLATE.render( + func_name="bmm", + a_ptr="memory_pool->RequestHalfTensorByIdx(0)", + b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + d_ptr="memory_pool->RequestHalfTensorByIdx(%d)" % (4 if has_bias else 3), + has_d=has_d, + a_dims_ptr=a_dims_ptr, + b_dims_ptr=b_dims_ptr, + c_dims_ptr=c_dims_ptr, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parser, + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render( + name=name, + a_ndims=a_ndims, + b_ndims=b_ndims, + c_ndims=c_ndims, + has_d=has_d, + has_bias=has_bias, + ), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + return FUNC_DECL_TEMPLATE.render( + func_name=func_name, + a_ndims=len(func_attrs["input_accessors"][0].original_shapes), + b_ndims=len(func_attrs["input_accessors"][1].original_shapes), + c_ndims=len(func_attrs["output_accessors"][0].original_shapes), + has_d=has_d, + ) + + +def gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + input_addr_calculator="", + output_addr_calculator="", +): + return common.gen_function( + func_attrs, + common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=len(func_attrs["input_accessors"][0].original_shapes), + weight_ndims=len(func_attrs["input_accessors"][1].original_shapes), + output_ndims=len(func_attrs["output_accessors"][0].original_shapes), + dim_info_dict=dim_info_dict, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=output_addr_calculator, + ) + + +def gen_function_call(func_attrs, indent=" ", bias_ptr_arg=None): + a = func_attrs["inputs"][0] + ashape = func_attrs["input_accessors"][0].original_shapes + a_dims_ptr = [f'&{ashape[idx]._attrs["name"]}' for idx in range(len(ashape))] + b = func_attrs["inputs"][1] + bshape = func_attrs["input_accessors"][1].original_shapes + b_dims_ptr = [f'&{bshape[idx]._attrs["name"]}' for idx in range(len(bshape))] + c = func_attrs["outputs"][0] + cshape = func_attrs["output_accessors"][0].original_shapes + c_dims_ptr = [f'&{cshape[idx]._attrs["name"]}' for idx in range(len(cshape))] + has_d = False + d_ptr = None + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + d_ptr = func_attrs["inputs"][2]._attrs["name"] + has_bias = bias_ptr_arg is not None + assert not (has_d and has_bias) + + local_dim_defs = common.gen_local_dim_defs(func_attrs, indent=indent) + + return FUNC_CALL_TEMPLATE.render( + local_dim_defs=local_dim_defs, + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr=c._attrs["name"], + d_ptr=d_ptr, + has_d=has_d, + a_dims_ptr=a_dims_ptr, + b_dims_ptr=b_dims_ptr, + c_dims_ptr=c_dims_ptr, + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py new file mode 100644 index 000000000..62d6eee96 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_crr, which computes A @ B + bias. +A[ColMajor], B[RowMajor], bias[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "N * K", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "M", + "ldb": "N", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.bmm_crr.config") +def bmm_crr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.ColumnMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.bmm_crr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_crr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_crr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_crr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_crr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py new file mode 100644 index 000000000..2767af9b0 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_crr_add, which computes A @ B + bias + C. +A[ColMajor], B[RowMajor], bias / C[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, bmm_crr, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.bmm_crr_add.config") +def bmm_crr_add_config(func_attrs, dtype="float16"): + return bmm_crr.bmm_crr_config(func_attrs, dtype) + + +@registry.reg("cuda.bmm_crr_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = bmm_crr._get_problem_info( + bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_crr_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = bmm_crr._get_problem_info( + bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_crr_add.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_crr_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_crr_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py new file mode 100644 index 000000000..582bfd38e --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common functions and templates for bmm_permute-family ops +""" +from ...common import gemm_common +from ..gemm_universal import common, common_bias + +from . import bmm_common, common_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args, + args_parser, + emit_kernel=False, + bias_ptr_arg=None, + extra_code="", +): + """Generate code for profiling""" + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + + a_ndims = len(func_attrs["input_accessors"][0].original_shapes) + b_ndims = len(func_attrs["input_accessors"][1].original_shapes) + c_ndims = len(func_attrs["output_accessors"][0].original_shapes) + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + assert not (has_d and has_bias) + for op_name, op in op_instance.items(): + config = common_permute.emit_instance( + op, + for_profiler=True, + emit_kernel=emit_kernel, + func_attrs=func_attrs, + ) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + problem_args=problem_args, + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + ) + op_func = src_template.render( + instances=instance, + function_name="bmm", + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + has_d=has_d, + extra_code=extra_code, + ) + a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] + b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] + c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] + func_call = bmm_common.FUNC_CALL_TEMPLATE.render( + func_name="bmm", + a_ptr="memory_pool->RequestHalfTensorByIdx(0)", + b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + d_ptr="memory_pool->RequestHalfTensorByIdx(%d)" % (4 if has_bias else 3), + has_d=has_d, + a_dims_ptr=a_dims_ptr, + b_dims_ptr=b_dims_ptr, + c_dims_ptr=c_dims_ptr, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parser, + func_call=func_call, + name=name, + tensor_decl=bmm_common.TENSOR_DECL_TEMPLATE.render( + name=name, + a_ndims=a_ndims, + b_ndims=b_ndims, + c_ndims=c_ndims, + has_d=has_d, + has_bias=has_bias, + ), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function_decl(func_attrs): + """Rendering argument to function declaration template""" + func_name = func_attrs["name"] + has_d = False + if "has_d" in func_attrs: + has_d = func_attrs["has_d"] + return bmm_common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + a_ndims=len(func_attrs["input_accessors"][0].original_shapes), + b_ndims=len(func_attrs["input_accessors"][1].original_shapes), + c_ndims=len(func_attrs["output_accessors"][0].original_shapes), + has_d=has_d, + ) + + +def gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + input_addr_calculator="", + output_addr_calculator="", + extra_code="", + has_bias=False, +): + return common_permute.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE if has_bias else common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=len(func_attrs["input_accessors"][0].original_shapes), + weight_ndims=len(func_attrs["input_accessors"][1].original_shapes), + output_ndims=len(func_attrs["output_accessors"][0].original_shapes), + dim_info_dict=dim_info_dict, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=output_addr_calculator, + emit_kernel=True, + extra_code=extra_code, + ) + + +def gen_function_call(func_attrs, indent=" ", bias_ptr_arg=None): + return bmm_common.gen_function_call(func_attrs, indent, bias_ptr_arg) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py new file mode 100644 index 000000000..d660f3c61 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_rcr, which computes A @ B + bias. +A[RowMajor], B[ColMajor], bias[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, common +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_default_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "N * K", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "K", + "ldb": "K", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.bmm_rcr.config") +def bmm_rcr_config(func_attrs, dtype="float16"): + common.make_fproc_f16(func_attrs, RCR) + + +@registry.reg("cuda.bmm_rcr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_rcr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_a_batch_stride_dim = "M * K" + input_a_stride_k_dim = "K" + input_a_offset = 0 + input_b_batch_stride_dim = "N * K" + input_b_stride_k_dim = "K" + input_b_offset = 0 + + if "input_accessors" in func_attrs: + input_a_accessor = func_attrs["input_accessors"][0] + input_b_accessor = func_attrs["input_accessors"][1] + + if input_a_accessor.is_from_strided_tensor: + input_a_offset = input_a_accessor.offset + if not input_a_accessor.is_contiguous: + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + + input_a_batch_stride_dim = input_a_accessor.gen_stride_str(0, a_dims) + input_a_stride_k_dim = input_a_accessor.stride(1) + + if input_b_accessor.is_from_strided_tensor: + input_b_offset = input_b_accessor.offset + if not input_b_accessor.is_contiguous: + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + input_b_batch_stride_dim = input_b_accessor.gen_stride_str(0, b_dims) + input_b_stride_k_dim = input_b_accessor.stride(1) + + input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render( + input_a_batch_stride_dim=input_a_batch_stride_dim, + input_a_stride_dim=input_a_stride_k_dim, + input_a_offset_val=input_a_offset, + input_b_batch_stride_dim=input_b_batch_stride_dim, + input_b_stride_dim=input_b_stride_k_dim, + input_b_offset_val=input_b_offset, + ) + + output_batch_stride_dim = "M * N" + output_stride_n_dim = "N" + output_offset = 0 + + if "output_accessors" in func_attrs: + output_accessor = func_attrs["output_accessors"][0] + if output_accessor.is_from_strided_tensor: + output_offset = output_accessor.offset + if not output_accessor.is_contiguous: + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + output_batch_stride_dim = output_accessor.gen_stride_str(0, c_dims) + output_stride_n_dim = output_accessor.stride(1) + + output_addr_calculator = bmm_common.OUTPUT_ADDR_CALCULATOR.render( + output_batch_stride_dim=output_batch_stride_dim, + output_stride_dim=output_stride_n_dim, + output_offset_val=output_offset, + ) + + bmm_problem_info = bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + a_ptr="(a_ptr + input_a_offset)", + b_ptr="(b_ptr + input_b_offset)", + bias_ptr="(c_ptr + output_offset)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="input_a_batch_stride", + b_batch_stride="input_b_batch_stride", + bias_batch_stride="output_batch_stride", + c_batch_stride="output_batch_stride", + lda="input_a_stride", + ldb="input_b_stride", + ldbias="output_stride", + ldc="output_stride", + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + input_addr_calculator, + output_addr_calculator, + ) + + +@registry.reg("cuda.bmm_rcr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_rcr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rcr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py new file mode 100644 index 000000000..2dc737be5 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_rcr_permute, which computes permute(A @ B + bias). +A[RowMajor], B[ColMajor], bias[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, bmm_permute_common, common, common_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.bmm_rcr_permute.config") +def bmm_rcr_permute_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common_permute.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + permute_layout=func_attrs["layout"], + ) + + func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + + +@registry.reg("cuda.bmm_rcr_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + bmm_problem_info = bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + bias_ptr="c_ptr", + a_batch_stride="M * K", + b_batch_stride="N * K", + bias_batch_stride="M * N", + c_batch_stride="0", + lda="K", + ldb="K", + ldbias="N", + ldc="N", + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) + + bmm_permute_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + emit_kernel=True, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.bmm_rcr_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_a_batch_stride_dim = "M * K" + input_a_stride_k_dim = "K" + input_a_offset = 0 + input_b_batch_stride_dim = "N * K" + input_b_stride_k_dim = "K" + input_b_offset = 0 + + if "input_accessors" in func_attrs: + input_a_accessor = func_attrs["input_accessors"][0] + input_b_accessor = func_attrs["input_accessors"][1] + + if input_a_accessor.is_from_strided_tensor: + input_a_offset = input_a_accessor.offset + if not input_a_accessor.is_contiguous: + input_a_batch_stride_dim = input_a_accessor.stride(0) + input_a_stride_k_dim = input_a_accessor.stride(1) + + if input_b_accessor.is_from_strided_tensor: + input_b_offset = input_b_accessor.offset + if not input_b_accessor.is_contiguous: + input_b_batch_stride_dim = input_b_accessor.stride(0) + input_b_stride_k_dim = input_b_accessor.stride(1) + + input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render( + input_a_batch_stride_dim=input_a_batch_stride_dim, + input_a_stride_dim=input_a_stride_k_dim, + input_a_offset_val=input_a_offset, + input_b_batch_stride_dim=input_b_batch_stride_dim, + input_b_stride_dim=input_b_stride_k_dim, + input_b_offset_val=input_b_offset, + ) + + output_batch_stride_dim = "M * N" + output_stride_n_dim = "N" + output_offset = 0 + + if "output_accessors" in func_attrs: + output_accessor = func_attrs["output_accessors"][0] + if output_accessor.is_from_strided_tensor: + output_offset = output_accessor.offset + if not output_accessor.is_contiguous: + output_batch_stride_dim = output_accessor.stride(0) + output_stride_n_dim = output_accessor.stride(1) + + output_addr_calculator = bmm_common.OUTPUT_ADDR_CALCULATOR.render( + output_batch_stride_dim=output_batch_stride_dim, + output_stride_dim=output_stride_n_dim, + output_offset_val=output_offset, + ) + + bmm_problem_info = bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + a_ptr="(a_ptr + input_a_offset)", + b_ptr="(b_ptr + input_b_offset)", + bias_ptr="(c_ptr + output_offset)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="input_a_batch_stride", + b_batch_stride="input_b_batch_stride", + bias_batch_stride="output_batch_stride", + c_batch_stride="0", + lda="input_a_stride", + ldb="input_b_stride", + ldbias="output_stride", + ldc="output_stride", + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) + + return bmm_permute_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + input_addr_calculator, + output_addr_calculator, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.bmm_rcr_permute.func_decl") +def gen_function_decl(func_attrs): + return bmm_permute_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_rcr_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_permute_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rcr_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py new file mode 100644 index 000000000..bc752b1bb --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_rrr, which computes A @ B + bias. +A[RowMajor], B[RowMajor], bias / C[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "N * K", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "K", + "ldb": "N", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.bmm_rrr.config") +def bmm_rrr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.bmm_rrr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_rrr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_rrr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_rrr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rrr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py new file mode 100644 index 000000000..bb8201291 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_rrr_add, which computes A @ B + bias + C. +A[RowMajor], B[RowMajor], bias / C[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, bmm_rrr, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.bmm_rrr_add.config") +def bmm_rrr_add_config(func_attrs, dtype="float16"): + return bmm_rrr.bmm_rrr_config(func_attrs, dtype) + + +@registry.reg("cuda.bmm_rrr_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + mm_info = bmm_rrr._get_problem_info( + bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.bmm_rrr_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = bmm_rrr._get_problem_info( + bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + d_shapes = func_attrs["input_accessors"][2].original_shapes + bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.bmm_rrr_add.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_rrr_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rrr_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py new file mode 100644 index 000000000..d1d17ee8d --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Codegen for bmm_rrr_permute, which computes permute(A @ B + bias). +A[RowMajor], B[RowMajor], bias / C[RowMajor] +""" + +from ... import registry +from ...common import gemm_common +from . import bmm_common, bmm_permute_common, common, common_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.bmm_rrr_permute.config") +def bmm_rrr_permute_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common_permute.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + permute_layout=func_attrs["layout"], + ) + + func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + + +@registry.reg("cuda.bmm_rrr_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=a_dims, b_dims=b_dims, c_dims=c_dims + ) + + bmm_problem_info = bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + bias_ptr="c_ptr", + a_batch_stride="M * K", + b_batch_stride="K * N", + bias_batch_stride="M * N", + c_batch_stride="0", + lda="K", + ldb="N", + ldbias="N", + ldc="N", + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) + + bmm_permute_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + emit_kernel=True, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.bmm_rrr_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_a_batch_stride_dim = "M * K" + input_a_stride_k_dim = "K" + input_a_offset = 0 + input_b_batch_stride_dim = "K * N" + input_b_stride_k_dim = "N" + input_b_offset = 0 + + if "input_accessors" in func_attrs: + input_a_accessor = func_attrs["input_accessors"][0] + input_b_accessor = func_attrs["input_accessors"][1] + + if input_a_accessor.is_from_strided_tensor: + input_a_offset = input_a_accessor.offset + if not input_a_accessor.is_contiguous: + a_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 0 + ) + + input_a_batch_stride_dim = input_a_accessor.gen_stride_str(0, a_dims) + input_a_stride_k_dim = input_a_accessor.stride(1) + + if input_b_accessor.is_from_strided_tensor: + input_b_offset = input_b_accessor.offset + if not input_b_accessor.is_contiguous: + b_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.INPUT, 1 + ) + input_b_batch_stride_dim = input_b_accessor.gen_stride_str(0, b_dims) + input_b_stride_k_dim = input_b_accessor.stride(1) + + input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render( + input_a_batch_stride_dim=input_a_batch_stride_dim, + input_a_stride_dim=input_a_stride_k_dim, + input_a_offset_val=input_a_offset, + input_b_batch_stride_dim=input_b_batch_stride_dim, + input_b_stride_dim=input_b_stride_k_dim, + input_b_offset_val=input_b_offset, + ) + + output_batch_stride_dim = "M * N" + output_stride_n_dim = "N" + output_offset = 0 + + if "output_accessors" in func_attrs: + output_accessor = func_attrs["output_accessors"][0] + if output_accessor.is_from_strided_tensor: + output_offset = output_accessor.offset + if not output_accessor.is_contiguous: + c_dims = bmm_common.reverse_dim_info_mapping( + dim_info_dict, gemm_common.Source.OUTPUT, 0 + ) + output_batch_stride_dim = output_accessor.gen_stride_str(0, c_dims) + output_stride_n_dim = output_accessor.stride(1) + + output_addr_calculator = bmm_common.OUTPUT_ADDR_CALCULATOR.render( + output_batch_stride_dim=output_batch_stride_dim, + output_stride_dim=output_stride_n_dim, + output_offset_val=output_offset, + ) + + bmm_problem_info = bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + a_ptr="(a_ptr + input_a_offset)", + b_ptr="(b_ptr + input_b_offset)", + bias_ptr="(c_ptr + output_offset)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="input_a_batch_stride", + b_batch_stride="input_b_batch_stride", + bias_batch_stride="output_batch_stride", + c_batch_stride="0", + lda="input_a_stride", + ldb="input_b_stride", + ldbias="output_stride", + ldc="output_stride", + ) + a_shapes = func_attrs["input_accessors"][0].original_shapes + b_shapes = func_attrs["input_accessors"][1].original_shapes + bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + return bmm_permute_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + input_addr_calculator, + output_addr_calculator, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.bmm_rrr_permute.func_decl") +def gen_function_decl(func_attrs): + return bmm_permute_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.bmm_rrr_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_permute_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.bmm_rrr_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_softmax_bmm_permute.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_softmax_bmm_permute.py new file mode 100644 index 000000000..742d601a0 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_softmax_bmm_permute.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ... import registry + + +@registry.reg("cuda.bmm_softmax_bmm_permute.func_decl") +def gen_function_decl(func_attrs): + raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") + + +@registry.reg("cuda.bmm_softmax_bmm_permute.gen_function") +def gen_function(func_attrs): + raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") + + +@registry.reg("cuda.bmm_softmax_bmm_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") diff --git a/python/aitemplate/backend/cuda/gemm_universal/common.py b/python/aitemplate/backend/cuda/gemm_universal/common.py new file mode 100644 index 000000000..199311035 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common.py @@ -0,0 +1,944 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for gemm. +""" + +import os +import random +import re +from collections import OrderedDict +from hashlib import sha1 +from typing import Any, Dict, List, Tuple + +import jinja2 + +from ....compiler.base import IntImm + +from ... import builder +from ...common import gemm_common, tensor_accessor_codegen +from ...target import Target + +# pylint: disable=C0301,C0415,R1705 + + +INPUT_ADDR_CALCULATOR = jinja2.Template( + """ + int64_t input_a_batch_stride = {{input_a_batch_stride_dim}}; + int64_t input_a_stride = {{input_a_stride_dim}}; + int64_t input_a_offset = {{input_a_offset_val}}; // default to 0 + int64_t input_b_batch_stride = {{input_b_batch_stride_dim}}; + int64_t input_b_stride = {{input_b_stride_dim}}; + int64_t input_b_offset = {{input_b_offset_val}}; // default to 0 + """ +) + + +# These should be only used for 2D gemm +# For templates for bmm, see bmm_common +OUTPUT_ADDR_CALCULATOR = jinja2.Template( + """ + {% if not output_accessor.is_from_strided_tensor %} + int64_t output_stride = {{stride_dim}}; + int64_t output_offset = 0; + {% else %} + int64_t output_stride = {{output_accessor.actual_total_elements_from_stride_dim}}; + int64_t output_offset = {{output_accessor.offset}}; + {% endif %} + """ +) + +DEFAULT_OUTPUT_ADDR_CALCULATOR = jinja2.Template( + """ + int64_t output_stride = {{stride_dim}}; + int64_t output_offset = 0; + """ +) + +DIM_DEFS_TEMPLATE = jinja2.Template( + """ +{% for dim, value in dims.items() %} +{{indent}}int64_t {{dim}} = {{value}}; +{% endfor %} +""" +) + + +INPUT_OUTPUT_CHECKS_TEMPLATE = jinja2.Template( + """ + int64_t a_size = 1; +{% for idx in range(input_ndims) %} + a_size *= *a_dim{{idx}}; +{% endfor %} + if (a_size != 0 && !a_ptr) { + throw std::runtime_error("input a is null!"); + } + + int64_t b_size = 1; +{% for idx in range(weight_ndims) %} + b_size *= *b_dim{{idx}}; +{% endfor %} + if (b_size != 0 && !b_ptr) { + throw std::runtime_error("input b is null!"); + } + + int64_t c_size = 1; +{% for idx in range(output_ndims) %} + c_size *= *c_dim{{idx}}; +{% endfor %} + if (c_size != 0) { + if (!c_ptr) { + throw std::runtime_error("input c is null!"); + } + } else { + // output is empty and safe to return + return; + } + + // One of the input tensor are empty + if (a_size == 0 || b_size == 0) { + return; + } +""" +) + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = {{config_name}}; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +{{extra_code}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, +{% if has_d %} + cutlass::half_t* d_ptr, +{% endif %} + cutlass::half_t* c_ptr, + uint8_t* workspace, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(output_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream + ) { + {{shape_eval}} + {{input_addr_calculator}} + {{output_addr_calculator}} + {{extra_shape}} + {{input_output_checks}} + + {{exec_paths}} + {% for idx in range(input_ndims) %} + std::cout << "input_ndims{{idx}}: " << *a_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(weight_ndims) %} + std::cout << "weight_ndims{{idx}}: " << *b_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(output_ndims) %} + std::cout << "output_ndims{{idx}}: " << *c_dim{{idx}} << std::endl; + {% endfor %} + throw std::runtime_error( + "Unsupported workload for this {{function_name}} specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +// TODO: cast to right dtype +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementAccumulator; + +{{indent}}typename {{instance}}::Arguments arguments{ + +{{problem_args}} + +{{indent}}}; +{{indent}}{{instance}} gemm_op; +{% if is_profiler %} +{{indent}}// https://www.youtube.com/watch?v=rRwxfYlgG-M +{{indent}}size_t workspace_size = gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}auto status = gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op.initialize(arguments, workspace, stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}}{{local_dim_defs}} +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{% if has_bias %} +{{indent}} {{bias_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} global_workspace, +{{indent}} {{split_k}}, +{% for dim in adims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in bdims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in cdims %} +{{indent}} {{dim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +{{indent}}} +""" +) + + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + int64_t a_ptr_sz = a_dim0 * a_dim1; + int64_t b_ptr_sz = b_dim0 * b_dim1; + int64_t c_ptr_sz = c_dim0 * c_dim1; + + // The value 1 is used to force ptr_max_sz to be non-zero + int64_t ptr_max_sz = std::max({1, a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + +{% if has_bias %} + memory_pool->AllocateHalfTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 +{% endif %} + +""" +) + + +# TODO Merge all alignment into single profiler +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; + +{{op_func}} + +struct ProfilerMemoryPool { + ProfilerMemoryPool() { + std::random_device rd; + gen = std::mt19937(rd()); + uniform_dist = std::uniform_int_distribution(1, 48964896); + offsets.reserve(512); + strides.reserve(512); + copies.reserve(512); + ptrs.reserve(512); + blobs.reserve(512); + } + ~ProfilerMemoryPool() {} + + template + DType* AllocateGaussianTensor(int64_t size) { + size_t length = size * sizeof(DType); + blobs.emplace_back(length); + DType* ptr = reinterpret_cast(blobs.back().get()); + + uint64_t seed = uniform_dist(gen); + double mean = 0.f; + double std = 1.f; + + cutlass::reference::device::BlockFillRandomGaussian(ptr, size, seed, mean, + std); + + return ptr; + } + + + cutlass::half_t* AllocateHalfGaussianTensor(int64_t size) { + return reinterpret_cast( + AllocateGaussianTensor<__half>(size)); + } + + int AllocateHalfTensor(int64_t size, int64_t copy) { + offsets.push_back(0); + strides.push_back(size); + copies.push_back(copy); + auto ptr = AllocateHalfGaussianTensor(size * copy); + ptrs.push_back(reinterpret_cast(ptr)); + return ptrs.size() - 1; + } + + cutlass::half_t* RequestHalfTensorByIdx(int idx) { + auto copy = copies.at(idx); + auto offset = offsets.at(idx); + auto stride = strides.at(idx); + cutlass::half_t* ptr = reinterpret_cast(ptrs.at(idx)); + ptr += offset; + offset += stride; + if (offset == copy * stride) { + offset = 0; + } + offsets[idx] = offset; + return ptr; + } + + std::vector offsets; + std::vector strides; + std::vector copies; + std::vector ptrs; + std::vector > blobs; + std::mt19937 gen; + std::uniform_int_distribution uniform_dist; +}; + + +int main(int argc, char** argv) { + int device_idx; + cudaDeviceProp device_properties; + cudaError_t result = cudaGetDevice(&device_idx); + auto memory_pool = std::make_unique(); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&device_properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + {{args_parse}} + + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + {{tensor_decl}} + + // warmup + for (int i = 0; i < 5; ++i) { + {{func_call}} + } + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 10; ++i) { + {{func_call}} + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; + return 0; +} +""" +) + + +KERNEL_KEY_TEMPLATE = jinja2.Template( + """ +cutlass_{{opcode_class_name}}_{{extended_name}}_{{threadblock}}_{{layout}}_align_{{align_ab}}_{{align_c}} +""" +) + + +def has_d(func_attrs): + if "has_d" in func_attrs: + return func_attrs["has_d"] + else: + return False + + +def has_d1(func_attrs): + return func_attrs.get("num_sources", 0) >= 2 + + +def get_gemm_instance_template_params( + op_def: str, + kernel_config: Tuple[str, int, int] = ("cutlass::gemm::device::Gemm", 21, 3), +) -> List[str]: + """ + For a given op_def string generated by cutlass's gemm emiter, parse and + return the gemm instance's template parameters. + kernel_config is a tuple used for finding kernel params. The first element + of kernel_config is the kernel kind, the second is the expected number + of params, and the third is the index offset of alignment values in the + full op_def string. + """ + kernel_kind, expected_num_params, _ = kernel_config + params = re.findall(rf"{kernel_kind}<([\s\S]+)>;", op_def) + assert len(params) == 1 + param = params[0] + gemm_universal_params = param.strip().split("\n") + gemm_universal_params = [param.strip(",") for param in gemm_universal_params] + assert len(gemm_universal_params) == expected_num_params, ( + f"expected len(gemm_universal_params) to be {expected_num_params}, but got " + f"{len(gemm_universal_params)}, {gemm_universal_params=}" + ) + return gemm_universal_params + + +def update_alignments_in_gemm_instance( + op_def: str, + func_attrs: Dict[str, Any], + for_profiler: bool, + kernel_config: Tuple[str, int, int] = ("cutlass::gemm::device::Gemm", 21, 3), +) -> str: + """ + update kAlignmentA, kAlignmentB, and epilogue_alignment in op_def, + which is a gemm instance emitted by the gemm instance emitter of cutlass. + kernel_config is a tuple used for finding kernel params. The first element + of kernel_config is the kernel kind, the second is the expected number + of params, and the third is the index offset of alignment values in the + full op_def string. + """ + if for_profiler: + return op_def + + input_accessors = func_attrs["input_accessors"] + a_alignment = tensor_accessor_codegen.find_max_alignment_for_accessor( + input_accessors[0] + ) + b_alignment = tensor_accessor_codegen.find_max_alignment_for_accessor( + input_accessors[1] + ) + output_accessor = func_attrs["output_accessors"][0] + epilogue_alignment = tensor_accessor_codegen.find_max_alignment_for_accessor( + output_accessor + ) + gemm_params = get_gemm_instance_template_params(op_def, kernel_config) + epilogue_align_idx = 11 + a_align_idx = 17 + b_align_idx = 18 + a_curr_align = gemm_params[a_align_idx].strip() + b_curr_align = gemm_params[b_align_idx].strip() + epilogue_curr_align = gemm_params[epilogue_align_idx].strip() + a_alignment = min(a_alignment, int(a_curr_align)) + b_alignment = min(b_alignment, int(b_curr_align)) + epilogue_alignment = min(epilogue_alignment, int(epilogue_curr_align)) + instance_lines = op_def.split("\n") + # a_align_idx + idx_offset in the full instance string + idx_offset = kernel_config[2] + + def _replace_align(align_idx, curr_align, alignment): + curr_align_line = instance_lines[align_idx + idx_offset] + assert curr_align == curr_align_line.strip( + " ," + ), f"expected {curr_align=} equal to {curr_align_line=}" + instance_lines[align_idx + idx_offset] = curr_align_line.replace( + curr_align, str(alignment) + ) + + _replace_align(a_align_idx, a_curr_align, a_alignment) + _replace_align(b_align_idx, b_curr_align, b_alignment) + _replace_align(epilogue_align_idx, epilogue_curr_align, epilogue_alignment) + return "\n".join(instance_lines) + + +def universal_gemm_instance( + op_def: str, func_attrs: Dict[str, Any], for_profiler: bool +) -> str: + op_def = update_alignments_in_gemm_instance(op_def, func_attrs, for_profiler) + tmp = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + tmp = tmp.replace("false,", "") + return tmp + + +def kernel_name(op): + """Returns kernel_name of a given cutlass op_instance.""" + from cutlass_lib import library + + threadblock = op.tile_description.procedural_name() + extended_name = op.extended_name() + opcode_class_name = library.OpcodeClassNames[ + op.tile_description.math_instruction.opcode_class + ] + layout = op.layout_name() + align_ab = op.A.alignment + align_c = op.C.alignment + name = KERNEL_KEY_TEMPLATE.render( + threadblock=threadblock, + extended_name=extended_name, + opcode_class_name=opcode_class_name, + layout=layout, + align_ab=align_ab, + align_c=align_c, + ) + return name.replace("\n", "") + + +def emit_instance( + op, + for_profiler, + f_instance_convertor=universal_gemm_instance, + emit_kernel=False, + func_attrs=None, +): + import cutlass_lib + + emitter = cutlass_lib.gemm_operation.EmitGemmInstance() + if emit_kernel: + emitter = cutlass_lib.gemm_operation.EmitGemmUniversalInstance() + op_def = emitter.emit(op) + op_def = f_instance_convertor(op_def, func_attrs, for_profiler) + return op_def + + +def extract_config(f_proc_op): + import cutlass_lib + + op_kind = cutlass_lib.library.OperationKind.Gemm + gemm_kind = cutlass_lib.library.GemmKind.Universal + gemm_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.gemm_kind == gemm_kind: + ret = f_proc_op(op) + if len(ret) > 0: + for op_inst in ret: + key = kernel_name(op_inst) + gemm_ops[key] = op_inst + return gemm_ops + + +def extract_config_name(config): + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = config.split("\n")[2] + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid config: \n" + config) + return match.groups()[0] + + +def gen_function( + func_attrs, + src_template, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + f_instance_convertor=universal_gemm_instance, + emit_kernel=False, + support_split_k=False, + input_addr_calculator="", + output_addr_calculator="", + extra_code="", +): + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + inst_def_flag = set() + instances = {} + instance_decl = "" + for exec_item in exec_path.values(): + fname = "f" + sha1(exec_item.exec_cond.encode()).hexdigest() + algo = exec_item.algo + if algo not in inst_def_flag: + config = emit_instance( + op_instance[algo], + for_profiler=False, + f_instance_convertor=f_instance_convertor, + emit_kernel=emit_kernel, + func_attrs=func_attrs, + ) + inst_def_flag.add(algo) + else: + config = "" + inst = INSTANCE_TEMPLATE.render( + config=config, name=fname, config_name=extract_config_name(config) + ) + instances[exec_item.exec_cond] = inst + instance_decl += inst + shape_eval_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + exec_paths = "" + for key in instances: + fname = "f" + sha1(key.encode()).hexdigest() + program = EXEC_TEMPLATE.render( + indent=" ", + instance=fname, + problem_args=problem_args, + support_split_k=support_split_k, + ) + exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + input_output_checks = INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + ) + return src_template.render( + instances=instance_decl, + function_name=func_name, + dtype="cutlass::half_t", + shape_eval=shape_eval_func, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=output_addr_calculator, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + support_split_k=support_split_k, + has_d=has_d(func_attrs), + has_d1=has_d1(func_attrs), + extra_code=extra_code, + ) + + +def build_profiler(file_pairs): + target = Target.current() + if target.disable_profiler_codegen(): + file_pairs = [] + elif target.use_dummy_profiling_results(): + # if it is circle CI only random build 2 profiler + random.shuffle(file_pairs) + file_pairs = file_pairs[:2] + compile_engine = builder.Builder() + compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + + +def add_profiler(file_pairs, workdir, op_type, output_name, code): + prefix = os.path.join(workdir, "profiler", op_type) + if not os.path.exists(prefix): + os.makedirs(prefix) + src_path = os.path.join(prefix, output_name + ".cu") + obj_path = os.path.join(prefix, output_name) + if os.path.exists(obj_path): + return + with open(src_path, "w") as f: + f.write(code) + file_pairs.append((src_path, obj_path)) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + args_parser_template, + support_split_k=False, + output_addr_calculator="", + bias_ptr_arg=None, + extra_code="", +): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + ndims = 2 + adims = ["&a_dim" + str(i) for i in range(ndims)] + bdims = ["&b_dim" + str(i) for i in range(ndims)] + cdims = ["&c_dim" + str(i) for i in range(ndims)] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + for op_name, op in op_instance.items(): + config = emit_instance(op, for_profiler=True) + config_name = extract_config_name(config) + name = "GemmInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render(), + ) + input_output_checks = INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + op_func = src_template.render( + instances=instance, + function_name="gemm", + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + func_call = FUNC_CALL_TEMPLATE.render( + func_name="gemm", + a_ptr="memory_pool->RequestHalfTensorByIdx(0)", + b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + split_k="split_k", + adims=adims, + bdims=bdims, + cdims=cdims, + ) + # TODO: Render args_parse by caller. + args_parse = ( + args_parser_template + if isinstance(args_parser_template, str) + else args_parser_template.render() + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parse, + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_bias=has_bias), + ) + add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + build_profiler(file_pairs) + + +def gen_local_dim_defs(func_attrs, indent=" "): + """ + used together with input TensorAccessor to access a strided input + """ + if "input_accessors" not in func_attrs: + return "" + + dims = {} + for input_idx, input_accessor in enumerate(func_attrs["input_accessors"]): + if not input_accessor.is_from_strided_tensor: + continue + original_shape = input_accessor.original_shapes + for idx, dim in enumerate(original_shape): + # skip dynamic dims + if isinstance(dim, IntImm): + input_shape = func_attrs["inputs"][input_idx]._attrs["shape"] + name = input_shape[idx]._attrs["name"] + if name in dims: + assert dims[name] == dim.value(), "bmm inputs shape mismatch" + else: + dims[name] = dim.value() + return DIM_DEFS_TEMPLATE.render(dims=dims, indent=indent) + + +def gen_function_call(func_attrs, indent=" ", bias_ptr_arg=None): + a = func_attrs["inputs"][0] + ashapes = func_attrs["input_accessors"][0].original_shapes + b = func_attrs["inputs"][1] + bshapes = func_attrs["input_accessors"][1].original_shapes + c = func_attrs["outputs"][0] + cshapes = func_attrs["output_accessors"][0].original_shapes + has_bias = bias_ptr_arg is not None + # overwrite the global defs if we have input TensorAccessor + local_dim_defs = gen_local_dim_defs(func_attrs, indent=indent) + adims = ["&" + dim._attrs["name"] for dim in ashapes] + bdims = ["&" + dim._attrs["name"] for dim in bshapes] + cdims = ["&" + dim._attrs["name"] for dim in cshapes] + return FUNC_CALL_TEMPLATE.render( + local_dim_defs=local_dim_defs, + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr=c._attrs["name"], + split_k=func_attrs["split_k"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + ) + + +def default_fproc_f16(*, op, a_layout, b_layout, c_layout, epiligue_name): + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.accumulator_type() == acc_type + and op.A.layout == a_layout + and op.B.layout == b_layout + ): + op = copy.deepcopy(op) + # set output major + op.C.layout = c_layout + # set epilogue + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epiligue_name] + op.element_epilogue = acc_type + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + +def make_fproc_f16(func_attrs, layout): + """ + This function sets a callback for processing the epilogue of the kernel + associated with func_attrs. + """ + + def fproc_f16(op): + a_layout, b_layout, c_layout = layout.cutlass_lib_layouts() + return default_fproc_f16( + op=op, + a_layout=a_layout, + b_layout=b_layout, + c_layout=c_layout, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = extract_config(fproc_f16) + + +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + tmp = cfg.split("_") + align_c = int(tmp[-1]) + align_ab = int(tmp[-2]) + if align_c != func_attrs["epilogue_alignment"]: + return False + if align_ab != ab_alignment: + return False + return True diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias.py new file mode 100644 index 000000000..98d8e979c --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for gemm with bias. +""" + +import jinja2 + +# pylint: disable=C0301,C0415,R1705 + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = {{config_name}}; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +{{extra_code}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* bias_ptr, + cutlass::half_t* c_ptr, + uint8_t* workspace, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream + ) { + {{shape_eval}} + {{input_addr_calculator}} + {{output_addr_calculator}} + {{extra_shape}} + {{input_output_checks}} + + if (!bias_ptr) { + throw std::runtime_error("bias_ptr is null!"); + } + + {{exec_paths}} + {% for idx in range(input_ndims) %} + std::cout << "input_ndims{{idx}}: " << *a_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(weight_ndims) %} + std::cout << "weight_ndims{{idx}}: " << *b_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(input_ndims) %} + std::cout << "output_ndims{{idx}}: " << *c_dim{{idx}} << std::endl; + {% endfor %} + throw std::runtime_error( + "Unsupported workload for this {{function_name}} specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py new file mode 100644 index 000000000..843230243 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Common codegen functions for gemm_bias_activation. +""" + +from . import common, common_bias, gemm_rcr +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def gemm_rcr_config(func_attrs, dtype="float16"): + common.make_fproc_f16(func_attrs, RCR) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + problem_args_template, + extra_code="", +): + gemm_rcr.common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + problem_args_template, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + extra_code=extra_code, + ) + + +def gen_function( + func_attrs, + problem_args_template, + exec_cond_template, + dim_info_dict, + extra_code="", +): + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + problem_args = problem_args_template.render() + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", + output_accessor=func_attrs["output_accessors"][0], + ), + extra_code=extra_code, + ) + + +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py new file mode 100644 index 000000000..5c46b3cc5 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py @@ -0,0 +1,585 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = UnaryOp2(BinaryOp2(BinaryOp1(UnaryOp1(GeMM(A, B) + bias), D1), D2)), +""" + +import re +from functools import partial + +import jinja2 + +from ...common import gemm_common +from ...target import Target + +from . import common, gemm_rcr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +# For config extraction. +GEMM_UNIVERSAL_WITH_BROADCAST_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::device::GemmUniversalWithBroadcast< + cutlass::half_t, {{layout.cutlass_layout_a}}, + cutlass::half_t, {{layout.cutlass_layout_b}}, + cutlass::half_t, {{layout.cutlass_layout_c}}, + {{acc_type}}, + cutlass::arch::OpClassTensorOp, + {{arch}}, + {{tb_shape}}, + {{warp_shape}}, + {{instruction_shape}}, + {{epilogue_functor}}< + cutlass::half_t, {{acc_type}}, {{acc_type}}, + cutlass::half_t, {{epilogue_vector_length}}, + {{unary_op1}}, {{binary_op1}}, {{unary_op2}} +{% if has_d1 %} + , {{binary_op2}} +{% endif %} + >, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + {{stage}}, + {{alignment_a}}, + {{alignment_b}} + >; +""" +) + +# For func codegen. +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + { {{layout.m}}, {{layout.n}}, {{layout.k}} }, +{% if support_split_k %} + split_k, +{% else %} + 1, +{% endif %} + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) (a_ptr + input_a_offset), + (void*) (b_ptr + input_b_offset), + (void*) d0_ptr, +{% if has_d1 %} + (void*) d1_ptr, +{% else %} + nullptr, +{% endif %} + (void*) (c_ptr + output_offset), + (void*) bias_ptr, + nullptr, + /*batch_stride_A*/ input_a_batch_stride, + /*batch_stride_B*/ input_b_batch_stride, + /*batch_stride_C1*/ 0, + /*batch_stride_C2*/ 0, + /*batch_stride_D*/ 0, + /*batch_stride_Vector*/ 0, + /*batch_stride_Tensor*/ 0, + input_a_stride, + input_b_stride, + {{layout.stride_c}}, +{% if has_d1 %} + {{layout.stride_c}}, +{% else %} + 0, +{% endif %} + output_stride, + /*ldr*/ 0, + /*/ldt*/ 0 +""" +) + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + { {{layout.m}}, {{layout.n}}, {{layout.k}} }, +{% if support_split_k %} + split_k, +{% else %} + 1, +{% endif %} + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) d0_ptr, +{% if has_d1 %} + (void*) d1_ptr, +{% else %} + nullptr, +{% endif %} + (void*) (c_ptr + output_offset), + (void*) bias_ptr, + nullptr, + /*batch_stride_A*/ 0, + /*batch_stride_B*/ 0, + /*batch_stride_C1*/ 0, + /*batch_stride_C2*/ 0, + /*batch_stride_D*/ 0, + /*batch_stride_Vector*/ 0, + /*batch_stride_Tensor*/ 0, + {{layout.stride_a}}, + {{layout.stride_b}}, + {{layout.stride_c}}, +{% if has_d1 %} + {{layout.stride_c}}, +{% else %} + 0, +{% endif %} + output_stride, + /*ldr*/ 0, + /*/ldt*/ 0 +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block_v2.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +void {{function_name}} ( + cutlass::half_t* a_ptr, + cutlass::half_t* b_ptr, + cutlass::half_t* bias_ptr, + cutlass::half_t* d0_ptr, +{% if has_d1 %} + cutlass::half_t* d1_ptr, +{% endif %} + cutlass::half_t* c_ptr, + uint8_t* workspace, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream + ) { + {{shape_eval}} + {{input_addr_calculator}} + {{output_addr_calculator}} + {{extra_shape}} + {{input_output_checks}} + + if (!bias_ptr) { + throw std::runtime_error("bias is null!"); + } + if (!d0_ptr) { + throw std::runtime_error("d0_ptr is null!"); + } +{% if has_d1 %} + if (!d1_ptr) { + throw std::runtime_error("d1_ptr is null!"); + } +{% endif %} + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this {{function_name}} specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) + +# For function declaration codegen. +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, + cutlass::half_t*, +{% if has_d1 %} + cutlass::half_t*, +{% endif %} + cutlass::half_t*, + uint8_t*, +{% if support_split_k %} + int, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t*, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t*, +{% endfor %} + cudaStream_t +); +""" +) + + +# For function call codegen. +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}}{{local_dim_defs}} +{{indent}}{{func_name}}( +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{{indent}} {{bias_ptr}}, +{{indent}} {{d0_ptr}}, +{% if has_d1 %} +{{indent}} {{d1_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} global_workspace, +{% if support_split_k %} +{{indent}} {{split_k}}, +{% endif %} +{% for dim in adims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in bdims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in cdims %} +{{indent}} {{dim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +{{indent}}} +""" +) + +# For profiler codegen. +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); +{% if support_split_k %} + int split_k = std::atoi(argv[4]); +{% endif %} + {{layout.args_parser}} +""" +) + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + int64_t a_ptr_sz = a_dim0 * a_dim1; + int64_t b_ptr_sz = b_dim0 * b_dim1; + int64_t c_ptr_sz = c_dim0 * c_dim1; + // The value 1 is used to force ptr_max_sz to be non-zero + int64_t ptr_max_sz = std::max({1, a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateHalfTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4 +{% if has_d1 %} + memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5 +{% endif %} +""" +) + + +def _support_split_k(func_attrs): + return func_attrs["split_k"] is not None + + +def gemm_bias_broadcast_instance( + op_def, + func_attrs, + for_profiler, + layout, + unary_op1, + binary_op1, + binary_op2, + unary_op2, +): + """ + adjust gemm instance with respect to input_accessors, layout and epilogue ops + """ + op_def = common.update_alignments_in_gemm_instance(op_def, func_attrs, for_profiler) + gemm_universal_params = common.get_gemm_instance_template_params(op_def) + epilogue_pattern = re.compile(r"\s*(cutlass::epilogue::thread::.*)\s*<") + match = epilogue_pattern.match(gemm_universal_params[9]) + if match is None: + raise RuntimeError("Invalid epilogue functor:\n" + gemm_universal_params[9]) + epilogue_functor = match.groups()[0] + + if ( + "use_fp16_acc" in Target.current()._kwargs + and Target.current()._kwargs["use_fp16_acc"] + ): + acc_type = "cutlass::half_t" + else: + acc_type = "float" + gemm_universal_with_broadcast_params = ( + GEMM_UNIVERSAL_WITH_BROADCAST_TEMPLATE.render( + arch=gemm_universal_params[5], + tb_shape=gemm_universal_params[6], + warp_shape=gemm_universal_params[7], + instruction_shape=gemm_universal_params[8], + epilogue_functor=epilogue_functor, + epilogue_vector_length=gemm_universal_params[11], + unary_op1=unary_op1, + binary_op1=binary_op1, + binary_op2=binary_op2, + unary_op2=unary_op2, + stage=gemm_universal_params[16], + alignment_a=gemm_universal_params[17], + alignment_b=gemm_universal_params[18], + layout=layout, + acc_type=acc_type, + has_d1=(binary_op2 is not None), + ) + ) + res = re.sub( + r"cutlass::gemm::device::Gemm<[\s\S]+>;", + gemm_universal_with_broadcast_params, + op_def, + ) + return res + + +def gemm_bias_broadcast_config(func_attrs, layout, dtype="float16"): + common.make_fproc_f16(func_attrs, layout) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + layout, + unary_op1, + binary_op1, + binary_op2, + unary_op2, +): + op_type = func_attrs["op"] + support_split_k = _support_split_k(func_attrs) + op_instance = func_attrs["op_instance"] + has_d1 = common.has_d1(func_attrs) + + ndims = 2 + adims = ["&a_dim" + str(i) for i in range(ndims)] + bdims = ["&b_dim" + str(i) for i in range(ndims)] + cdims = ["&c_dim" + str(i) for i in range(ndims)] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance( + op, + for_profiler=True, + f_instance_convertor=partial( + gemm_bias_broadcast_instance, + layout=layout, + unary_op1=unary_op1, + binary_op1=binary_op1, + binary_op2=binary_op2, + unary_op2=unary_op2, + ), + ) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + problem_args=PROFILER_PROBLEM_ARGS_TEMPLATE.render( + support_split_k=support_split_k, layout=layout, has_d1=has_d1 + ), + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="gemm", + input_ndims=ndims, + weight_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N" + ), + support_split_k=support_split_k, + has_d1=has_d1, + ) + func_call = FUNC_CALL_TEMPLATE.render( + func_name="gemm", + a_ptr="memory_pool->RequestHalfTensorByIdx(0)", + b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + d0_ptr="memory_pool->RequestHalfTensorByIdx(4)", + d1_ptr="memory_pool->RequestHalfTensorByIdx(5)", + bias_ptr="memory_pool->RequestHalfTensorByIdx(3)", + adims=adims, + bdims=bdims, + cdims=cdims, + support_split_k=support_split_k, + split_k="split_k", + has_d1=has_d1, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=ARGS_PARSER_TEMPLATE.render( + layout=layout, support_split_k=support_split_k + ), + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_d1=has_d1), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + layout, + unary_op1, + binary_op1, + binary_op2, + unary_op2, +): + input_addr_calculator = gemm_rcr.get_input_addr_calculator(func_attrs) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + support_split_k = _support_split_k(func_attrs) + has_d1 = common.has_d1(func_attrs) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + layout=layout, support_split_k=support_split_k, has_d1=has_d1 + ) + return common.gen_function( + func_attrs, + SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + f_instance_convertor=partial( + gemm_bias_broadcast_instance, + layout=layout, + unary_op1=unary_op1, + binary_op1=binary_op1, + binary_op2=binary_op2, + unary_op2=unary_op2, + ), + support_split_k=support_split_k, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", + output_accessor=func_attrs["output_accessors"][0], + ), + ) + + +def gen_function_decl(func_attrs): + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return FUNC_DECL_TEMPLATE.render( + func_name=func_attrs["name"], + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=_support_split_k(func_attrs), + has_d1=common.has_d1(func_attrs), + ) + + +def gen_function_call(func_attrs, indent=" "): + has_d1 = common.has_d1(func_attrs) + if has_d1: + (a, b, bias, d0, d1) = func_attrs["inputs"] + else: + (a, b, bias, d0) = func_attrs["inputs"] + d1 = None + c = func_attrs["outputs"][0] + # overwrite the global defs if we have input TensorAccessor + local_dim_defs = common.gen_local_dim_defs(func_attrs, indent=indent) + adims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][0].original_shapes + ] + bdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][1].original_shapes + ] + cdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["output_accessors"][0].original_shapes + ] + return FUNC_CALL_TEMPLATE.render( + local_dim_defs=local_dim_defs, + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + bias_ptr=bias._attrs["name"], + d0_ptr=d0._attrs["name"], + d1_ptr=d1._attrs["name"] if has_d1 else "", + c_ptr=c._attrs["name"], + split_k=func_attrs["split_k"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + support_split_k=_support_split_k(func_attrs), + has_d1=has_d1, + ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_permute.py b/python/aitemplate/backend/cuda/gemm_universal/common_permute.py new file mode 100644 index 000000000..2f3f1e903 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common_permute.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for gemm + permute. +""" + +import re +from collections import OrderedDict +from hashlib import sha1 + +import jinja2 + +from ...common import gemm_common +from ...target import Target +from ..gemm_universal import common + +# pylint: disable=C0301,C0415,R1705 + +EXTRA_CODE = jinja2.Template( + """ +#include "cutlass/layout/permute.h" +""" +) + +# HACK: we don't record different permutation shape, +# because it has little impact on execution time compared. +# Therefore, no matter what permutation shape it is, +# we will use the same kernel, i.e. the first generated perm_shape +# At runtime, the kernel will be regenerated and thus the correctness will not be affected. +KERNEL_KEY_TEMPLATE = jinja2.Template( + """ +cutlass_{{opcode_class_name}}_{{extended_name}}_{{threadblock}}_{{layout}}_{{perm_type}}_{{perm_shape}}_align_{{align_ab}}_{{align_c}} +""" +) + + +def kernel_name(op, func_attrs): + """Returns kernel_name given input cutlass op_instance and operator attrs.""" + + from cutlass_lib import library + + threadblock = op.tile_description.procedural_name() + extended_name = op.extended_name() + opcode_class_name = library.OpcodeClassNames[ + op.tile_description.math_instruction.opcode_class + ] + layout = op.layout_name() + align_ab = op.A.alignment + align_c = op.C.alignment + shape = func_attrs["shape"] + if len(shape) == 1: + perm_type = "perm4d" + perm_shape = f"{shape[0]}" + elif len(shape) == 3: + perm_type = "perm5d" + perm_shape = f"{shape[0]}_{shape[1]}_{shape[2]}" + else: + raise NotImplementedError( + f"gemm permute shape with {shape} is not implemented!" + ) + name = KERNEL_KEY_TEMPLATE.render( + threadblock=threadblock, + extended_name=extended_name, + opcode_class_name=opcode_class_name, + layout=layout, + align_ab=align_ab, + align_c=align_c, + perm_type=perm_type, + perm_shape=perm_shape, + ) + return name.replace("\n", "") + + +def default_fproc_f16( + *, op, a_layout, b_layout, c_layout, epiligue_name, permute_layout +): + """Generates new op_instances by adding alignment info, permute_layout, etc.""" + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.accumulator_type() == acc_type + and op.A.layout == a_layout + and op.B.layout == b_layout + ): + op = copy.deepcopy(op) + # set output major + op.C.layout = c_layout + # set epilogue + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epiligue_name] + op.element_epilogue = acc_type + op.permute_layout = cutlass_lib.library.EpiloguePermuteLayoutName[ + permute_layout + ] + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + +def extract_config(f_proc_op, func_attrs): + import cutlass_lib + + op_kind = cutlass_lib.library.OperationKind.Gemm + gemm_kind = cutlass_lib.library.GemmKind.Universal + gemm_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.gemm_kind == gemm_kind: + ret = f_proc_op(op) + if len(ret) > 0: + for op_inst in ret: + key = kernel_name(op_inst, func_attrs) + gemm_ops[key] = op_inst + return gemm_ops + + +def gemm_permute_instance(op_def, func_attrs, for_profiler): + import cutlass_lib + + op_def = common.update_alignments_in_gemm_instance( + op_def, + func_attrs, + for_profiler, + # expected to have 26 of params, the index offset of alignment value + # in the full op_def string is 4 + kernel_config=("cutlass::gemm::device::GemmUniversal", 26, 4), + ) + shape_info = ", ".join(map(str, func_attrs["shape"])) + layout = cutlass_lib.library.EpiloguePermuteLayoutName[func_attrs["layout"]] + layout_class = cutlass_lib.library.EpiloguePermuteLayoutTag[layout] + tmp = re.sub( + r"{}".format(layout_class), "{}<{}>".format(layout_class, shape_info), op_def + ) + return tmp + + +def emit_instance( + op, + for_profiler, + f_instance_convertor=gemm_permute_instance, + emit_kernel=False, + func_attrs=None, +): + import cutlass_lib + + emiter = cutlass_lib.gemm_operation.EmitGemmInstance() + if emit_kernel: + emiter = cutlass_lib.gemm_operation.EmitGemmPermuteInstance() + + op_def = emiter.emit(op) + op_def = f_instance_convertor(op_def, func_attrs, for_profiler) + return op_def + + +def gen_function( + func_attrs, + src_template, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + f_instance_convertor=gemm_permute_instance, + emit_kernel=False, + support_split_k=False, + input_addr_calculator="", + output_addr_calculator="", + extra_code="", +): + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + inst_def_flag = set() + instances = {} + instance_decl = "" + for exec_item in exec_path.values(): + fname = "f" + sha1(exec_item.exec_cond.encode()).hexdigest() + algo = exec_item.algo + if algo not in inst_def_flag: + config = emit_instance( + op_instance[algo], + for_profiler=False, + f_instance_convertor=f_instance_convertor, + emit_kernel=emit_kernel, + func_attrs=func_attrs, + ) + inst_def_flag.add(algo) + else: + config = "" + inst = common.INSTANCE_TEMPLATE.render( + config=config, name=fname, config_name=common.extract_config_name(config) + ) + instances[exec_item.exec_cond] = inst + instance_decl += inst + shape_eval_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + exec_paths = "" + for key, _ in instances.items(): + fname = "f" + sha1(key.encode()).hexdigest() + program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=fname, + problem_args=problem_args, + support_split_k=support_split_k, + ) + exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + ) + return src_template.render( + instances=instance_decl, + function_name=func_name, + dtype="cutlass::half_t", + shape_eval=shape_eval_func, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=output_addr_calculator, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + support_split_k=support_split_k, + has_d=common.has_d(func_attrs), + has_d1=common.has_d1(func_attrs), + extra_code=extra_code, + ) + + +def gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + args_parser_template, + emit_kernel=False, + support_split_k=False, + output_addr_calculator="", + bias_ptr_arg=None, + extra_code="", +): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + + ndims = 2 + adims = ["&a_dim" + str(i) for i in range(ndims)] + bdims = ["&b_dim" + str(i) for i in range(ndims)] + cdims = ["&c_dim" + str(i) for i in range(ndims)] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + file_pairs = [] + has_bias = bias_ptr_arg is not None + for op_name, op in op_instance.items(): + config = emit_instance( + op, for_profiler=True, emit_kernel=emit_kernel, func_attrs=func_attrs + ) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=name, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render(), + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + op_func = src_template.render( + instances=instance, + function_name="gemm", + input_ndims=2, + weight_ndims=2, + output_ndims=2, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + func_call = common.FUNC_CALL_TEMPLATE.render( + func_name="gemm", + a_ptr="memory_pool->RequestHalfTensorByIdx(0)", + b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + split_k="split_k", + adims=adims, + bdims=bdims, + cdims=cdims, + ) + # TODO: Render args_parse by caller. + args_parse = ( + args_parser_template + if isinstance(args_parser_template, str) + else args_parser_template.render() + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=args_parse, + func_call=func_call, + name=name, + tensor_decl=common.TENSOR_DECL_TEMPLATE.render( + name=name, has_bias=has_bias + ), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py new file mode 100644 index 000000000..0fb211cb0 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = GeMM(A, B) +where A[RowMajor][M, K], B[ColMajor][N, K] +""" +import jinja2 + +from ... import registry +from . import common +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + +# used for real execution +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) (a_ptr + input_a_offset), + (void*) (b_ptr + input_b_offset), + (void*) (c_ptr + output_offset), + (void*) (c_ptr + output_offset), + input_a_batch_stride, + input_b_batch_stride, + /*output_batch_stride*/ M * N, + /*output_batch_stride*/ M * N, + input_a_stride, + input_b_stride, + output_stride, + output_stride +""" +) + + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) c_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + M * N, + M * N, + K, + K, + N, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + common.make_fproc_f16(func_attrs, RCR) + + +def common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim0" + ) + common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rcr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + PROFILER_PROBLEM_ARGS_TEMPLATE, + ) + + +def get_input_addr_calculator(func_attrs): + input_a_batch_stride_dim = "M * K" + input_a_stride_k_dim = "K" + input_a_offset = 0 + input_b_batch_stride_dim = "N * K" + input_b_stride_k_dim = "K" + input_b_offset = 0 + + if "input_accessors" in func_attrs: + input_a_accessor = func_attrs["input_accessors"][0] + input_b_accessor = func_attrs["input_accessors"][1] + if input_a_accessor.is_from_strided_tensor: + input_a_offset = input_a_accessor.offset + shapes = input_a_accessor.original_shapes + input_a_stride_k_dim = input_a_accessor.stride(len(shapes) - 2) + + if input_b_accessor.is_from_strided_tensor: + input_b_offset = input_b_accessor.offset + shapes = input_b_accessor.original_shapes + input_b_stride_k_dim = input_b_accessor.stride(len(shapes) - 2) + + input_addr_calculator = common.INPUT_ADDR_CALCULATOR.render( + input_a_batch_stride_dim=input_a_batch_stride_dim, + input_a_stride_dim=input_a_stride_k_dim, + input_a_offset_val=input_a_offset, + input_b_batch_stride_dim=input_b_batch_stride_dim, + input_b_stride_dim=input_b_stride_k_dim, + input_b_offset_val=input_b_offset, + ) + return input_addr_calculator + + +@registry.reg("cuda.gemm_rcr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_addr_calculator = get_input_addr_calculator(func_attrs) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + problem_args = PROBLEM_ARGS_TEMPLATE.render() + return common.gen_function( + func_attrs, + common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + support_split_k=True, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + ) + + +@registry.reg("cuda.gemm_rcr.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rcr.func_call") +def gen_function_call(func_attrs, indent=" "): + return common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py new file mode 100644 index 000000000..f54c0ed2c --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = GeMM(A, B) + bias +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias, gemm_rcr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +# used for real execution +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) (a_ptr + input_a_offset), + (void*) (b_ptr + input_b_offset), + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + input_a_batch_stride, + input_b_batch_stride, + /*bias_batch_stride*/ N, + /*output_batch_stride*/ M * N, + input_a_stride, + input_b_stride, + /*bias_stride*/ 0, + output_stride +""" +) + + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return gemm_rcr.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + gemm_rcr.common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + PROFILER_PROBLEM_ARGS_TEMPLATE, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + ) + + +@registry.reg("cuda.gemm_rcr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_addr_calculator = gemm_rcr.get_input_addr_calculator(func_attrs) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + problem_args = PROBLEM_ARGS_TEMPLATE.render() + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + support_split_k=True, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + ) + + +@registry.reg("cuda.gemm_rcr_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rcr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.gemm_rcr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py new file mode 100644 index 000000000..c2fc67191 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = ADD(GeMM(A, B) + bias, D0) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::plus" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::Identity" + + +@registry.reg("cuda.gemm_rcr_bias_add.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py new file mode 100644 index 000000000..56511dbc1 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = RELU(ADD(ADD(GeMM(A, B) + bias, D0), D1)) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::plus" +BINARY_OP2 = "cutlass::plus" +UNARY_OP2 = "cutlass::epilogue::thread::Identity" + + +@registry.reg("cuda.gemm_rcr_bias_add_add.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_add_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_add.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_add_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_add_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py new file mode 100644 index 000000000..f823baab2 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = RELU(ADD(ADD(GeMM(A, B) + bias, D0), D1)) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::plus" +BINARY_OP2 = "cutlass::plus" +UNARY_OP2 = "cutlass::epilogue::thread::ReLu" + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_add_add_relu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py new file mode 100644 index 000000000..bd4f7da4b --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = RELU(ADD(GeMM(A, B) + bias, D0)) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::plus" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::ReLu" + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_add_relu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py new file mode 100644 index 000000000..f55e21cd8 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for C = fast_gelu(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +EXTRA_CODE = jinja2.Template( + """ +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/constants.h" +#include "cutlass/complex.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +using LinearCombinationFastGELU = LinearCombinationGeneric; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +""" +) + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_fast_gelu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py new file mode 100644 index 000000000..d16d769a1 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for C = fast_gelu(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + ) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_gelu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py new file mode 100644 index 000000000..6c22e1e3a --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for C = hard_swish(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + ) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_hardswish.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py new file mode 100644 index 000000000..f2049abef --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = ADD(GeMM(A, B) + bias, D0) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::multiplies" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::Identity" + + +@registry.reg("cuda.gemm_rcr_bias_mul.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_mul.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_mul.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_mul.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py new file mode 100644 index 000000000..55400a029 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = Add(Mul(GeMM(A, B) + bias, D0), D1), +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::multiplies" +BINARY_OP2 = "cutlass::plus" +UNARY_OP2 = "cutlass::epilogue::thread::Identity" + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_mul_add.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py new file mode 100644 index 000000000..3d5abf306 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = TANH(Mul((GeMM(A, B) + bias), D0)) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Identity" +BINARY_OP1 = "cutlass::multiplies" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::Tanh" + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_mul_tanh.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py new file mode 100644 index 000000000..2a4c75cbe --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM with bias and permute epilogue fusion +""" + +from ... import registry +from ..gemm_universal import common +from . import common_bias, common_permute, gemm_rcr_bias, gemm_rcr_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +PROBLEM_ARGS_TEMPLATE = gemm_rcr_bias.PROFILER_PROBLEM_ARGS_TEMPLATE + + +@registry.reg("cuda.gemm_rcr_bias_permute.config") +def gemm_rcr_bias_permute_config(func_attrs, dtype="float16"): + return gemm_rcr_permute.gemm_rcr_permute_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return gemm_rcr_permute.common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render() + else: + problem_args = problem_args_template.render() + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + return common_permute.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_permute.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rcr_bias_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.gemm_rcr_bias_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py new file mode 100644 index 000000000..3a5940e7a --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for C = relu(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] +""" + +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_relu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + ) + + +@registry.reg("cuda.gemm_rcr_bias_relu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_relu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_relu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py new file mode 100644 index 000000000..719efbfa2 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = Sigmoid(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py new file mode 100644 index 000000000..b3b306f38 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = Mul(Sigmoid(GeMM(A, B) + bias), D0) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Sigmoid" +BINARY_OP1 = "cutlass::multiplies" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::Identity" + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py new file mode 100644 index 000000000..66cad13c4 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = TANH(Mul(Sigmoid(GeMM(A, B) + bias), D0)) +where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] +bias[RowMajor][N], D0[RowMajor][M, N] +""" +from ... import registry +from . import common, common_bias_broadcast +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +UNARY_OP1 = "cutlass::epilogue::thread::Sigmoid" +BINARY_OP1 = "cutlass::multiplies" +BINARY_OP2 = None +UNARY_OP2 = "cutlass::epilogue::thread::Tanh" + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_broadcast.gemm_bias_broadcast_config(func_attrs, RCR) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + common_bias_broadcast.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_broadcast.gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + RCR, + UNARY_OP1, + BINARY_OP1, + BINARY_OP2, + UNARY_OP2, + ) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.func_decl") +def gen_function_decl(func_attrs): + return common_bias_broadcast.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_broadcast.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py new file mode 100644 index 000000000..688c9daf3 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = swish(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_swish.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_swish.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.gemm_rcr_bias_swish.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + ) + + +@registry.reg("cuda.gemm_rcr_bias_swish.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_swish.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_swish.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py new file mode 100644 index 000000000..8a11c966f --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = tanh(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from . import common, common_bias_activation + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +EXTRA_CODE = jinja2.Template( + """ +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/constants.h" +#include "cutlass/complex.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +using LinearCombinationTanh = LinearCombinationGeneric; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +""" +) + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) bias_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + return common_bias_activation.gen_function( + func_attrs, + PROBLEM_ARGS_TEMPLATE, + exec_cond_template, + dim_info_dict, + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.func_call") +def gen_function_call(func_attrs, indent=" "): + return common_bias_activation.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rcr_bias_tanh.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py new file mode 100644 index 000000000..f2851db12 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = permute(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from ..gemm_universal import common +from . import common_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) c_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + M * N, + M * N, + K, + K, + N, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_permute.config") +def gemm_rcr_permute_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common_permute.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + permute_layout=func_attrs["layout"], + ) + + func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + + +def common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim0" + ) + common_permute.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rcr_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render() + else: + problem_args = problem_args_template.render() + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + return common_permute.gen_function( + func_attrs, + common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_permute.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rcr_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + b = func_attrs["inputs"][1] + + output = func_attrs["outputs"][0] + has_bias = False + adims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][0].original_shapes + ] + bdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][1].original_shapes + ] + cdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["output_accessors"][0].original_shapes + ] + return common.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + c_ptr=output._attrs["name"], + split_k=func_attrs["split_k"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + ) + + +@registry.reg("cuda.gemm_rcr_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py new file mode 100644 index 000000000..0a3d109d6 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = GeMM(A, B) +where A[RowMajor][M, K], B[RowMajor][K, N] +""" +import jinja2 + +from ... import registry +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = K; + int64_t b_dim1 = N; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) c_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + M * N, + M * N, + K, + N, + N, + output_stride, +""" +) + + +@registry.reg("cuda.gemm_rrr.config") +def gemm_rrr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.gemm_rrr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N" + ) + common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + ARGS_PARSER_TEMPLATE, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + ) + + +@registry.reg("cuda.gemm_rrr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + problem_args = PROBLEM_ARGS_TEMPLATE.render() + return common.gen_function( + func_attrs, + common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim1", output_accessor=func_attrs["output_accessors"][0] + ), + ) + + +@registry.reg("cuda.gemm_rrr.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rrr.func_call") +def gen_function_call(func_attrs, indent=" "): + return common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.gemm_rrr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py new file mode 100644 index 000000000..8653efab1 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = permute(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[RowMajor][K, N], bias[RowMajor][N] +""" +import jinja2 + +from ... import registry +from ..gemm_universal import common +from . import common_permute + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = K; + int64_t b_dim1 = N; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + (void*) a_ptr, + (void*) b_ptr, + (void*) c_ptr, + (void*) (c_ptr + output_offset), + M * K, + N * K, + M * N, + M * N, + K, + N, + N, + output_stride, +""" +) + + +@registry.reg("cuda.gemm_rrr_permute.config") +def gemm_rrr_permute_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common_permute.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + permute_layout=func_attrs["layout"], + ) + + func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + + +def common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N" + ) + common_permute.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.gemm_rrr_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + PROBLEM_ARGS_TEMPLATE, + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rrr_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render() + else: + problem_args = problem_args_template.render() + + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + return common_permute.gen_function( + func_attrs, + common.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim1", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=common_permute.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rrr_permute.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.gemm_rrr_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + a = func_attrs["inputs"][0] + b = func_attrs["inputs"][1] + + output = func_attrs["outputs"][0] + has_bias = False + adims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][0].original_shapes + ] + bdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["input_accessors"][1].original_shapes + ] + cdims = [ + "&" + dim._attrs["name"] + for dim in func_attrs["output_accessors"][0].original_shapes + ] + return common.FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + a_ptr=a._attrs["name"], + b_ptr=b._attrs["name"], + has_bias=has_bias, + c_ptr=output._attrs["name"], + split_k=func_attrs["split_k"], + adims=adims, + bdims=bdims, + cdims=cdims, + indent=indent, + ) + + +@registry.reg("cuda.gemm_rrr_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common.py b/python/aitemplate/backend/cuda/gemm_universal/group_common.py new file mode 100644 index 000000000..6568b3c4f --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common.py @@ -0,0 +1,974 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common functions and templates for group-gemm-family kernels +""" +import re +from hashlib import sha1 +from typing import Any, Dict, List + +import jinja2 + +from ...common import tensor_accessor_codegen +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +DIM_DEFS_TEMPLATE = jinja2.Template( + """ +{% for dim_name in dim_names %} +{% set dim_value = dim_values[loop.index - 1] %} +{{indent}}int64_t {{dim_name}} = {{dim_value}}; +{% endfor %} +""" +) + + +GROUP_OUTPUT_ADDR_CALCULATOR = jinja2.Template( + """ + {% if output_accessor.is_contiguous %} + int64_t output_stride_{{group_id}} = GROUP_{{group_id}}_{{output_stride_dim}}; + int64_t output_offset_{{group_id}} = 0; + {% else %} + int64_t output_stride_{{group_id}} = {{output_accessor.actual_total_elements_from_stride_dim}}; + int64_t output_offset_{{group_id}} = {{output_accessor.offset}}; + {% endif %} +""" +) + + +GROUP_INPUT_A_ADDR_CALCULATOR = jinja2.Template( + """ + {% if input_a_accessor.is_contiguous %} + int64_t input_a_stride_{{group_id}} = GROUP_{{group_id}}_{{input_a_stride_dim}}; + int64_t input_a_offset_{{group_id}} = 0; + {% else %} + int64_t input_a_stride_{{group_id}} = {{input_a_accessor.actual_total_elements_from_stride_dim}}; + int64_t input_a_offset_{{group_id}} = {{input_a_accessor.offset}}; + {% endif %} +""" +) + + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = cutlass::gemm::device::GemmGrouped<{{config_name}}>; +""" +) + + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +{{indent}}void {{func_name}}( +{{indent}} int, +{{indent}} int, +{{indent}} int64_t*, +{{indent}} int, +{{indent}} cutlass::half_t*, +{% for i in range(groups) %} +{{indent}} cutlass::half_t*, +{{indent}} cutlass::half_t*, +{{indent}} cutlass::half_t*, +{% if has_bias %} +{{indent}} cutlass::half_t*, +{% endif %} +{% endfor %} +{{indent}} uint8_t*, +{% for i in range(groups) %} +{{indent}} int64_t*, +{{indent}} int64_t*, +{{indent}} int64_t*, +{{indent}} int64_t*, +{{indent}} int64_t*, +{{indent}} int64_t*, +{% endfor %} +{{indent}} cudaStream_t +{{indent}}); +""" +) + + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} device_properties.sharedMemPerMultiprocessor, +{{indent}} device_properties.multiProcessorCount, +{{indent}} &{{func_name}}_state, +{{indent}} {{problem_count}}, +{{indent}} {{device_args}}, +{% for operand in group_operands %} +{{indent}} {{operand[0]}}, +{{indent}} {{operand[1]}}, +{{indent}} {{operand[2]}}, +{% if has_bias %} +{{indent}} {{operand[3]}}, +{% endif %} +{% endfor %} +{{indent}} global_workspace, +{% for operand_dim in group_operand_dims %} +{{indent}} {{operand_dim[0]}}, +{{indent}} {{operand_dim[1]}}, +{{indent}} {{operand_dim[2]}}, +{{indent}} {{operand_dim[3]}}, +{{indent}} {{operand_dim[4]}}, +{{indent}} {{operand_dim[5]}}, +{% endfor %} +{{indent}} stream +{{indent}}); +""" +) + + +ADAPTOR_FUNCTION_TEMPLATE = jinja2.Template( + """ +{% if is_profiler %} +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instance}} + +{% endif %} + +{{indent}}template +{{indent}}void {{func_name}}_adapter( + int sharedMemPerMultiprocessor, + int multiProcessorCount, + uint8_t* workspace, + int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_device, + cutlass::half_t **ptr_A, + cutlass::half_t **ptr_B, + cutlass::half_t **ptr_C, +{% if has_bias %} + cutlass::half_t **ptr_bias, +{% endif %} + int64_t* lda, + int64_t* ldb, + int64_t* ldc, +{% if has_bias %} + int64_t* ldd, +{% endif %} + int occupancy, + cudaStream_t stream) { + {{exec_program}} + throw std::runtime_error( + "Unsupported workload for this gemm specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) + + +ADAPTER_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}_adapter<{{instance}}>( + {{sharedMemPerMultiprocessor}}, + {{multiProcessorCount}}, + {{workspace}}, + {{problem_count}}, + {{problem_sizes_device}}, + {{ptr_A}}, + {{ptr_B}}, + {{ptr_C}}, +{% if has_bias %} + {{ptr_bias}}, +{% endif %} + {{lda}}, + {{ldb}}, + {{ldc}}, +{% if has_bias %} + {{ldd}}, +{% endif %} + {{instance}}::maximum_active_blocks(), + stream + ); +""", + trim_blocks=True, + lstrip_blocks=True, +) + + +EXEC_TEMPLATE = jinja2.Template( + """ +// TODO: cast to right dtype +{{indent}}using ElementComputeEpilogue = typename GEMMKind::ElementAccumulator; +{{indent}}// int smem_size = int(sizeof(typename GEMMKind::GemmKernel::SharedStorage)); +{{indent}}// int occupancy = std::min(2, int(sharedMemPerMultiprocessor / smem_size)); +{{indent}}int threadblock_count = multiProcessorCount * occupancy; +{{indent}}// Early exit +{{indent}}if (!threadblock_count) { +{{indent}} throw std::runtime_error( +{{indent}} "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." +{{indent}} ); +{{indent}}} + + +{{indent}}typename GEMMKind::Arguments arguments{ + +{{problem_args}} + +{{indent}}}; +{{indent}}GEMMKind gemm_op; +{% if is_profiler %} +{{indent}}// Debug BGM: https://www.youtube.com/watch?v=rRwxfYlgG-M +{{indent}}size_t workspace_size = gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}// TODO: cutlass bug here +{{indent}}// auto status = gemm_op.can_implement(arguments); +{{indent}}// CUTLASS_CHECK(status); +{{indent}}auto status = gemm_op.initialize(arguments, workspace, stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; +""" +) + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("Got cutlass error: ") + cutlassGetStatusString(error) + \\ + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +namespace { +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * cutlass::sizeof_bits::value / 8; + if (bytes == 0 && count > 0) + bytes = 1; + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } +} +} // namespace + +{{instances}} + +{{func_adapter}} + +void {{function_name}} ( + int sharedMemPerMultiprocessor, + int multiProcessorCount, + int64_t* func_state, + int problem_count, + cutlass::half_t* device_args, + {% for operand in group_operands %} + cutlass::half_t* {{operand[0]}}, + cutlass::half_t* {{operand[1]}}, + cutlass::half_t* {{operand[2]}}, + {% if has_bias %} + cutlass::half_t* {{operand[3]}}, + {% endif %} + {% endfor %} + uint8_t* global_workspace, +{% for operand_dim in group_operand_dims %} + int64_t* {{operand_dim[0]}}, + int64_t* {{operand_dim[1]}}, + int64_t* {{operand_dim[2]}}, + int64_t* {{operand_dim[3]}}, + int64_t* {{operand_dim[4]}}, + int64_t* {{operand_dim[5]}}, +{% endfor %} + cudaStream_t stream) { + + {{shape_function}} + + if (!device_args) { + throw std::runtime_error("device_args is NULL!"); + } + // It's a bit tricky to check individual gemms in group_gemm cases, + // so let's rule out them all if any of the input/output tensors is zero-sized. + // We can re-visit this part if we hit any use case, e.g. one input of the + // gemm is zero-sized, but all others are non-zero-sized. +{% for operand_dim in group_operand_dims %} + if (*{{operand_dim[0]}} == 0 || *{{operand_dim[1]}} == 0 || + *{{operand_dim[2]}} == 0 || *{{operand_dim[3]}} == 0 || + *{{operand_dim[4]}} == 0 || *{{operand_dim[5]}} == 0) { + throw std::runtime_error("Zero-sized tensors are not supported yet"); + } +{% endfor %} +{% for operand in group_operands %} + if (!{{operand[0]}}) { + throw std::runtime_error("{{operand[0]}} is NULL!"); + } + if (!{{operand[1]}}) { + throw std::runtime_error("{{operand[1]}} is NULL!"); + } + if (!{{operand[2]}}) { + throw std::runtime_error("{{operand[2]}} is NULL!"); + } +{% if has_bias %} + if (!{{operand[3]}}) { + throw std::runtime_error("{{operand[3]}} is NULL!"); + } +{% endif %} + +{% endfor %} + + uint8_t* arg_ptr = (uint8_t*) device_args; + // problem_sizes_device: N * GemmCoord -> N * 3 * sizeof(int64_t) -> 32 * N + // ptrA/B/C/D: N * 8 for each + // lda/b/c/d: N * 8 for each + // total: N * 8 * 4 + N * 8 * 4 + N * 8 * 4 + // total: 3 * 32 * N + int offset = 0; + auto problem_sizes_device = + (cutlass::gemm::GemmCoord*)(arg_ptr + offset); + offset += 32 * problem_count; + + auto ptr_A = (cutlass::half_t**)(arg_ptr + offset); + offset += 8 * problem_count; + auto ptr_B = (cutlass::half_t**)(arg_ptr + offset); + offset += 8 * problem_count; + auto ptr_C = (cutlass::half_t**)(arg_ptr + offset); + offset += 8 * problem_count; + {% if has_bias %} + auto ptr_bias = (cutlass::half_t**)(arg_ptr + offset); + offset += 8 * problem_count; + {% endif %} + + auto lda = (int64_t*)(arg_ptr + offset); + offset += 8 * problem_count; + auto ldb = (int64_t*)(arg_ptr + offset); + offset += 8 * problem_count; + auto ldc = (int64_t*)(arg_ptr + offset); + {% if has_bias %} + offset += 8 * problem_count; + auto ldd = (int64_t*)(arg_ptr + offset); + {% endif %} + // offset += 8 * problem_count; + + if (*func_state != GROUP_0_AM) { + // need update + std::vector problem_sizes; + std::vector ptr_A_host; + std::vector ptr_B_host; + std::vector ptr_C_host; + {% if has_bias %} + std::vector ptr_bias_host; + {% endif %} + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + {% if has_bias %} + std::vector ldd_host; + {% endif %} + + {% for operand in group_operands %} + ptr_A_host.push_back({{operand[0]}} + input_a_offset_{{loop.index0}}); + ptr_B_host.push_back({{operand[1]}}); + ptr_C_host.push_back({{operand[2]}} + output_offset_{{loop.index0}}); + {% if has_bias %} + ptr_bias_host.push_back({{operand[3]}}); + {% endif %} + {% endfor %} + + // AM: 0 + // AK: 1 + // BN: 2 + {% for operand_dim in group_operand_dims %} + cutlass::gemm::GemmCoord problem_{{loop.index0}}( + GROUP_{{loop.index0}}_M, + GROUP_{{loop.index0}}_N, + GROUP_{{loop.index0}}_K); + problem_sizes.emplace_back(problem_{{loop.index0}}); + lda_host.push_back(input_a_stride_{{loop.index0}}); + ldb_host.push_back(GROUP_{{loop.index0}}_K); + {% if has_bias %} + ldc_host.push_back(0); + ldd_host.push_back(output_stride_{{loop.index0}}); + {% else %} + ldc_host.push_back(output_stride_{{loop.index0}}); + {% endif %} + {% endfor %} + + copy(problem_sizes_device, + problem_sizes.data(), + problem_count, cudaMemcpyHostToDevice); + + copy(ptr_A, + ptr_A_host.data(), + problem_count, cudaMemcpyHostToDevice); + + copy(ptr_B, + ptr_B_host.data(), + problem_count, cudaMemcpyHostToDevice); + + copy(ptr_C, + ptr_C_host.data(), + problem_count, cudaMemcpyHostToDevice); + + {% if has_bias %} + copy(ptr_bias, + ptr_bias_host.data(), + problem_count, cudaMemcpyHostToDevice); + {% endif %} + + copy(lda, + lda_host.data(), + problem_count, cudaMemcpyHostToDevice); + + copy(ldb, + ldb_host.data(), + problem_count, cudaMemcpyHostToDevice); + + copy(ldc, + ldc_host.data(), + problem_count, cudaMemcpyHostToDevice); + + {% if has_bias %} + copy(ldd, + ldd_host.data(), + problem_count, cudaMemcpyHostToDevice); + {% endif %} + + *func_state = GROUP_0_AM; + } + {{exec_paths}} +} + + +""" +) + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int problem_count = std::atoi(argv[1]); + int64_t idx = 2; + std::vector problem_sizes; + while (idx < argc) { + int64_t M = std::atoi(argv[idx++]); + int64_t N = std::atoi(argv[idx++]); + int64_t K = std::atoi(argv[idx++]); + cutlass::gemm::GemmCoord problem(M, N, K); + problem_sizes.push_back(problem); + } +""" +) + + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + cutlass::DeviceAllocation blob_A; + cutlass::DeviceAllocation blob_B; + cutlass::DeviceAllocation blob_C; +{% if has_bias %} + cutlass::DeviceAllocation blob_Bias; +{% endif %} + int64_t total_size_A = 0; + int64_t total_size_B = 0; + int64_t total_size_C = 0; +{% if has_bias %} + int64_t total_size_Bias = 0; +{% endif %} + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; +{% if has_bias %} + std::vector ldd_host; +{% endif %} + + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; +{% if has_bias %} + cutlass::DeviceAllocation ldd; +{% endif %} + + std::vector ptr_A_host; + std::vector ptr_B_host; + std::vector ptr_C_host; +{% if has_bias %} + std::vector ptr_bias_host; +{% endif %} + + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; +{% if has_bias %} + cutlass::DeviceAllocation ptr_bias; +{% endif %} + + + for (auto & mnk : problem_sizes) { + int64_t M = mnk.m(); + int64_t N = mnk.n(); + int64_t K = mnk.k(); + lda_host.push_back(K); + ldb_host.push_back(K); +{% if has_bias %} + ldc_host.push_back(0); + ldd_host.push_back(N); +{% else %} + ldc_host.push_back(N); +{% endif %} + + total_size_A += M * K; + total_size_B += N * K; + total_size_C += M * N; +{% if has_bias %} + total_size_Bias += N; +{% endif %} + } + + blob_A.reset(total_size_A); + blob_B.reset(total_size_B); + blob_C.reset(total_size_C); +{% if has_bias %} + blob_Bias.reset(total_size_Bias); +{% endif %} + + int64_t offset_A = 0; + int64_t offset_B = 0; + int64_t offset_C = 0; +{% if has_bias %} + int64_t offset_Bias = 0; +{% endif %} + + for (int i = 0; i < problem_sizes.size(); ++i) { + auto & mnk = problem_sizes.at(i); + int64_t M = mnk.m(); + int64_t N = mnk.n(); + int64_t K = mnk.k(); + + ptr_A_host.push_back(blob_A.get() + offset_A); + ptr_B_host.push_back(blob_B.get() + offset_B); + ptr_C_host.push_back(blob_C.get() + offset_C); +{% if has_bias %} + ptr_bias_host.push_back(blob_Bias.get() + offset_Bias); +{% endif %} + offset_A += M * K; + offset_B += N * K; + offset_C += M * N; +{% if has_bias %} + offset_Bias += N; +{% endif %} + } + + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); +{% if has_bias %} + ldd.reset(problem_count); +{% endif %} + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); +{% if has_bias %} + ldd.copy_from_host(ldd_host.data()); +{% endif %} + + ptr_A.reset(problem_count); + ptr_B.reset(problem_count); + ptr_C.reset(problem_count); +{% if has_bias %} + ptr_bias.reset(problem_count); +{% endif %} + ptr_A.copy_from_host(ptr_A_host.data()); + ptr_B.copy_from_host(ptr_B_host.data()); + ptr_C.copy_from_host(ptr_C_host.data()); +{% if has_bias %} + ptr_bias.copy_from_host(ptr_bias_host.data()); +{% endif %} + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes.data()); + +""" +) + + +def get_group_gemm_instance_template_params(op_def: str) -> List[str]: + """ + For a given op_def string generated by cutlass's group_gemm emiter, parse and + return the group_gemm instance's template parameters. + """ + params = re.findall( + r"cutlass::gemm::kernel::DefaultGemmUniversal<([\s\S]+)>::GemmKernel;", op_def + ) + assert len(params) == 1 + param = params[0] + gemm_universal_params = param.strip().split("\n") + gemm_universal_params = [param.strip(",") for param in gemm_universal_params] + assert len(gemm_universal_params) == 20, ( + "expected len(gemm_universal_params) to be 20, but got " + "{len(gemm_universal_params)}, {gemm_universal_params=}" + ) + return gemm_universal_params + + +def update_alignments_in_group_gemm_instance( + op_def: str, func_attrs: Dict[str, Any], for_profiler: bool +) -> str: + """ + update kAlignmentA, kAlignmentB, and epilogue_alignment in op_def, + which is a group_gemm instance emitted by the gemm instance emitter of cutlass. + """ + if for_profiler: + return op_def + + # TODO: adjust a_alignment, b_alignment based on input_accessors + + gemm_params = get_group_gemm_instance_template_params(op_def) + epilogue_align_idx = 12 + epilogue_curr_align = gemm_params[epilogue_align_idx].strip() + + output_accessors = func_attrs["output_accessors"] + epilogue_alignment = int(epilogue_curr_align) + for output_accessor in output_accessors: + epilogue_alignment = min( + epilogue_alignment, + tensor_accessor_codegen.find_max_alignment_for_accessor(output_accessor), + ) + + instance_lines = op_def.split("\n") + # a_align_idx + 4 in the full instance string + idx_offset = 4 + + epilogue_curr_align_line = instance_lines[epilogue_align_idx + idx_offset] + assert epilogue_curr_align == epilogue_curr_align_line.strip( + " ," + ), f"expected {epilogue_curr_align=} equal to {epilogue_curr_align_line=}" + instance_lines[epilogue_align_idx + idx_offset] = epilogue_curr_align_line.replace( + epilogue_curr_align, str(epilogue_alignment) + ) + return "\n".join(instance_lines) + + +def group_gemm_instance(op_def: str, func_attrs: Dict[str, Any], for_profiler: bool): + # TODO: This is a dirty thing need to add an extra emitter to clean this up + op_def = update_alignments_in_group_gemm_instance(op_def, func_attrs, for_profiler) + tmp = op_def.replace("DefaultGemmUniversal", "DefaultGemmGrouped") + tmp = tmp.replace("false,", "") + # force output to be row major + # cutlass lib can't generate row major output kernels + tmp = re.sub( + r"cutlass::layout::ColumnMajor,\n", "cutlass::layout::RowMajor,\n", tmp + ) + tmp = re.sub( + r"GemmIdentityThreadblockSwizzle<\d>", + "GemmBatchedIdentityThreadblockSwizzle", + tmp, + ) + tmp = re.sub( + r"cutlass::arch::OpMultiplyAdd", + "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n" + + "cutlass::arch::OpMultiplyAdd", + tmp, + ) + return tmp + + +def gen_profiler( + func_attrs, + workdir, + shape_template, + problem_args_template, + has_bias=False, + output_addr_calculator="", +): + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance( + op, + for_profiler=True, + f_instance_convertor=group_gemm_instance, + emit_kernel=True, + ) + config_name = common.extract_config_name(config) + name = "GemmInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + + # instance = instance + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, problem_args=problem_args_template.render() + ) + op_func = ADAPTOR_FUNCTION_TEMPLATE.render( + instance=instance, + is_profiler=True, + func_name=name, + indent=" ", + exec_program=exec_program, + has_bias=has_bias, + ) + func_call = ADAPTER_CALL_TEMPLATE.render( + func_name=name, + instance=name, + sharedMemPerMultiprocessor="device_properties.sharedMemPerMultiprocessor", + multiProcessorCount="device_properties.multiProcessorCount", + workspace="global_workspace", + problem_count="problem_count", + problem_sizes_device="problem_sizes_device.get()", + ptr_A="ptr_A.get()", + ptr_B="ptr_B.get()", + ptr_C="ptr_C.get()", + has_bias=has_bias, + ptr_bias="ptr_bias.get()", + lda="lda.get()", + ldb="ldb.get()", + ldc="ldc.get()", + ldd="ldd.get()", + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + args_parse=ARGS_PARSER_TEMPLATE.render(), + func_call=func_call, + name=name, + tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_bias=has_bias), + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + common.build_profiler(file_pairs) + + +def gen_function( + func_attrs, + exec_cond_template, + shape_eval_template, + problem_args_template, + has_bias=False, +): + problem_args = problem_args_template.render() + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + inst_def_flag = set() + instances = {} + instance_decl = "" + emit_kernel = True + for key, value in exec_path.items(): + fname = "f" + sha1(key.encode()).hexdigest() + algo = value.algo + if algo not in inst_def_flag: + config = common.emit_instance( + op_instance[algo], + for_profiler=False, + f_instance_convertor=group_gemm_instance, + emit_kernel=emit_kernel, + func_attrs=func_attrs, + ) + inst_def_flag.add(algo) + else: + raise ValueError(f"Algo {algo} already in inst_def_flags") + + inst = INSTANCE_TEMPLATE.render( + config=config, + name=fname, + config_name=common.extract_config_name(config), + ) + instances[key] = inst + instance_decl += inst + kwargs = {} + kwargs["indent"] = " " + kwargs["dtype"] = "int64_t " + group_operand_dims = [] + output_addr_cals = [] + input_a_addr_cals = [] + num_inputs_per_group = 3 if has_bias else 2 + + for i in range(func_attrs["groups"]): + dim_names = [] + for j in range(6): + dim_names.append("*dim_{group}_{dim}".format(group=i, dim=j)) + group_operand_dims.append(dim_names) + output_addr_cal = GROUP_OUTPUT_ADDR_CALCULATOR.render( + group_id=i, + output_stride_dim="CN", + output_accessor=func_attrs["output_accessors"][i], + ) + output_addr_cals.append(output_addr_cal) + input_a_addr_cal = GROUP_INPUT_A_ADDR_CALCULATOR.render( + group_id=i, + input_a_stride_dim="AK", + input_a_accessor=func_attrs["input_accessors"][i * num_inputs_per_group], + ) + input_a_addr_cals.append(input_a_addr_cal) + kwargs["group_operand_dims"] = group_operand_dims + kwargs["output_addr_cals"] = output_addr_cals + kwargs["input_a_addr_cals"] = input_a_addr_cals + shape_func = shape_eval_template.render(**kwargs) + exec_paths = "" + # + for key, _ in instances.items(): + fname = "f" + sha1(key.encode()).hexdigest() + program = ADAPTER_CALL_TEMPLATE.render( + indent=" ", + func_name=func_name, + instance=fname, + sharedMemPerMultiprocessor="sharedMemPerMultiprocessor", + multiProcessorCount="multiProcessorCount", + workspace="global_workspace", + problem_count=func_attrs["groups"], + problem_sizes_device="problem_sizes_device", + ptr_A="ptr_A", + ptr_B="ptr_B", + ptr_C="ptr_C", + has_bias=has_bias, + ptr_bias="ptr_bias", + lda="lda", + ldb="ldb", + ldc="ldc", + ldd="ldd", + ) + exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=False, problem_args=problem_args + ) + adapter_func = ADAPTOR_FUNCTION_TEMPLATE.render( + func_name=func_name, exec_program=exec_program, has_bias=has_bias + ) + group_operands = [] + group_operand_dims = [] + for i in range(func_attrs["groups"]): + operand = [] + operand.append("ptr_{group}_a".format(group=i)) + operand.append("ptr_{group}_b".format(group=i)) + operand.append("ptr_{group}_c".format(group=i)) + if has_bias: + operand.append("ptr_{group}_bias".format(group=i)) + dims = [] + for j in range(6): + dims.append("dim_{group}_{dim}".format(group=i, dim=j)) + group_operands.append(operand) + group_operand_dims.append(dims) + + return SRC_TEMPLATE.render( + instances=instance_decl, + func_adapter=adapter_func, + function_name=func_name, + shape_function=shape_func, + group_operands=group_operands, + group_operand_dims=group_operand_dims, + exec_paths=exec_paths, + has_bias=has_bias, + ) + + +def gen_function_call(func_attrs, ndims, has_bias=False, indent=" "): + group_operands = [] + group_operand_dims = [] + output_accessors = [a.is_contiguous for a in func_attrs["output_accessors"]] + with_single_strided_output = False + if "output_stride_dim" in func_attrs: + output_accessors = list(set(output_accessors)) + # we only support two cases: either all outputs are contiguous or none + # of them are + assert len(output_accessors) == 1 + with_single_strided_output = not output_accessors[0] + for i in range(func_attrs["groups"]): + a = func_attrs["inputs"][i * ndims] + b = func_attrs["inputs"][i * ndims + 1] + if has_bias: + bias = func_attrs["inputs"][i * ndims + 2] + c_idx = 0 if with_single_strided_output else i + c = func_attrs["outputs"][c_idx] + input_a_accessor = func_attrs["input_accessors"][i * ndims] + input_b_accessor = func_attrs["input_accessors"][i * ndims + 1] + output_accessor = func_attrs["output_accessors"][i] + + ashape = input_a_accessor.original_shapes + bshape = input_b_accessor.original_shapes + cshape = output_accessor.original_shapes + operands = [] + operand_dims = [] + operands.append(a._attrs["name"]) + operands.append(b._attrs["name"]) + operands.append(c._attrs["name"]) + if has_bias: + operands.append(bias._attrs["name"]) + operand_dims.append("&" + ashape[0]._attrs["name"]) + operand_dims.append("&" + ashape[1]._attrs["name"]) + operand_dims.append("&" + bshape[0]._attrs["name"]) + operand_dims.append("&" + bshape[1]._attrs["name"]) + operand_dims.append("&" + cshape[0]._attrs["name"]) + operand_dims.append("&" + cshape[1]._attrs["name"]) + group_operands.append(operands) + group_operand_dims.append(operand_dims) + device_args = f'reinterpret_cast(unique_workspace + {func_attrs["unique_workspace_offset"]})' + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + problem_count=func_attrs["groups"], + device_args=device_args, + group_operands=group_operands, + group_operand_dims=group_operand_dims, + indent=indent, + has_bias=has_bias, + ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py new file mode 100644 index 000000000..2b556fc83 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for group_gemm_bias-family kernels. +""" +import jinja2 + +from . import group_common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + problem_sizes_device, + problem_count, + threadblock_count, + {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, + ptr_A, + ptr_B, + ptr_bias, + ptr_C, + lda, + ldb, + ldc, + ldd +""" +) + + +def gen_profiler( + func_attrs, + workdir, + shape_template, +): + group_common.gen_profiler( + func_attrs, workdir, shape_template, PROBLEM_ARGS_TEMPLATE, has_bias=True + ) + + +def gen_function( + func_attrs, + exec_cond_template, + shape_eval_template, +): + return group_common.gen_function( + func_attrs, + exec_cond_template, + shape_eval_template, + PROBLEM_ARGS_TEMPLATE, + has_bias=True, + ) + + +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return group_common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, groups=func_attrs["groups"], has_bias=True + ) + + +def gen_function_call(func_attrs, indent=" "): + ndims = 3 + return group_common.gen_function_call(func_attrs, ndims, has_bias=True) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py new file mode 100644 index 000000000..354039b40 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for group_gemm_rcr. +""" +import jinja2 + +from ... import registry +from . import common, group_common +from .layout import RCR + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + problem_sizes_device, + problem_count, + threadblock_count, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + ptr_A, + ptr_B, + ptr_C, + ptr_C, + lda, + ldb, + ldc, + ldc +""" +) + + +@registry.reg("cuda.group_gemm_rcr.config") +def group_rcr_config(func_attrs, dtype="float16"): + common.make_fproc_f16(func_attrs, RCR) + + +@registry.reg("cuda.group_gemm_rcr.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + group_common.gen_profiler( + func_attrs, workdir, shape_template, PROBLEM_ARGS_TEMPLATE + ) + + +@registry.reg("cuda.group_gemm_rcr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + shape_eval_template, +): + return group_common.gen_function( + func_attrs, + exec_cond_template, + shape_eval_template, + PROBLEM_ARGS_TEMPLATE, + ) + + +@registry.reg("cuda.group_gemm_rcr.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return group_common.FUNC_DECL_TEMPLATE.render( + func_name=func_name, groups=func_attrs["groups"] + ) + + +@registry.reg("cuda.group_gemm_rcr.func_call") +def gen_function_call(func_attrs, indent=" "): + ndims = 2 + return group_common.gen_function_call(func_attrs, ndims) + + +@registry.reg("cuda.group_gemm_rcr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py new file mode 100644 index 000000000..c292c3e1d --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for group_gemm_rcr_bias. +""" +from ... import registry +from . import common, group_common_bias, group_gemm_rcr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.group_gemm_rcr_bias.config") +def group_rcr_config(func_attrs, dtype="float16"): + group_gemm_rcr.group_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.group_gemm_rcr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + group_common_bias.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.group_gemm_rcr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, +): + return group_common_bias.gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + ) + + +@registry.reg("cuda.group_gemm_rcr_bias.func_decl") +def gen_function_decl(func_attrs): + return group_common_bias.gen_function_decl(func_attrs) + + +@registry.reg("cuda.group_gemm_rcr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + return group_common_bias.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.group_gemm_rcr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py new file mode 100644 index 000000000..9345c26e4 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for group_gemm_rcr_bias_relu. +""" +from ... import registry +from . import common, group_common_bias, group_gemm_rcr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.config") +def group_rcr_config(func_attrs, dtype="float16"): + group_gemm_rcr.group_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + group_common_bias.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, +): + return group_common_bias.gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + ) + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.func_decl") +def gen_function_decl(func_attrs): + return group_common_bias.gen_function_decl(func_attrs) + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.func_call") +def gen_function_call(func_attrs, indent=" "): + return group_common_bias.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.group_gemm_rcr_bias_relu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py new file mode 100644 index 000000000..e247bbe2a --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for group_gemm_rcr_bias_sigmoid. +""" +from ... import registry +from . import common, group_common_bias, group_gemm_rcr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.config") +def group_rcr_config(func_attrs, dtype="float16"): + group_gemm_rcr.group_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + group_common_bias.gen_profiler(func_attrs, workdir, shape_template) + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, +): + return group_common_bias.gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + ) + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.func_decl") +def gen_function_decl(func_attrs): + return group_common_bias.gen_function_decl(func_attrs) + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.func_call") +def gen_function_call(func_attrs, indent=" "): + return group_common_bias.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.group_gemm_rcr_bias_sigmoid.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/layout.py b/python/aitemplate/backend/cuda/gemm_universal/layout.py new file mode 100644 index 000000000..8bab2b98e --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/layout.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GeMM layout classes. +""" + +from dataclasses import dataclass + +# pylint: disable=C0415 + + +@dataclass +class Layout: + m = "M" + n = "N" + k = "K" + + +@dataclass +class RCR(Layout): + """ + Layout: A[RowMajor], B[ColumnMajor], C[RowMajor] + """ + + cutlass_layout_a = "cutlass::layout::RowMajor" + cutlass_layout_b = "cutlass::layout::ColumnMajor" + cutlass_layout_c = "cutlass::layout::RowMajor" + stride_a = "K" + stride_b = "K" + stride_c = "N" + + args_parser = """ + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" + + @staticmethod + def fproc_op(op): + import cutlass_lib + + row_major = cutlass_lib.library.LayoutType.RowMajor + op.C.layout = row_major + + @staticmethod + def fcond_op(op): + import cutlass_lib + + row_major = cutlass_lib.library.LayoutType.RowMajor + col_major = cutlass_lib.library.LayoutType.ColumnMajor + return op.A.layout == row_major and op.B.layout == col_major + + @staticmethod + def cutlass_lib_layouts(): + """ + return [layout_a, layout_b, layout_c] in the form of cutlass_lib definitions + """ + import cutlass_lib + + return [ + cutlass_lib.library.LayoutType.RowMajor, + cutlass_lib.library.LayoutType.ColumnMajor, + cutlass_lib.library.LayoutType.RowMajor, + ] diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py new file mode 100644 index 000000000..580a3b005 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm021fc_ccr, which computes +[b, m, n] = bmm([b, k, m], [1, n, k]). +""" +from ... import registry +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "0", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "M", + "ldb": "K", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.perm021fc_ccr.config") +def gemm_ccr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.ColumnMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.perm021fc_ccr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["B", "K", "M"], b_dims=["1", "N", "K"], c_dims=["B", "M", "N"] + ) + + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.perm021fc_ccr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.perm021fc_ccr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.perm021fc_ccr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.perm021fc_ccr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py new file mode 100644 index 000000000..b4f320de9 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm021fc_ccr_bias, which computes +[b, m, n] = bmm([b, k, m], [1, n, k]) + bias[n]. +""" +from ... import registry +from . import bmm_common, common, common_bias, perm021fc_ccr + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "beta_value": 1, + "bias_ptr": "bias_ptr", + "a_batch_stride": "M * K", + "b_batch_stride": "0", + "bias_batch_stride": "0", + "c_batch_stride": "M * N", + "lda": "M", + "ldb": "K", + "ldbias": "0", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.perm021fc_ccr_bias.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return perm021fc_ccr.gemm_ccr_config(func_attrs, dtype) + + +@registry.reg("cuda.perm021fc_ccr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["B", "K", "M"], b_dims=["1", "N", "K"], c_dims=["B", "M", "N"] + ) + + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + problem_args, + args_parser, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + ) + + +@registry.reg("cuda.perm021fc_ccr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + dim_info_dict=dim_info_dict, + ) + + +@registry.reg("cuda.perm021fc_ccr_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, input_ndims=input_ndims, weight_ndims=weight_ndims + ) + + +@registry.reg("cuda.perm021fc_ccr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return bmm_common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.perm021fc_ccr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py new file mode 100644 index 000000000..5631bf3ca --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common functions and templates for perm021_ccr_bias_permute, which computes +(A.permute(0, 2, 1)[col] @ B[col] + Bias).permute(0, 2, 1) +""" +from ... import registry + +from ..gemm_universal import common + +from . import ( + bmm_common, + bmm_permute_common, + common_bias, + common_permute, + perm021fc_ccr_bias, +) + + +EXTRA_CODE = """ + +#include "cutlass/gemm/device/gemm_universal_with_perm.h" + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +template +class Tensor3DPermute021BMM { + public: + using Index = int32_t; + using LongIndex = int64_t; + + Index col_permute; + Index row_permute; + Index stride_permute; + + private: + MatrixCoord extent_; + + public: + CUTLASS_HOST_DEVICE + Tensor3DPermute021BMM() {} + + CUTLASS_HOST_DEVICE + Tensor3DPermute021BMM(MatrixCoord extent) : extent_(extent) {} + + CUTLASS_HOST_DEVICE + void compute(Index col_init, Index row_init, Index stride_init, Index BMM_batch_idx) { + // Permute as torch.permute(X1, [0, 2, 1]) -> 3D Tensor indices as [i,j,k], the dimension of X is [D0, D1, D2], after permutation the dim of X1 is [D0, D2, D1]. + // printf("BMM batch index: %d\t GEMM_m, GEMM_n = %d, %d\\n", BMM_batch_idx, extent_.row(), extent_.column()); + + int k = col_init; + int j = row_init; + int i = BMM_batch_idx; + + col_permute = j; + row_permute = k; + stride_permute = stride_init / extent_.column() * extent_.row(); // stride in Bytes + } +}; + +} // namespace layout +} // namespace cutlass +""" + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.config") +def config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common_permute.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.ColumnMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + permute_layout=func_attrs["layout"], + ) + + func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + return perm021fc_ccr_bias.gen_profiler(func_attrs, workdir, dim_info_dict) + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + mm_info = perm021fc_ccr_bias._get_problem_info( + alpha_value=func_attrs.get("alpha", 1) + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + return bmm_permute_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + extra_code=EXTRA_CODE, + has_bias=True, + ) + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, input_ndims=input_ndims, weight_ndims=weight_ndims + ) + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return bmm_common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.perm021fc_ccr_bias_permute.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py new file mode 100644 index 000000000..35a9ef77d --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm021fc_crc, which computes +[b, n, m](col) = bmm([1, k, n](col), [b, k, m](row)). +""" +from ... import registry +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "problem_size": "{N, M, K}", + "bias_ptr": "c_ptr", + "a_batch_stride": "0", + "b_batch_stride": "K * M", + "bias_batch_stride": "M * N", + "c_batch_stride": "M * N", + "lda": "N", + "ldb": "M", + "ldbias": "N", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.perm021fc_crc.config") +def gemm_crc_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.ColumnMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.ColumnMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.perm021fc_crc.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["1", "K", "N"], b_dims=["B", "K", "M"], c_dims=["B", "M", "N"] + ) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1), beta_value=0) + ) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.perm021fc_crc.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1), beta_value=0) + ) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + ) + + +@registry.reg("cuda.perm021fc_crc.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.perm021fc_crc.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.perm021fc_crc.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py new file mode 100644 index 000000000..187a0c6c1 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm021fc_crc_bias, which computes +[b, n, m](col) = bmm([1, k, n](col), [b, k, m](row)) + bias[n]. +""" +from ... import registry +from . import bmm_common, common, common_bias, perm021fc_crc + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_problem_info(**kwargs): + problem_args = { + "beta_value": 1, + "problem_size": "{N, M, K}", + "bias_ptr": "bias_ptr", + "a_batch_stride": "0", + "b_batch_stride": "K * M", + "bias_batch_stride": "0", + "c_batch_stride": "M * N", + "lda": "N", + "ldb": "M", + "ldbias": "0", + "ldc": "N", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +@registry.reg("cuda.perm021fc_crc_bias.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return perm021fc_crc.gemm_crc_config(func_attrs, dtype) + + +@registry.reg("cuda.perm021fc_crc_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["1", "K", "N"], b_dims=["B", "K", "M"], c_dims=["B", "M", "N"] + ) + + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + ) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + problem_args, + args_parser, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + ) + + +@registry.reg("cuda.perm021fc_crc_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + ) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + dim_info_dict=dim_info_dict, + ) + + +@registry.reg("cuda.perm021fc_crc_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, input_ndims=input_ndims, weight_ndims=weight_ndims + ) + + +@registry.reg("cuda.perm021fc_crc_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return bmm_common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.perm021fc_crc_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py new file mode 100644 index 000000000..fe0ffe9cd --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm102_bmm_rcr, which computes +C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) +""" +from ... import registry +from . import bmm_common, common + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_default_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "K", + "b_batch_stride": "N * K", + "bias_batch_stride": "N", + "c_batch_stride": "N", + "lda": "K * B", + "ldb": "K", + "ldbias": "N * B", + "ldc": "N * B", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +# Currently only has output Tensor Accessor support. +def _get_strided_problem_info(func_attrs): + return bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + a_ptr="a_ptr", + b_ptr="b_ptr", + bias_ptr="(c_ptr + output_offset)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="K", + b_batch_stride="N * K", + bias_batch_stride="output_batch_stride", + c_batch_stride="output_batch_stride", + lda="K * B", + ldb="K", + ldbias="output_stride", + ldc="output_stride", + ) + + +def get_output_addr_calculator(func_attrs): + output_batch_stride_dim = "N" + output_stride_dim = "N * B" + output_offset = 0 + + if "output_accessors" in func_attrs: + output_accessor = func_attrs["output_accessors"][0] + if output_accessor.is_from_strided_tensor: + output_offset = output_accessor.offset + if not output_accessor.is_contiguous: + output_stride_dim = output_accessor.stride(0) + original_shapes = output_accessor.original_shapes + actual_shapes = output_accessor.actual_shapes + if len(actual_shapes) == 2 and actual_shapes[0] == original_shapes[0]: + # x = perm102_bmm_xxx(a, b) # [m, b, n] + # y = x.reshape()[x[0], -1] # [m, b * n] + # z = cat()(y0, y1, ..., yn, dim=-1) + output_batch_stride_dim = "N" + else: + raise NotImplementedError( + "Other strided fusion cases are not supported." + ) + + output_addr_calculator = bmm_common.OUTPUT_ADDR_CALCULATOR.render( + output_batch_stride_dim=output_batch_stride_dim, + output_stride_dim=output_stride_dim, + output_offset_val=output_offset, + ) + + return output_addr_calculator + + +@registry.reg("cuda.perm102_bmm_rcr.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.ColumnMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.perm102_bmm_rcr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["M", "B", "K"], b_dims=["B", "N", "K"], c_dims=["M", "B", "N"] + ) + + mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.perm102_bmm_rcr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + bmm_problem_info = _get_strided_problem_info(func_attrs) + + # broadcasting is not supported + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + "", # input_addr_calculator + get_output_addr_calculator(func_attrs), + ) + + +@registry.reg("cuda.perm102_bmm_rcr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.perm102_bmm_rcr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.perm102_bmm_rcr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py new file mode 100644 index 000000000..8c34ecd48 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm102_bmm_rcr_bias, which computes +C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) + bias[n]. +""" +from ... import registry +from . import bmm_common, common, common_bias, perm102_bmm_rcr +from .perm102_bmm_rcr import get_output_addr_calculator + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_default_problem_info(**kwargs): + problem_args = { + "beta_value": 1, + "bias_ptr": "bias_ptr", + "a_batch_stride": "K", + "b_batch_stride": "N * K", + "bias_batch_stride": "N", + "c_batch_stride": "N", + "lda": "K * B", + "ldb": "K", + "ldbias": "0", + "ldc": "N * B", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +# Currently only has output Tensor Accessor support. +def _get_strided_problem_info(func_attrs): + return bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, + a_ptr="(a_ptr)", + b_ptr="(b_ptr)", + bias_ptr="(bias_ptr)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="K", + b_batch_stride="N * K", + bias_batch_stride="N", + c_batch_stride="output_batch_stride", + lda="K * B", + ldb="K", + ldbias="0", + ldc="output_stride", + ) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return perm102_bmm_rcr.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["M", "B", "K"], b_dims=["B", "N", "K"], c_dims=["M", "B", "N"] + ) + + mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + problem_args, + args_parser, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + ) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + bmm_problem_info = _get_strided_problem_info(func_attrs) + + # broadcasting is not supported + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + dim_info_dict=dim_info_dict, + output_addr_calculator=get_output_addr_calculator(func_attrs), + ) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, input_ndims=input_ndims, weight_ndims=weight_ndims + ) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return bmm_common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.perm102_bmm_rcr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py new file mode 100644 index 000000000..e4a3d7d1b --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm102_bmm_rrr, which computes +C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) +""" +from ... import registry +from . import bmm_common, common +from .perm102_bmm_rcr import get_output_addr_calculator + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_default_problem_info(**kwargs): + problem_args = { + "bias_ptr": "c_ptr", + "a_batch_stride": "K", + "b_batch_stride": "N * K", + "bias_batch_stride": "N", + "c_batch_stride": "N", + "lda": "K * B", + "ldb": "N", + "ldbias": "N * B", + "ldc": "N * B", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +# Currently only has output Tensor Accessor support. +def _get_strided_problem_info(func_attrs): + return bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + a_ptr="(a_ptr)", + b_ptr="(b_ptr)", + bias_ptr="(c_ptr + output_offset)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="K", + b_batch_stride="N * K", + bias_batch_stride="output_batch_stride", + c_batch_stride="output_batch_stride", + lda="K * B", + ldb="N", + ldbias="output_stride", + ldc="output_stride", + ) + + +@registry.reg("cuda.perm102_bmm_rrr.config") +def gemm_rrr_config(func_attrs, dtype="float16"): + def fproc_f16(op): + import cutlass_lib + + return common.default_fproc_f16( + op=op, + a_layout=cutlass_lib.library.LayoutType.RowMajor, + b_layout=cutlass_lib.library.LayoutType.RowMajor, + c_layout=cutlass_lib.library.LayoutType.RowMajor, + epiligue_name=func_attrs["epilogue"], + ) + + func_attrs["op_instance"] = common.extract_config(fproc_f16) + + +@registry.reg("cuda.perm102_bmm_rrr.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["M", "B", "K"], b_dims=["B", "K", "N"], c_dims=["M", "B", "N"] + ) + + mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common.SRC_TEMPLATE, + problem_args, + args_parser, + ) + + +@registry.reg("cuda.perm102_bmm_rrr.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + bmm_problem_info = _get_strided_problem_info(func_attrs) + + # broadcasting is not supported + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + return bmm_common.gen_function( + func_attrs, + exec_cond_template, + problem_args, + dim_info_dict, + "", # input_addr_calculator + get_output_addr_calculator(func_attrs), + ) + + +@registry.reg("cuda.perm102_bmm_rrr.func_decl") +def gen_function_decl(func_attrs): + return bmm_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.perm102_bmm_rrr.func_call") +def gen_function_call(func_attrs, indent=" "): + return bmm_common.gen_function_call(func_attrs, indent) + + +@registry.reg("cuda.perm102_bmm_rrr.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py new file mode 100644 index 000000000..f7435c071 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for perm102_bmm_rrr_bias, which computes +C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) + bias[n] +""" +from ... import registry +from . import bmm_common, common, common_bias, perm102_bmm_rrr +from .perm102_bmm_rcr import get_output_addr_calculator + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +def _get_default_problem_info(**kwargs): + problem_args = { + "beta_value": 1, + "bias_ptr": "bias_ptr", + "a_batch_stride": "K", + "b_batch_stride": "N * K", + "bias_batch_stride": "N", + "c_batch_stride": "N", + "lda": "K * B", + "ldb": "N", + "ldbias": "0", + "ldc": "N * B", + } + for k, v in kwargs.items(): + problem_args[k] = v + + bmm_problem_info = bmm_common.Bmm_problem_info(**problem_args) + return bmm_problem_info + + +# Currently only has output Tensor Accessor support. +def _get_strided_problem_info(func_attrs): + return bmm_common.Bmm_problem_info( + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, + a_ptr="(a_ptr)", + b_ptr="(b_ptr)", + bias_ptr="(bias_ptr)", + c_ptr="(c_ptr + output_offset)", + a_batch_stride="K", + b_batch_stride="N * K", + bias_batch_stride="N", + c_batch_stride="output_batch_stride", + lda="K * B", + ldb="N", + ldbias="0", + ldc="output_stride", + ) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.config") +def gemm_rrr_config(func_attrs, dtype="float16"): + return perm102_bmm_rrr.gemm_rrr_config(func_attrs, dtype) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.gen_profiler") +def gen_profiler(func_attrs, workdir, dim_info_dict): + args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( + a_dims=["M", "B", "K"], b_dims=["B", "K", "N"], c_dims=["M", "B", "N"] + ) + + mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + + bmm_common.gen_profiler( + func_attrs, + workdir, + dim_info_dict, + common_bias.SRC_TEMPLATE, + problem_args, + args_parser, + bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + ) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + bmm_problem_info = _get_strided_problem_info(func_attrs) + + # broadcasting is not supported + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + + return common.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + dim_info_dict=dim_info_dict, + output_addr_calculator=get_output_addr_calculator(func_attrs), + ) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, input_ndims=input_ndims, weight_ndims=weight_ndims + ) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return bmm_common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.perm102_bmm_rrr_bias.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/groupnorm/__init__.py b/python/aitemplate/backend/cuda/groupnorm/__init__.py new file mode 100644 index 000000000..ee950628c --- /dev/null +++ b/python/aitemplate/backend/cuda/groupnorm/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from . import groupnorm, groupnorm_swish + +__all__ = ["groupnorm", "groupnorm_swish"] diff --git a/python/aitemplate/backend/cuda/groupnorm/groupnorm.py b/python/aitemplate/backend/cuda/groupnorm/groupnorm.py new file mode 100644 index 000000000..e26d8cd62 --- /dev/null +++ b/python/aitemplate/backend/cuda/groupnorm/groupnorm.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Dict + +from ... import registry + +from .groupnorm_common import ( + groupnorm_gen_func_call, + groupnorm_gen_func_decl, + groupnorm_gen_function, +) + + +@registry.reg("cuda.groupnorm.gen_function") +def gen_function(func_attrs: Dict[str, Any]) -> str: + return groupnorm_gen_function(func_attrs) + + +@registry.reg("cuda.groupnorm.func_decl") +def func_decl(func_attrs: Dict[str, Any]) -> str: + return groupnorm_gen_func_decl(func_attrs) + + +@registry.reg("cuda.groupnorm.func_call") +def gen_func_call(func_attrs: Dict[str, Any], indent=" ") -> str: + return groupnorm_gen_func_call(func_attrs, indent) diff --git a/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py b/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py new file mode 100644 index 000000000..5b075783c --- /dev/null +++ b/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for group_norm. +""" + +import os +from typing import Any, Dict, List + +import jinja2 + +from ...target import Target + +FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( + "reinterpret_cast(&({{name}}->raw()))" +) + +FUNC_SIGNATURE = jinja2.Template( + """ +cudaError_t {{func_name}}(half* output, + half* input, + half* gamma, + half* beta, + int N, + const float eps, + const int max_smem_size, + cudaStream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}} {{func_name}}( +{{indent}} {{output}}, {{input}}, {{gamma}}, {{beta}}, {{N}}, +{{indent}} {{eps}}, max_smem_size, stream /* default stream */ +{{indent}} ); +{{indent}}} + """ +) + + +FUNC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include + +#include +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "logging.h" + + +{{gamma_beta_const_defs}} + +namespace { + +{{custom_libs}} + +} // namespace + +{{func_signature}} +{ + return invokeGroupNorm<{{FuseSwish}}, {{H}}, {{W}}, {{C}}, {{G}}>( + output, + input, + gamma, + beta, + N, + eps, + max_smem_size, + stream); +} + """ +) + + +def get_input_names(func_attrs: Dict[str, Any]) -> List[str]: + """ + Return a list of rendered name strings for inputs. It returns nullptr + for gamma and beta if they are None. + """ + inputs = func_attrs["inputs"] + x = inputs[0] + gamma = None + beta = None + + idx = 1 + if func_attrs["gamma_constant"] is None: + gamma = inputs[idx] + idx += 1 + if func_attrs["beta_constant"] is None: + beta = inputs[idx] + idx += 1 + + input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=x._attrs["name"]) + if gamma is None: + gamma_name = "nullptr" + else: + gamma_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=gamma._attrs["name"]) + if beta is None: + beta_name = "nullptr" + else: + beta_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=beta._attrs["name"]) + + return (input_name, gamma_name, beta_name) + + +def groupnorm_gen_function(func_attrs: Dict[str, Any]) -> str: + use_swish = True if "swish" in func_attrs["name"] else False + input_shape = func_attrs["inputs"][0].shape() + + H = input_shape[1].value() + W = input_shape[2].value() + C = input_shape[3].value() + G = func_attrs["num_groups"] + + return FUNC_TEMPLATE.render( + custom_libs=Target.current().get_custom_libs( + os.path.dirname(__file__), "groupnorm_kernel.cuh" + ), + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]), + FuseSwish="true" if use_swish else "false", + H=H, + W=W, + C=C, + G=G, + ) + + +def groupnorm_gen_func_decl(func_attrs: Dict[str, Any]) -> str: + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]).strip() + ) + + +def groupnorm_gen_func_call(func_attrs: Dict[str, Any], indent=" ") -> str: + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert 1 <= len( + func_attrs["inputs"] + ), "expected at least 1 inputs but got {}".format(len(func_attrs["inputs"])) + + output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][0]._attrs["name"] + ) + (input_name, gamma_name, beta_name) = get_input_names(func_attrs) + input_shape = func_attrs["inputs"][0]._attrs["shape"] + eps = func_attrs["eps"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + input=input_name, + gamma=gamma_name, + beta=beta_name, + N=input_shape[0]._attrs["name"], + eps=eps, + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh b/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh new file mode 100644 index 000000000..6a235589c --- /dev/null +++ b/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh @@ -0,0 +1,561 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef GROUPNORM_KERNEL_CUH +#define GROUPNORM_KERNEL_CUH + +#define FINAL_MASK 0xffffffff + +#ifndef GROUP_NORM_CUDA_CHECK +#define GROUP_NORM_CUDA_CHECK(expr) \ + do { \ + cudaError_t status = (expr); \ + if (status != cudaSuccess) { \ + std::cerr << "CUDA error: " << cudaGetErrorString(status) << " at " \ + << __FILE__ << ": " << __LINE__ << std::endl; \ + return status; \ + } \ + } while (0) +#endif + +#ifndef GROUP_NORM_CUDA_CHECK_LAUNCH +#define GROUP_NORM_CUDA_CHECK_LAUNCH() GROUP_NORM_CUDA_CHECK(cudaGetLastError()) +#endif + +__inline__ __device__ float sigmoid(float val) { + return (cutlass::fast_tanh(val * 0.5f) + 1.0f) * 0.5f; +} + +//////////////////////////////////////////////////////////////////////////////// +// The Groupnorm implementation below is based on OneFlow's Layernorm +// implementation at: +// https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh + +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define __AIT_GN_USE_FAST_MATH 1 +template +__forceinline__ __device__ T Div(T a, T b); + +template <> +__forceinline__ __device__ float Div(float a, float b) { +#ifdef __AIT_GN_USE_FAST_MATH + return __fdividef(a, b); +#else + return a / b; +#endif +} + +template <> +__forceinline__ __device__ half Div(half a, half b) { + return __hdiv(a, b); +} + +template +__forceinline__ __device__ T Rsqrt(T x); + +template <> +__forceinline__ __device__ float Rsqrt(float x) { +#ifdef __AIT_GN_USE_FAST_MATH + return __frsqrt_rn(x); +#else + return rsqrt(x); +#endif +} + +template <> +__forceinline__ __device__ half Rsqrt(half x) { + return hrsqrt(x); +} + +#undef __AIT_GN_USE_FAST_MATH + +template +inline __device__ void WelfordCombine(T val, T* mean, T* m2, int* count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + T delta1 = val - *mean; + *mean += Div(delta1, static_cast(*count)); + T delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +template +inline __device__ void WelfordCombine( + T b_mean, + T b_m2, + int b_count, + T* mean, + T* m2, + int* count) { + if (b_count == 0) { + return; + } + int new_count = *count + b_count; + T nb_over_n = Div((T)b_count, (T)new_count); + T delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (T)(*count) * (T)(nb_over_n); + *count = new_count; +} + +constexpr int kWarpSize = 32; + +template +__inline__ __device__ void WelfordWarpReduce( + T thread_mean, + T thread_m2, + int thread_count, + T* mean, + T* m2, + int* count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width); + T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width); + int b_count = + __shfl_down_sync(0xffffffff, *count, mask, thread_group_width); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +template +__inline__ __device__ void WelfordBlockAllReduce( + T thread_mean, + T thread_m2, + int thread_count, + T* result_mean, + T* result_m2, + int* result_count) { + __shared__ T mean_shared[kWarpSize]; + __shared__ T m2_shared[kWarpSize]; + __shared__ int count_shared[kWarpSize]; + __shared__ T mean_result_broadcast; + __shared__ T m2_result_broadcast; + __shared__ int count_result_broadcast; + const int lid = threadIdx.x % kWarpSize; + const int wid = threadIdx.x / kWarpSize; + T warp_mean = 0; + T warp_m2 = 0; + int warp_count = 0; + WelfordWarpReduce( + thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); + __syncthreads(); + if (lid == 0) { + mean_shared[wid] = warp_mean; + m2_shared[wid] = warp_m2; + count_shared[wid] = warp_count; + } + __syncthreads(); + if (wid == 0) { + if (threadIdx.x < blockDim.x / kWarpSize) { + warp_mean = mean_shared[lid]; + warp_m2 = m2_shared[lid]; + warp_count = count_shared[lid]; + } else { + warp_mean = static_cast(0); + warp_m2 = static_cast(0); + warp_count = static_cast(0); + } + __syncwarp(); + T block_mean = 0; + T block_m2 = 0; + int block_count = 0; + WelfordWarpReduce( + warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); + if (lid == 0) { + mean_result_broadcast = block_mean; + m2_result_broadcast = block_m2; + count_result_broadcast = block_count; + } + } + __syncthreads(); + *result_mean = mean_result_broadcast; + *result_m2 = m2_result_broadcast; + *result_count = count_result_broadcast; +} + +template +__global__ void groupnorm_welford_fp16( + T* output, + T* input, + T* gamma, + T* beta, + const float eps, + const int64_t elems_per_block, + const int64_t elems_per_group_channel, + const int64_t batch_stride, + const int64_t group_stride, + const int64_t num_rows, + const int64_t row_stride) { + // all the numbers and strides are counted with respect to type T + constexpr int vec_size = sizeof(T) / sizeof(half); + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int gid = blockIdx.y; // index of group + const int64_t batch_offset = bid * batch_stride; + const int64_t group_offset = gid * group_stride; + const int64_t offset = batch_offset + group_offset; + + // the first input of this thread + const T* t_input = input + offset; + + ComputeType thread_mean = ComputeType(0.0); + ComputeType thread_m2 = ComputeType(0.0); + int thread_count = 0; +#pragma unroll + for (int row_id = tid; row_id < num_rows; row_id += blockDim.x) { +#pragma unroll + for (int i = 0; i < elems_per_group_channel; i++) { + const T* local_input = t_input + i + row_id * row_stride; + const half* half_ptr = reinterpret_cast(local_input); +#pragma unroll + for (int j = 0; j < vec_size; ++j) { + WelfordCombine( + __half2float(half_ptr[j]), &thread_mean, &thread_m2, &thread_count); + } + } + } + ComputeType row_mean = (ComputeType)(0.0f); + ComputeType row_m2 = (ComputeType)(0.0f); + int row_count = 0; + if (blockDim.x <= 32) { + WelfordWarpReduce( + thread_mean, thread_m2, thread_count, &row_mean, &row_m2, &row_count); + } else { + WelfordBlockAllReduce( + thread_mean, thread_m2, thread_count, &row_mean, &row_m2, &row_count); + } + ComputeType row_variance = Div(row_m2, static_cast(row_count)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(eps)); + + float local_row_mean; + if (std::is_same::value) { + local_row_mean = __half2float(row_mean); + } else if (std::is_same::value) { + local_row_mean = row_mean; + } + float local_row_inv_var; + if (std::is_same::value) { + local_row_inv_var = __half2float(row_inv_var); + } else if (std::is_same::value) { + local_row_inv_var = row_inv_var; + } + + const T* t_gamma = gamma + group_offset; + const T* t_beta = beta + group_offset; + // the first input of this thread + T* t_output = output + offset; +#pragma unroll + for (int row_id = tid; row_id < num_rows; row_id += blockDim.x) { +#pragma unroll + for (int i = 0; i < elems_per_group_channel; i++) { + const T* local_input = t_input + i + row_id * row_stride; + const half* input_half_ptr = reinterpret_cast(local_input); + + T* local_output = t_output + i + row_id * row_stride; + T tmp_output; + half* output_half_ptr = reinterpret_cast(&tmp_output); + + const T* local_gamma = t_gamma + i; + const T* local_beta = t_beta + i; + const half* gamma_half_ptr = reinterpret_cast(local_gamma); + const half* beta_half_ptr = reinterpret_cast(local_beta); + +#pragma unroll + for (int j = 0; j < vec_size; ++j) { + float local_val = __half2float(input_half_ptr[j]); + float local_gamma = __half2float(gamma_half_ptr[j]); + float local_beta = __half2float(beta_half_ptr[j]); + float out_val = (local_val - local_row_mean) * local_row_inv_var; + out_val = out_val * local_gamma + local_beta; + out_val = FuseSwish ? out_val * sigmoid(out_val) : out_val; + output_half_ptr[j] = __float2half_rn(out_val); + } + *local_output = tmp_output; + } + } +} + +// End the Groupnorm implementation that is based on from OneFlow's Layernorm +//////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template