Skip to content

Commit

Permalink
upstream rest rocm backend (facebookincubator#806)
Browse files Browse the repository at this point in the history
Summary:
CC ipiszy

Pull Request resolved: facebookincubator#806

Reviewed By: aakhundov

Differential Revision: D47762985

Pulled By: ipiszy

fbshipit-source-id: 7ceb626579dc9901b416bcbf6873a51803214076
  • Loading branch information
fsx950223 authored and facebook-github-bot committed Jul 25, 2023
1 parent ecf6037 commit ce6697d
Show file tree
Hide file tree
Showing 14 changed files with 774 additions and 68 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
3 changes: 3 additions & 0 deletions python/aitemplate/backend/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
Rocm backend init.
"""
from aitemplate.backend.rocm import lib_template, target_def, utils
from aitemplate.backend.rocm.attention import *
from aitemplate.backend.rocm.common import *
from aitemplate.backend.rocm.conv2d import *
from aitemplate.backend.rocm.embedding import *
from aitemplate.backend.rocm.gemm import *
from aitemplate.backend.rocm.pool2d import *
from aitemplate.backend.rocm.view_ops import *
Expand All @@ -27,4 +29,5 @@
from aitemplate.backend.rocm.normalization import softmax
from aitemplate.backend.rocm.upsample import *
from aitemplate.backend.rocm.vision_ops import *
from aitemplate.backend.rocm.padding import *
from aitemplate.backend.rocm.normalization import groupnorm, groupnorm_swish, layernorm
7 changes: 4 additions & 3 deletions python/aitemplate/backend/rocm/gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% elif "bias" in gemm_flag or has_d0 %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
{% elif gemm_flag in ["permute", "bias_permute"] %}
{% if gemm_flag == "bias_permute" %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
#include "ck/tensor_operation/gpu/device/impl/gemm_specialization.hpp"
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% endif %}
"""
)
Expand Down
12 changes: 6 additions & 6 deletions python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_hardswish.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# pylint: disable=C0415,W0613


@registry.reg("rocm.gemm_rcr_bias_swish.config")
@registry.reg("rocm.gemm_rcr_bias_hardswish.config")
def gemm_config(func_attrs, dtype="float16"):
"""Extract (operation name, operation instance) pair from
all operation candidates.
Expand All @@ -49,7 +49,7 @@ def gemm_config(func_attrs, dtype="float16"):
common.make_fproc_f16(func_attrs, RCR, op_kind, extra_kind)


@registry.reg("rocm.gemm_rcr_bias_swish.gen_profiler")
@registry.reg("rocm.gemm_rcr_bias_hardswish.gen_profiler")
def gemm_gen_profiler(func_attrs, workdir, dim_info_dict):
"""Generates standalone executables for profiler.
Expand All @@ -72,7 +72,7 @@ def gemm_gen_profiler(func_attrs, workdir, dim_info_dict):
)


@registry.reg("rocm.gemm_rcr_bias_swish.gen_function")
@registry.reg("rocm.gemm_rcr_bias_hardswish.gen_function")
def gemm_gen_function(func_attrs, exec_cond_template, dim_info_dict):
"""Generates function body.
Expand Down Expand Up @@ -106,7 +106,7 @@ def gemm_gen_function(func_attrs, exec_cond_template, dim_info_dict):
)


@registry.reg("rocm.gemm_rcr_bias_swish.func_decl")
@registry.reg("rocm.gemm_rcr_bias_hardswish.func_decl")
def gemm_gen_function_decl(func_attrs):
"""Generates function declarations.
Expand All @@ -124,7 +124,7 @@ def gemm_gen_function_decl(func_attrs):
return common.gen_function_decl(func_name=func_name, gemm_flag="bias_hardswish")


@registry.reg("rocm.gemm_rcr_bias_swish.func_call")
@registry.reg("rocm.gemm_rcr_bias_hardswish.func_call")
def gemm_gen_function_call(func_attrs, indent=" "):
"""Generates function call.
Expand All @@ -143,7 +143,7 @@ def gemm_gen_function_call(func_attrs, indent=" "):
return common.gen_function_call(func_attrs, indent, gemm_flag="bias_hardswish")


@registry.reg("rocm.gemm_rcr_bias_swish.filter")
@registry.reg("rocm.gemm_rcr_bias_hardswish.filter")
def gemm_function_filter(cfg, func_attrs, x_shape):
"""Generates function filter.
Expand Down
6 changes: 3 additions & 3 deletions python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_swish.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
operator()<ck::half_t>(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const
{
const half_t a = x0 + x1;
y = a / (type_convert<half_t>(1.0) + type_convert<half_t>(exp(ck::type_convert<float>(-a))));
const ck::half_t a = x0 + x1;
y = a / (ck::type_convert<ck::half_t>(1.0) + ck::type_convert<ck::half_t>(exp(ck::type_convert<float>(-a))));
};
};
} // namespace
Expand Down
2 changes: 2 additions & 0 deletions python/aitemplate/backend/rocm/lib_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def void_ptr_decl(name, dtype="float16", indent=" "):
type_string = "int64_t*"
elif dtype == "bool":
type_string = "bool*"
elif dtype == "int32":
type_string = "int*"
else:
raise NotImplementedError
return PTR_TEMPLATE.render(name=name, dtype=type_string, indent=indent)
Expand Down
20 changes: 20 additions & 0 deletions python/aitemplate/backend/rocm/padding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
CUDA padding init
"""
from . import nhwc3to4, nhwc3to8, pad_last_dim

__all__ = ["nhwc3to8", "pad_last_dim", "nhwc3to4"]
225 changes: 225 additions & 0 deletions python/aitemplate/backend/rocm/padding/nhwc3to4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
CUDA codegen for nhwc3to4 op
"""
import jinja2

from ... import registry
from ...backend_spec import ROCMSpec

# pylint: disable=C0301,W0613,W0612

FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void {{func_name}}(
void*,
void*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
hipStream_t
);
"""
)

FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{in_ptr}},
{{indent}} {{out_ptr}},
{{indent}} {{p_batch}},
{{indent}} {{p_in_h}},
{{indent}} {{p_in_w}},
{{indent}} {{p_out_batch}},
{{indent}} {{p_out_h}},
{{indent}} {{p_out_w}},
{{indent}} stream
{{indent}});
"""
)


EXEC_TEMPLATE = jinja2.Template(
"""
{{indent}}nhwc3to4_launcher<{{elem_input_type}}>(
{{indent}} static_cast<const {{elem_input_type}}*>(in_ptr),
{{indent}} static_cast<{{elem_input_type}}*>(out_ptr),
{{indent}} NI,
{{indent}} HI,
{{indent}} WI,
{{indent}} stream
{{indent}});
{{indent}}return;
"""
)

SRC_TEMPLATE = jinja2.Template(
"""
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
// fast kernel for c_in = 3 & c_out = 4
template <typename Tio, typename Telement, int element_in_Tio>
__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n,
const int32_t h,
const int32_t w,
const Tio *input,
Tio *output,
const int32_t max_output_element,
const int32_t max_input_element,
const Tio zero_io,
const Telement zero_element){
__shared__ Tio shm[192];
const int tidx = blockIdx.x * 192 + threadIdx.x;
const int threadidx = threadIdx.x;
shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
__syncthreads();
const int ouput_offset = blockIdx.x * 256;
const int lower_bound = max_output_element < ouput_offset + 256 ? max_output_element : ouput_offset + 256;
for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
{
const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4;
Telement array[element_in_Tio];
#pragma unroll
for (int k = 0 ; k < element_in_Tio ; k++)
array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k];
output[i] = *((const Tio *)array);
}
}
template <typename ElemT>
void nhwc3to4_launcher(const ElemT* in_ptr,
ElemT* out_ptr,
int NI,
int HI,
int WI,
hipStream_t stream) {
dim3 block(192);
const int nhw = NI * HI * WI;
const int nhwc = nhw * 3;
CHECK_EQ(nhw % 8, 0);
const int element_in_Tio = sizeof(int4) / sizeof(ElemT);
const int max_input_element = nhwc / element_in_Tio;
const int max_output_element = nhw * 4 / element_in_Tio;
const int4 zero_io = {0, 0, 0, 0};
const ElemT zero_element = static_cast<ElemT>(0.0f);
dim3 grid((nhwc + 192 * element_in_Tio - 1)/(192 * element_in_Tio));
nhwc_padding_channel_3To4_kernel<int4, ElemT, element_in_Tio><<<grid, block, 0, stream>>>
(NI, HI, WI,
(const int4 *)in_ptr,
(int4 *)out_ptr,
max_output_element,
max_input_element,
zero_io,
zero_element);
}
void {{function_name}} (
void* in_ptr,
void* out_ptr,
int64_t* batch,
int64_t* in_h,
int64_t* in_w,
int64_t* out_batch,
int64_t* out_h,
int64_t* out_w,
hipStream_t stream
) {
{{shape_function}}
{{exec_paths}}
}
"""
)


@registry.reg("rocm.nhwc3to4.gen_function")
def gen_function(func_attrs, template_path, shape_eval_template, shape_save_template):
"""
Parameters
----------
func_attrs : [type]
[description]
template_path : [type]
[description]
shape_eval_template : [type]
[description]
shape_save_template : [type]
[description]
Returns
-------
[type]
[description]
"""
func_name = func_attrs["name"]
backend_spec = ROCMSpec()
elem_input_type = backend_spec.dtype_to_backend_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
shape_eval_func = shape_eval_template.render(
indent=" ",
dtype="int64_t ",
x_dim0="*batch",
x_dim1="*in_h",
x_dim2="*in_w",
)
shape_save_func = shape_save_template.render(
indent=" ",
y_dim0="*out_batch",
y_dim1="*out_h",
y_dim2="*out_w",
)
shape_func = shape_eval_func + shape_save_func
exec_paths = EXEC_TEMPLATE.render(elem_input_type=elem_input_type)
return SRC_TEMPLATE.render(
function_name=func_name,
elem_input_type=elem_input_type,
shape_function=shape_func,
exec_paths=exec_paths,
)


@registry.reg("rocm.nhwc3to4.func_decl")
def gen_function_decl(func_attrs):
func_name = func_attrs["name"]
return FUNC_DECL_TEMPLATE.render(func_name=func_name)


@registry.reg("rocm.nhwc3to4.func_call")
def gen_function_call(func_attrs, indent=" "):
x = func_attrs["inputs"][0]
xshape = x._attrs["shape"]
y = func_attrs["outputs"][0]
yshape = y._attrs["shape"]
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
in_ptr=x._attrs["name"],
out_ptr=y._attrs["name"],
p_batch="&" + xshape[0]._attrs["name"],
p_in_h="&" + xshape[1]._attrs["name"],
p_in_w="&" + xshape[2]._attrs["name"],
p_out_batch="&" + yshape[0]._attrs["name"],
p_out_h="&" + yshape[1]._attrs["name"],
p_out_w="&" + yshape[2]._attrs["name"],
indent=indent,
)
Loading

0 comments on commit ce6697d

Please sign in to comment.