From 02820ad283077c54d257b9384998c8e94821a296 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 9 Dec 2022 20:00:57 +0900 Subject: [PATCH 01/12] [FQ2I] Support converting `dense` -> `add` to `qnn.dense` -> `add` -> `requantize` (#13578) * wip * hack to convert size-1 scale and zp tensors to scalar * fix to binary op fast path * check output zp * add assert * add comment * lint * clean up beta handling * use regular binary op only for 32 bit add (bias addition) * do float(beta) when we know that beta is not None * restore original beta handling code to avoid mul by 1 * add comment on overflow --- python/tvm/relay/frontend/onnx.py | 5 ++- .../transform/fake_quantization_to_integer.py | 31 +++++++++++++ .../fake_quantization_to_integer.cc | 2 +- .../test_pass_fake_quantization_to_integer.py | 43 ++++++++++++++++--- 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d185d143c7a61..62f0f4b2dd255 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1406,7 +1406,10 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] *= _expr.const(alpha, dtype=dtype) out = _op.nn.dense(inputs[0], inputs[1], units=channels) if len(inputs) == 3: - out = out + _expr.const(beta, dtype=dtype) * inputs[2] + if beta != 1.0: + out += _expr.const(float(beta), dtype=dtype) * inputs[2] + else: + out += inputs[2] return out diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 46bdd94ace1a2..84b1f33e98cc8 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -502,6 +502,37 @@ def register_binary_qnn(op_name, op): def binary(expr, type_map): left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + + if ( + op_name == "add" + and approx_equal(left_t.scale, right_t.scale) + and approx_equal(left_t.zero_point, right_t.zero_point) + and tvm.ir.structural_equal(left_t.dtype, right_t.dtype) + and left_t.dtype == "int32" + and approx_equal(left_t.scale, out_t.scale) + and approx_equal(left_t.zero_point, out_t.zero_point) + and np.all(out_t.zero_point.data.numpy() == 0) + ): + # If this add op comes after conv2d or dense, out_t.scale and out_t.zero_point + # can be a vector, which is not supported by QNN binary operators. + # In particular, the pattern of an `add` op following `dense`, where the addition is + # really a bias addtion, can come up often. We identify that pattern and convert it to + # `qnn.dense` -> `add`. + # To avoid overflow, we do this conversion only when the input data type is 32 bit (bias + # addition is typically done in 32 bit). + return [left + right, left_t] + + assert ( + len(out_t.scale.data.shape) == 0 + ), "The output scale needs to be a scalar, but got a tensor of shape {}".format( + out_t.scale.data.shape + ) + assert ( + len(out_t.zero_point.data.shape) == 0 + ), "The output zero point needs to be a scalar, but got a tensor of shape {}".format( + out_t.zero_point.data.shape + ) + out = op( left, right, diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index eb176df5c978d..31353d5aa25e4 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -193,7 +193,7 @@ class SubgraphMutator : public ExprMutator { return Mutate(expr); } catch (std::exception& e) { if (hard_fail_) { - throw e; + LOG(FATAL) << e.what(); } else { DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl; return expr; diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 569bd9d7d6532..d384635e42e55 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -154,6 +154,41 @@ def test_fake_quantize_dense_per_channel(): compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) +def test_fake_quantize_dense_bias(): + out_dtype = "int8" + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + bias = relay.var("bias", shape=[256], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + w_scale = np.random.random([256]).astype("float32") + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize( + w, + relay.const(w_scale), + zero, + axis=0, + ), + units=256, + ) + + op += relay.qnn.op.dequantize( + bias, + relay.const(2.0 * w_scale), + zero, + ) + + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + bias_np = np.random.randint(-128, 127, size=[256], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np], allow_rounding_error=True) + + def test_fake_quantize_batch_matmul(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 128, 64], dtype="int8") @@ -976,15 +1011,9 @@ def test_fq_qat_positive_nothing_to_do(): op1 = relay.qnn.op.quantize( relay.const(1.0), relay.const(12.0), relay.const(0), out_dtype="int32" ) - op2 = relay.qnn.op.add( + op2 = relay.op.add( op0, op1, - relay.const(12.0), - relay.const(0), - relay.const(12.0), - relay.const(0), - relay.const(12.0), - relay.const(0), ) expected_expr = relay.qnn.op.requantize( op2, relay.const(12.0), relay.const(0), relay.const(1.0), relay.const(0), out_dtype="int8" From 0dc26dd87052ca7c0245a9eb26110e83a96982b1 Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Fri, 9 Dec 2022 21:15:28 -0700 Subject: [PATCH 02/12] [ci][docker] Allow usage of ECR images in PRs (#13590) This fixes `ecr_pull` so that `docker-images.ini` can be updated with Docker images from a previous CI run for testing purposes Example run: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-cortexm/detail/PR-13590/4/pipeline/#step-80-log-9 --- ci/jenkins/generated/arm_jenkinsfile.groovy | 46 +++++++++- .../generated/cortexm_jenkinsfile.groovy | 46 +++++++++- ci/jenkins/generated/cpu_jenkinsfile.groovy | 46 +++++++++- .../generated/docker_jenkinsfile.groovy | 87 +++++++++---------- ci/jenkins/generated/gpu_jenkinsfile.groovy | 46 +++++++++- .../generated/hexagon_jenkinsfile.groovy | 46 +++++++++- ci/jenkins/generated/i386_jenkinsfile.groovy | 46 +++++++++- ci/jenkins/generated/lint_jenkinsfile.groovy | 46 +++++++++- .../generated/minimal_jenkinsfile.groovy | 46 +++++++++- ci/jenkins/generated/riscv_jenkinsfile.groovy | 46 +++++++++- ci/jenkins/generated/wasm_jenkinsfile.groovy | 46 +++++++++- .../templates/docker_jenkinsfile.groovy.j2 | 41 --------- ci/jenkins/templates/utils/Prepare.groovy.j2 | 44 +++++++++- ci/scripts/jenkins/determine_docker_images.py | 6 +- 14 files changed, 520 insertions(+), 118 deletions(-) diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index f1bcc786b72e0..0fc71b430ca03 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.365592 +// Generated at 2022-12-09T15:39:24.387114 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy index 4b5ba2e104f4c..25846f5b4b5ec 100644 --- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy +++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.204393 +// Generated at 2022-12-09T15:39:24.437899 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index 378b20db51b08..f9ede00399a20 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.393957 +// Generated at 2022-12-09T15:39:24.540570 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index 050ef2983e43d..9e1946c194e60 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-07T07:10:24.637792 +// Generated at 2022-12-09T15:39:24.508775 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', @@ -544,47 +584,6 @@ def ecr_push(full_name) { return ecr_name } -def ecr_pull(full_name) { - aws_account_id = sh( - returnStdout: true, - script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', - label: 'Get AWS ID' - ).trim() - - try { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: ''' - set -eux - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 5 docker pull ${full_name} - """, - label: 'Pull image from ECR' - ) - } - } finally { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: 'docker logout $AWS_ECR_REPO', - label: 'Clean up login credentials' - ) - } - } -} - def build_image(image_name) { hash = sh( returnStdout: true, diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 48a6619cbab19..bebc0c4c22a5e 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-07T07:10:24.840515 +// Generated at 2022-12-09T15:39:24.455336 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index e5397eee3a9cf..c2f39a0d084ba 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.338377 +// Generated at 2022-12-09T15:39:24.369191 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index 876670acebba9..ae66fbe3e48ce 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.288840 +// Generated at 2022-12-09T15:39:24.421467 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy index 3aaea4436fcba..f8dccc863590c 100644 --- a/ci/jenkins/generated/lint_jenkinsfile.groovy +++ b/ci/jenkins/generated/lint_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.313954 +// Generated at 2022-12-09T15:39:24.476946 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy index f8a59ef5734d9..6c4abb0bd5af5 100644 --- a/ci/jenkins/generated/minimal_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.235080 +// Generated at 2022-12-09T15:39:24.492813 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy index eb62c3731f799..7b9bbe7ad3997 100644 --- a/ci/jenkins/generated/riscv_jenkinsfile.groovy +++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.442689 +// Generated at 2022-12-09T15:39:24.405262 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index d43c7f9d24e4d..8c8ee03886998 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-12-06T20:56:42.420989 +// Generated at 2022-12-09T15:39:24.526394 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -205,8 +205,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -219,6 +218,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 index db3e6159b82a8..07ae49811337d 100644 --- a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 @@ -61,47 +61,6 @@ def ecr_push(full_name) { return ecr_name } -def ecr_pull(full_name) { - aws_account_id = sh( - returnStdout: true, - script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', - label: 'Get AWS ID' - ).trim() - - try { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION={{ aws_default_region }}', - "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { - sh( - script: ''' - set -eux - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 5 docker pull ${full_name} - """, - label: 'Pull image from ECR' - ) - } - } finally { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION={{ aws_default_region }}', - "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { - sh( - script: 'docker logout $AWS_ECR_REPO', - label: 'Clean up login credentials' - ) - } - } -} - def build_image(image_name) { hash = sh( returnStdout: true, diff --git a/ci/jenkins/templates/utils/Prepare.groovy.j2 b/ci/jenkins/templates/utils/Prepare.groovy.j2 index b295bb4308530..d5aebdc07008d 100644 --- a/ci/jenkins/templates/utils/Prepare.groovy.j2 +++ b/ci/jenkins/templates/utils/Prepare.groovy.j2 @@ -75,8 +75,7 @@ def docker_init(image) { if (image.contains("amazonaws.com")) { // If this string is in the image name it's from ECR and needs to be pulled // with the right credentials - // ecr_pull(image) - sh "echo Pulling from AWS is not implemented && exit 1" + ecr_pull(image) } else { sh( script: """ @@ -89,6 +88,47 @@ def docker_init(image) { } } +def ecr_pull(full_name) { + aws_account_id = sh( + returnStdout: true, + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' + ).trim() + + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + . ${jenkins_scripts_root}/retry.sh + retry 5 docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/ci/scripts/jenkins/determine_docker_images.py b/ci/scripts/jenkins/determine_docker_images.py index 78da9a354629e..41003958dd61b 100755 --- a/ci/scripts/jenkins/determine_docker_images.py +++ b/ci/scripts/jenkins/determine_docker_images.py @@ -32,6 +32,7 @@ PAGE_SIZE = 25 TEST_DATA = None IMAGE_TAGS_FILE = REPO_ROOT / "ci" / "jenkins" / "docker-images.ini" +TVM_CI_ECR = "477529581014.dkr.ecr.us-west-2.amazonaws.com" def docker_api(url: str, use_pagination: bool = False) -> Dict[str, Any]: @@ -111,7 +112,10 @@ def image_exists(spec: str) -> bool: name_dir.mkdir(exist_ok=True) images_to_use = {} for filename, spec in images.items(): - if image_exists(spec): + if spec.startswith(TVM_CI_ECR): + logging.info(f"{spec} is from ECR") + images_to_use[filename] = spec + elif image_exists(spec): logging.info(f"{spec} found in tlcpack") images_to_use[filename] = spec else: From 970110302da46204428686ddc1098cd0bd9b7cd3 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 12 Dec 2022 11:53:11 +0800 Subject: [PATCH 03/12] [TIR][Schedule] Support for specific consumer block targeting in cache_write (#13510) Add optional consumer blocks to cache_write. --- include/tvm/tir/schedule/schedule.h | 4 +- python/tvm/tir/schedule/schedule.py | 11 +- src/tir/schedule/concrete_schedule.cc | 11 +- src/tir/schedule/concrete_schedule.h | 4 +- src/tir/schedule/primitive.h | 4 +- .../schedule/primitive/cache_read_write.cc | 72 ++++++++++-- src/tir/schedule/traced_schedule.cc | 8 +- src/tir/schedule/traced_schedule.h | 4 +- .../test_tir_schedule_cache_read_write.py | 103 ++++++++++++++++++ 9 files changed, 198 insertions(+), 23 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c4838f2eb8aa0..8b22c173a3d8e 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -399,10 +399,12 @@ class ScheduleNode : public runtime::Object { * \param block_rv The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope + * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) = 0; + const String& storage_scope, + const Array consumer_blocks = {}) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 5ff9d71313965..48850012cbb7f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1110,6 +1110,7 @@ def cache_write( block: Union[BlockRV, str], write_buffer_index: Union[int, str, Buffer], storage_scope: str, + consumer_blocks=None, ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: @@ -1130,6 +1131,9 @@ def cache_write( storage_scope: str The target storage scope. + consumer_blocks: Optional[List[Union[BlockRV, str]]] + An optional list of consumers that should read directly from the cache. + If not specified, all consumers will read from the original buffer. Returns ------- @@ -1179,6 +1183,11 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B[vi, vj] = B_local[vi, vj] """ + if consumer_blocks is None: + consumer_blocks = [] + + # Convert any string block names into Block RVs. + consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] block = self._normalize_block_arg(block) if not isinstance(write_buffer_index, int): @@ -1186,7 +1195,7 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: block, write_buffer_index, required_buffer_type="write" ) return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member - self, block, write_buffer_index, storage_scope + self, block, write_buffer_index, storage_scope, consumer_blocks ) @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 7ae0185b425c8..163c72eb07771 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -552,10 +552,17 @@ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer } BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) { + const String& storage_scope, + const Array consumer_blocks) { StmtSRef result{nullptr}; + // Create a new array of SRefs from the consumer block list. + Array consumer_block_refs = {}; + for (BlockRV block : consumer_blocks) { + consumer_block_refs.push_back(this->GetSRef(block)); + } TVM_TIR_SCHEDULE_BEGIN(); - result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope, + consumer_block_refs); TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 2381870760a0b..899775f2a15d6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -114,8 +114,8 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Insert cache stages ********/ BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) override; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, + const Array consumer_blocks = {}) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, int write_buffer_index) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 38931aa271473..9e7f77f55ea5c 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -263,10 +263,12 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r * \param block_sref The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope + * \param consumer_blocks Array of blocks that consume the cache. * \return The cache stage block. */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope); + const String& storage_scope, + const Array consumer_blocks = {}); /*! *! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 27244f1575922..4174a6699e066 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -382,21 +382,34 @@ class CacheLocDetector : public StmtVisitor { * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref * of the scope block of the cached block \param info The cache stage info. */ + template static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { std::vector related_blocks; // If consumer is specified, skip detecting the others - if (info->consumer_blocks.size() > 0) { - for (StmtSRef consumer : info->consumer_blocks) { - related_blocks.emplace_back(consumer); + if (is_cache_read) { + if (info->consumer_blocks.size() > 0) { + for (StmtSRef consumer : info->consumer_blocks) { + related_blocks.emplace_back(consumer); + } + } else { + for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + if (def->kind == DepKind::kRAW) { + related_blocks.push_back(def->dst); + } + } } } else { for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { if (def->kind == DepKind::kRAW) { + if (info->consumer_blocks.count(def->dst)) { + continue; + } related_blocks.push_back(def->dst); } } } + if (!related_blocks.empty()) { CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); detector(GetRef(scope_sref->stmt)); @@ -739,6 +752,30 @@ class CacheWriteRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) final { Block old_stmt = GetRef(block); + + // Check if this block is one of the specified cache consumers. + // update the read buffer to the cache. + for (StmtSRef consumer_sref : info_->consumer_blocks) { + const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); + Block consumer_block = GetRef(consumer_node); + if (old_stmt.same_as(consumer_block)) { + Array reads = + ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); + Array match_buffers = + ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + auto n = CopyOnWrite(block); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + n->body = VisitStmt(block->body); + Block new_consumer = Block(n); + info_->block_reuse.Set(old_stmt, new_consumer); + return std::move(new_consumer); + } + return std::move(old_stmt); + } + } + // We only mutate the block which generates info->write_buffer if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { return std::move(old_stmt); @@ -1160,7 +1197,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff StmtSRef parent_sref = GetRef(write_block_sref->parent); // Detect insert position - CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); } else { // Case 2. The buffer is the input block for the scope. @@ -1190,7 +1227,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope) { + const String& storage_scope, const Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1219,6 +1256,14 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Create the corresponding buffer allocation info.alloc = info.read_buffer; + // info.consumer_blocks indicates which buffers should consume the cache. + for (auto consumer : consumer_blocks) { + info.consumer_blocks.insert(consumer); + for (auto child : tir::GetChildBlocks(self, consumer)) { + info.consumer_blocks.insert(child); + } + } + // Step 3. Check the only writer block. ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); @@ -1226,7 +1271,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); StmtSRef parent_sref = GetRef(block_sref->parent); // Detect insert position - CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); @@ -1468,21 +1513,26 @@ struct CacheWriteTraits : public UnpackedInstTraits { static constexpr bool kIsPure = false; private: - static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumInputs = 2; static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index, + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, + Array consumer_blocks, Integer write_buffer_index, String storage_scope) { - return sch->CacheWrite(block, write_buffer_index->value, storage_scope); + return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, - String storage_scope) { + static String UnpackedAsPython(Array outputs, String block, Array consumer_blocks, + Integer write_buffer_index, String storage_scope) { PythonAPICall py("cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); py.Input("storage_scope", storage_scope); + // Only write out consumer blocks if provided. + if (!consumer_blocks.empty()) { + py.Input("consumer_blocks", consumer_blocks); + } py.SingleOutput(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 00941b48575d8..70559608e7893 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -296,12 +296,14 @@ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_i } BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) { - BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope); + const String& storage_scope, + const Array consumer_blocks) { + BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, + consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, + /*inputs=*/{block_rv, consumer_blocks}, /*attrs=*/{Integer(write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 80257f644f6b1..c54574e9c9ff8 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -74,8 +74,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Insert cache stages ********/ BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) final; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, + const Array consumer_blocks = {}) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 3476ca0830561..28c9a13700bf8 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -858,6 +858,81 @@ def cache_write_multi_consumer() -> None: C[vi] = A[vi] +@T.prim_func +def cache_write_multi_consumer_B_consume_cache(): + A = T.alloc_buffer([128], dtype="float32") + B = T.alloc_buffer([128], dtype="float32") + C = T.alloc_buffer([128], dtype="float32") + A_global = T.alloc_buffer([128], dtype="float32") + for i in T.serial(8): + for j in T.serial(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + A_global[vi] = 1.0 + for j in T.serial(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + B[vi] = A_global[vi] + 1.0 + for ax0 in T.serial(128): + with T.block("A_global"): + v0 = T.axis.spatial(128, ax0) + A[v0] = A_global[v0] + for i in T.serial(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + C[vi] = A[vi] + + +@T.prim_func +def cache_write_multi_consumer_C_consume_cache(): + A = T.alloc_buffer([128], dtype="float32") + B = T.alloc_buffer([128], dtype="float32") + C = T.alloc_buffer([128], dtype="float32") + A_global = T.alloc_buffer([128], dtype="float32") + for i in T.serial(8): + for j in T.serial(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + A_global[vi] = T.float32(1) + for ax0 in T.serial(16): + with T.block("A_global"): + v0 = T.axis.spatial(128, i * 16 + ax0) + A[v0] = A_global[v0] + for j in T.serial(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + B[vi] = A[vi] + T.float32(1) + for i in T.serial(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + C[vi] = A_global[vi] + + +@T.prim_func +def cache_write_multi_consumer_all_consume_cache(): + A = T.alloc_buffer([128], dtype="float32") + B = T.alloc_buffer([128], dtype="float32") + C = T.alloc_buffer([128], dtype="float32") + A_global = T.alloc_buffer([128], dtype="float32") + for i in T.serial(8): + for j in T.serial(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + A_global[vi] = T.float32(1) + for j in T.serial(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + B[vi] = A_global[vi] + T.float32(1) + for i in T.serial(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + C[vi] = A_global[vi] + for ax0 in T.serial(128): + with T.block("A_global"): + v0 = T.axis.spatial(128, ax0) + A[v0] = A_global[v0] + + @T.prim_func def continuous_cache_write(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) @@ -1113,6 +1188,34 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + # Test that specific consumer block targetting works. + # B read cache buffer and C read original output buffer + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_a = "A" if use_block_name else sch.get_block("A") + block_b = "B" if use_block_name else sch.get_block("B") + sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b]) + tvm.ir.assert_structural_equal(cache_write_multi_consumer_B_consume_cache, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + # Test that specific consumer block targetting works. + # B read original output buffer and C read cache buffer + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_a = "A" if use_block_name else sch.get_block("A") + block_c = "C" if use_block_name else sch.get_block("C") + sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c]) + tvm.ir.assert_structural_equal(cache_write_multi_consumer_C_consume_cache, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + # Test that specific consumer block targetting works. + # B and C read cache buffer + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_a = "A" if use_block_name else sch.get_block("A") + block_b = "B" if use_block_name else sch.get_block("B") + block_c = "C" if use_block_name else sch.get_block("C") + sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c]) + tvm.ir.assert_structural_equal(cache_write_multi_consumer_all_consume_cache, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + def test_continuous_cache_write(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") From ae07437a32c8addadfec4002426e03a6b2bc8781 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 12 Dec 2022 18:53:15 +0800 Subject: [PATCH 04/12] [LLVM] Fix get tm allow_missing check pos (#13591) Fix get tm allow_missing check pos --- src/target/llvm/llvm_instance.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 44454fc6b92d4..2aa190ad708ea 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -297,9 +297,9 @@ llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, reloc_model_, code_model_, opt_level_); target_machine_ = std::unique_ptr(tm); - if (!allow_missing) { - ICHECK(target_machine_ != nullptr) << error; - } + } + if (!allow_missing) { + ICHECK(target_machine_ != nullptr) << error; } return target_machine_.get(); } From 760b10ae0ec387f6dfb7945b7368c71371996142 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 12 Dec 2022 22:38:57 +0900 Subject: [PATCH 05/12] [Torch] Stable diffusion support (#13594) * add baddbmm conversion * fix * suppress lint --- python/tvm/relay/frontend/pytorch.py | 18 +++++++++++++++++- tests/python/frontend/pytorch/test_forward.py | 16 +++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 30f14b490b1bb..b9d167ad2d865 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1863,6 +1863,13 @@ def chunk(self, inputs, input_types): return _op.split(data, indeces, axis) + def baddbmm(self, inputs, _): + input = inputs[0] + batch1, batch2 = inputs[1:3] + beta = _expr.const(float(inputs[3])) + alpha = _expr.const(float(inputs[4])) + return beta * input + alpha * _op.nn.batch_matmul(batch1, batch2, transpose_b=False) + def matmul(self, inputs, input_types): inputs_0 = inputs[0] @@ -2565,7 +2572,14 @@ def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) def empty(self, inputs, input_types): - shape = inputs[0] + shape = [] + for s in inputs[0]: + if isinstance(s, _expr.Constant): + shape.append(s.data.numpy().item()) + else: + assert isinstance(s, int) + shape.append(s) + return _op.zeros(shape, _convert_dtype_value(inputs[1])) def empty_like(self, inputs, input_types): @@ -3621,6 +3635,7 @@ def create_convert_map(self): "aten::unsafe_chunk": self.chunk, "aten::matmul": self.matmul, "aten::bmm": self.matmul, + "aten::baddbmm": self.baddbmm, "aten::expand": self.expand, "aten::Int": self.int, "prim::NumToTensor": self.numtotensor, @@ -4587,6 +4602,7 @@ def from_pytorch( if inp.type().kind() == "TupleType" or inp.type().kind() == "ListType": enable_lower_all_tuples = False break + _run_jit_passes(graph, enable_lower_all_tuples) if custom_convert_map: diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 36bb5bede475e..35242fbf7dde0 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-self, invalid-name, unused-argument +# pylint: disable=import-self, invalid-name, unused-argument, missing-function-docstring """Unit tests for various models and operators""" import os import platform @@ -5038,5 +5038,19 @@ def _test_multinomial(num_samples): ) +@tvm.testing.uses_gpu +def test_baddbmm(): + def test_fn(alpha, beta): + return lambda inp, batch1, batch2: torch.baddbmm( + inp, batch1, batch2, beta=beta, alpha=alpha + ) + + M = torch.randn(10, 3, 5) + batch1 = torch.randn(10, 3, 4) + batch2 = torch.randn(10, 4, 5) + + verify_model(test_fn(0.5, 1.0), [M, batch1, batch2]) + + if __name__ == "__main__": tvm.testing.main() From fe1d7ad4f2b322aec7aa6358f63488d953614ccd Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 13 Dec 2022 00:02:27 +0300 Subject: [PATCH 06/12] [OpenCL][CI] Enable OpenCL cpp tests in CI (#13400) * [OpenCL][CI] Enable OpenCL cpp tests in CI * Add building gtest for OpenCL in GPU build * Fix CI build * Change OpenCL cpp tests build approach * Fix lint * Try to enable test in CI * Update version of gpu docker image * Change script mod --- CMakeLists.txt | 15 ----- ci/jenkins/docker-images.ini | 2 +- ci/jenkins/generated/gpu_jenkinsfile.groovy | 8 +++ .../templates/gpu_jenkinsfile.groovy.j2 | 8 +++ cmake/modules/OpenCL.cmake | 12 +++- tests/cpp-runtime/opencl/run_gtests.cc | 60 ------------------- .../contrib/test_opencl/test_run_gtests.py | 56 ----------------- tests/scripts/ci.py | 1 + tests/scripts/task_config_build_gpu.sh | 1 + tests/scripts/task_opencl_cpp_unittest.sh | 39 ++++++++++++ 10 files changed, 69 insertions(+), 133 deletions(-) delete mode 100644 tests/cpp-runtime/opencl/run_gtests.cc delete mode 100644 tests/python/contrib/test_opencl/test_run_gtests.py create mode 100755 tests/scripts/task_opencl_cpp_unittest.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 736d516fa1f60..119bf8325c8ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -638,18 +638,6 @@ if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTES include_directories("${USE_HEXAGON_GTEST}/include") endif() -if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) - include(FetchContent) - FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") - set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) - FetchContent_MakeAvailable(googletest) - target_link_libraries(tvm_runtime PUBLIC gtest) - target_link_libraries(tvm PUBLIC gtest) - include_directories("${USE_OPENCL_GTEST}/include") - include_directories("${USE_OPENCL_GTEST}/googletest/include") - message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") -endif() - # Set flags for clang include(cmake/modules/ClangFlags.cmake) set(CRC16_INCLUDE_PATH "3rdparty/libcrc/include") @@ -715,9 +703,6 @@ install(TARGETS tvm_runtime EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_S if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) endif() -if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) - install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) -endif() if (INSTALL_DEV) install( diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini index 119a43218642c..40e1b8a1313fc 100644 --- a/ci/jenkins/docker-images.ini +++ b/ci/jenkins/docker-images.ini @@ -20,7 +20,7 @@ ci_arm: tlcpack/ci-arm:20221013-060115-61c9742ea ci_cortexm: tlcpack/ci-cortexm:20221013-060115-61c9742ea ci_cpu: tlcpack/ci-cpu:20221013-060115-61c9742ea -ci_gpu: tlcpack/ci-gpu:20221019-060125-0b4836739 +ci_gpu: tlcpack/ci-gpu:20221128-070141-ae4fd7df7 ci_hexagon: tlcpack/ci-hexagon:20221013-060115-61c9742ea ci_i386: tlcpack/ci-i386:20221013-060115-61c9742ea ci_lint: tlcpack/ci-lint:20221013-060115-61c9742ea diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index bebc0c4c22a5e..a5609697af469 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -614,6 +614,14 @@ def shard_run_unittest_GPU_1_of_3() { make_standalone_crt(ci_gpu, 'build') make_cpp_tests(ci_gpu, 'build') cpp_unittest(ci_gpu) + sh ( + script: "${docker_run} ${ci_gpu} python3 ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod --cmake-target opencl-cpptest --build-dir build", + label: 'Make OpenCL cpp unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_opencl_cpp_unittest.sh", + label: 'Run OpenCL cpp unit tests', + ) micro_cpp_unittest(ci_gpu) sh ( script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", diff --git a/ci/jenkins/templates/gpu_jenkinsfile.groovy.j2 b/ci/jenkins/templates/gpu_jenkinsfile.groovy.j2 index 2a9e7236d26df..40698131a7833 100644 --- a/ci/jenkins/templates/gpu_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/gpu_jenkinsfile.groovy.j2 @@ -63,6 +63,14 @@ make_standalone_crt(ci_gpu, 'build') make_cpp_tests(ci_gpu, 'build') cpp_unittest(ci_gpu) + sh ( + script: "${docker_run} ${ci_gpu} python3 ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod --cmake-target opencl-cpptest --build-dir build", + label: 'Make OpenCL cpp unit tests', + ) + sh ( + script: "${docker_run} ${ci_gpu} ./tests/scripts/task_opencl_cpp_unittest.sh", + label: 'Run OpenCL cpp unit tests', + ) micro_cpp_unittest(ci_gpu) {% else %} {{ m.download_artifacts(tag='gpu') }} diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index e738df7c564ca..1e1041efe3860 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -59,9 +59,19 @@ if(USE_OPENCL) endif() if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) - file_glob_append(RUNTIME_OPENCL_SRCS + include(FetchContent) + FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) + + message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") + + tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" ) + add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) + target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime) endif() list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) else() diff --git a/tests/cpp-runtime/opencl/run_gtests.cc b/tests/cpp-runtime/opencl/run_gtests.cc deleted file mode 100644 index ffe86a7f52c0c..0000000000000 --- a/tests/cpp-runtime/opencl/run_gtests.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include - -#include -#include - -#include "../src/support/utils.h" - -namespace tvm { -namespace runtime { -namespace cl { - -TVM_REGISTER_GLOBAL("opencl.run_gtests").set_body([](TVMArgs args, TVMRetValue* rv) { - // gtest args are passed into this packed func as a singular string - // split gtest args using delimiter and build argument vector - std::vector parsed_args = tvm::support::Split(args[0], ' '); - std::vector argv; - - // add executable name - argv.push_back(const_cast("opencl_run_gtests")); - - // add parsed arguments - for (size_t i = 0; i < parsed_args.size(); ++i) { - argv.push_back(const_cast(parsed_args[i].data())); - } - - // end of parsed arguments - argv.push_back(nullptr); - - // set argument count - int argc = argv.size() - 1; - - // initialize gtest with arguments and run - ::testing::InitGoogleTest(&argc, argv.data()); - *rv = RUN_ALL_TESTS(); -}); - -} // namespace cl -} // namespace runtime -} // namespace tvm diff --git a/tests/python/contrib/test_opencl/test_run_gtests.py b/tests/python/contrib/test_opencl/test_run_gtests.py deleted file mode 100644 index ee59086b25f15..0000000000000 --- a/tests/python/contrib/test_opencl/test_run_gtests.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import pytest -import numpy as np - -import tvm -from tvm import rpc - - -# use pytest -sv to observe gtest output -# use --gtest_args to pass arguments to gtest -# for example to run all "foo" tests twice and observe gtest output run -# pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" -@tvm.testing.requires_opencl -@pytest.mark.skipif(tvm.testing.utils.IS_IN_CI, reason="failed due to nvidia libOpencl in the CI") -def test_run_gtests(gtest_args): - if ( - "TVM_TRACKER_HOST" in os.environ - and "TVM_TRACKER_PORT" in os.environ - and "TVM_TRACKER_KEY" in os.environ - ): - rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] - rpc_tracker_port = os.environ["TVM_TRACKER_PORT"] - rpc_tracker_port = int(rpc_tracker_port) - rpc_key = os.environ["TVM_TRACKER_KEY"] - tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) - rpc_connection = tracker.request(rpc_key, priority=0, session_timeout=600) - else: - rpc_connection = rpc.LocalSession() - - try: - func = rpc_connection.get_function("opencl.run_gtests") - except: - print( - "This test requires TVM Runtime to be built with a OpenCL gtest version using OpenCL API cmake flag -DUSE_OPENCL_GTEST=/path/to/opencl/googletest/gtest" - ) - raise - - gtest_error_code = func(gtest_args) - np.testing.assert_equal(gtest_error_code, 0) diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 6799f68d43b73..b11ee538dc68e 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -593,6 +593,7 @@ def add_subparser( "run unit tests", [ "./tests/scripts/task_java_unittest.sh", + "./tests/scripts/task_opencl_cpp_unittest.sh", "./tests/scripts/task_python_unittest_gpuonly.sh", "./tests/scripts/task_python_integration_gpuonly.sh", ], diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index ca5f3e935c08d..90c91fb990be6 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -29,6 +29,7 @@ echo set\(USE_CUDA ON\) >> config.cmake echo set\(USE_VULKAN ON\) >> config.cmake echo set\(USE_OPENGL ON\) >> config.cmake echo set\(USE_OPENCL ON\) >> config.cmake +echo set\(USE_OPENCL_GTEST \"/googletest\"\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_LLVM \"/usr/bin/llvm-config-9 --link-static\"\) >> config.cmake diff --git a/tests/scripts/task_opencl_cpp_unittest.sh b/tests/scripts/task_opencl_cpp_unittest.sh new file mode 100755 index 0000000000000..7ea6ea470db7c --- /dev/null +++ b/tests/scripts/task_opencl_cpp_unittest.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +if [ $# -gt 0 ]; then + BUILD_DIR="$1" +elif [ -n "${TVM_BUILD_PATH:-}" ]; then + # TVM_BUILD_PATH may contain multiple space-separated paths. If + # so, use the first one. + BUILD_DIR=$(IFS=" "; set -- $TVM_BUILD_PATH; echo $1) +else + BUILD_DIR=build +fi + + +# to avoid CI thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + +pushd "${BUILD_DIR}" +# run cpp test executable +./opencl-cpptest +popd From 51431d5a8c2c0fd956577ea974d0bee4a057e583 Mon Sep 17 00:00:00 2001 From: padreofthegame <97688606+padreofthegame@users.noreply.github.com> Date: Mon, 12 Dec 2022 23:17:33 +0100 Subject: [PATCH 07/12] [Relay] Bug fix in relay.squeeze function for issue #12400 (#12684) [Relay] Bug fix in relay.squeeze function. Also added functionality for parameter axis of type int --- python/tvm/relay/op/transform.py | 26 +++++++++++++++++++++----- tests/python/relay/test_op_level3.py | 7 ++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index e7ae5f7d83157..024da84cbfd8e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -204,23 +204,39 @@ def squeeze(data, axis=None): Parameters ---------- - data : tvm.relay.Expr + data : relay.Expr The input data to the operator. - axis : None or List[int] or Expr + axis : Union[None, int, Tuple[int], List[int]] or Expr The set of axes to remove. - If axis = None, remove all axis of dimensions 1. + If axis = None, remove all axes of dimension 1. If any specified axis has dimension that does not equal 1, it is an error. Returns ------- - result : tvm.relay.Expr + result : relay.Expr The squeezed result. """ if isinstance(axis, Constant): - axis = list(axis.data.numpy()) + if axis.data.shape: + axis = list(axis.data.numpy()) + else: + axis = [axis.data.numpy().item()] if isinstance(axis, Expr): return _dyn_make.squeeze(data, axis) + if isinstance(axis, int): + axis = [axis] + if isinstance(axis, (tuple, list)): + tempaxis = [] + for tmpax in axis: + if isinstance(tmpax, _expr.IntImm): + tempaxis.append(tmpax.value) + else: + try: + tempaxis.append(int(tmpax)) + except ValueError as err: + raise RuntimeError("Unrecognized axis type: %s" % err) + axis = tempaxis return _make.squeeze(data, axis) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index c3b3215e84e42..c96bc940f920c 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -210,13 +210,18 @@ class TestSqueeze: ((1, 3, 2, 5), "float32", None), ((1, 3, 1), "float32", [0]), ((1, 2, 1, 2, 1), "float32", [0, 2]), + ((1, 3, 1), "float32", 2), + ((1, 3, 1), "float32", []), ) def test_squeeze(self, shape, dtype, axis): x = relay.var("x", relay.TensorType(shape, dtype)) squeeze = relay.squeeze(x, axis=axis) - np_axis = tuple(axis) if axis is not None else None + if isinstance(axis, int): + np_axis = (axis,) + else: + np_axis = tuple(axis) if axis is not None else None data = np.random.random_sample(shape).astype(dtype) op_res = create_executor().evaluate(squeeze, {x: relay.const(data)}) From ec9fcc0dac9c09dcb5ef2c56f016c8433c978db6 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 13 Dec 2022 14:31:36 +0900 Subject: [PATCH 08/12] [Relay] Fix `CombineParallelDense` slicing axis (#13597) The current implementation of `CombineParallelDense` is hardcoded to slice along the last axis after the combined dense. I hit an error using this pass on the stable diffusion UNet, since it has a combined group where the dense is followed by `expand_dims` which changes the slicing axis (see https://github.com/masahi/torchscript-to-tvm/blob/master/stable-diffusion/compile.py for repro) ``` %76 = concatenate(%74) /* ty=Tensor[(20160, 1280), float32] */; %79 = concatenate(%77) /* ty=Tensor[(20160), float32] */; %78 = nn.dense(%75, %76, units=20160) /* ty=Tensor[(2, 20160), float32] */; %80 = nn.bias_add(%78, %79, axis=-1) /* ty=Tensor[(2, 20160), float32] */; %81 = expand_dims(%80, axis=2) /* ty=Tensor[(2, 20160, 1), float32] */; %82 = expand_dims(%81, axis=3) /* ty=Tensor[(2, 20160, 1, 1), float32] */; ``` The correct way to generate `strided_slice`: ``` %84 = strided_slice(%82, begin=[0, 0, 0, 0], end=[-1, 320, -1, -1], strides=[1, 1, 1, 1], slice_mode="size", axes=None) /* ty=Tensor[(2, 320, 1, 1), float32] */; ``` As I documented in the code, this fix is probably not 100% fail-proof. I think this is a difficult problem, since it requires tracking how the original output-channel axis of the combined dense moves across shape-changing operations like `reshape /transpose / split`. But this is at least "more correct" than the current implementation, so I'm submitting this fix as is for now. With this fix, `CombineParallelDense` works successfully on the stable diffusion UNet, and it reduces the number of `nn.dense` from 184 to 100. --- .../transforms/combine_parallel_dense.cc | 43 +++++++++++----- .../relay/test_pass_combine_parallel_dense.py | 51 ++++++++++++++++--- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 7cf102b5bcab7..e5f7e0b975f4a 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -195,23 +195,40 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; + const auto dense_op = Op::Get("nn.dense"); for (const auto& branch : branches) { const CallNode* call = branch[depth]; auto& out_shape = call->type_as()->shape; - auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]); - ICHECK(out_dims != nullptr); - Array begin; - Array end; - Array strides; - for (size_t k = 0; k < out_shape.size() - 1; ++k) { - begin.push_back(0); - end.push_back(-1); - strides.push_back(1); + + const CallNode* dense = branch[0]; + ICHECK(dense->op.same_as(dense_op)); + auto& dense_shape = dense->type_as()->shape; + auto dense_out_dims = tir::as_const_int(dense_shape[1]); + ICHECK(dense_out_dims != nullptr); + + // dense can be followed by shape-changing operations, so the slicing axis is + // not necessarily the last one. + // TODO(masahi): The following logic is incorrect if (1) there is no axis in + // out_shape[i] that directly corresponds to the output channel of dense or (2) there + // is another axis that happens to have the same size as the output channel of dense. + // Such cases might arise due to reshape / transpose / split etc. Revisit this logic + // when we encounter them in practice. + auto slice_axis = -1; + for (size_t i = out_shape.size() - 1; i >= 0; --i) { + ICHECK(tir::as_const_int(out_shape[i])); + if (*tir::as_const_int(out_shape[i]) == *dense_out_dims) { + slice_axis = i; + break; + } } - begin.push_back(index); - end.push_back(*out_dims); - strides.push_back(1); - index += *out_dims; + ICHECK(slice_axis != -1); + + Array begin(out_shape.size(), 0); + Array end(out_shape.size(), -1); + Array strides(out_shape.size(), 1); + begin.Set(slice_axis, index); + end.Set(slice_axis, *dense_out_dims); + index += *dense_out_dims; auto slice = MakeStridedSlice(data, begin, end, strides, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index cd946ab593bf2..2494c1a550cd3 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +import tvm.testing from tvm import relay from tvm.relay import transform @@ -359,10 +359,47 @@ def check(i, j, k, scale1, scale2, newshape1, newshape2): check(100, 200, 300, 0.5, 0.25, (1, 1, 20000), (1, 1, 40000)) +def test_combine_parallel_dense_expand_dims(): + """Verify that the correct slice axis is selected after the combined dense.""" + + def before(x, w1, w2): + args = [x, w1, w2] + y1 = relay.nn.dense(x, w1) + y1 = relay.expand_dims(y1, axis=2) + + y2 = relay.nn.dense(x, w2) + y2 = relay.expand_dims(y2, axis=2) + + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2): + args = [x, w1, w2] + w_stacked = relay.concatenate((w1, w2), axis=0) + y = relay.nn.dense(x, w_stacked, units=24) + y = relay.expand_dims(y, axis=2) + + strides = [1, 1, 1] + y1 = relay.strided_slice( + y, begin=[0, 0, 0], end=[-1, 16, -1], strides=strides, slice_mode="size" + ) + y2 = relay.strided_slice( + y, begin=[0, 16, 0], end=[-1, 8, -1], strides=strides, slice_mode="size" + ) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + x = relay.var("x", shape=(2, 32)) + w1 = relay.var("w1", shape=(16, 32)) + w2 = relay.var("w2", shape=(8, 32)) + + y_before = before(x, w1, w2) + combine_pass = transform.CombineParallelDense(min_num_branches=2, to_batch=False) + y = run_opt_pass(y_before, combine_pass) + y_expected = expected(x, w1, w2) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + if __name__ == "__main__": - test_combine_parallel_dense() - test_combine_parallel_dense_biasadd() - test_combine_parallel_dense_biasadd_scale_reshape() - test_combine_parallel_dense_flat() - test_combine_parallel_dense_flat_biasadd() - test_combine_parallel_dense_flat_biasadd_scale_reshape() + tvm.testing.main() From b7015bb388a1d8f84b89e57709831e74c032d941 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 13 Dec 2022 02:05:01 -0500 Subject: [PATCH 09/12] [Fix] Task scheduler error prompt upon build/run failure (#13601) --- src/meta_schedule/task_scheduler/task_scheduler.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 69a70f63c5c01..9d859947e4fee 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -120,7 +120,9 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r std::string err = error_msg.value(); TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // << "[Task #" << task_id << ": " << name << "] Trial #" << trials - << ": Error in building:\n" + << ": Error in " + << (builder_result->error_msg.defined() ? "building" : "running") + << ":\n" << err << "\n" << tir::AsTVMScript(sch->mod()) << "\n" << Concat(sch->trace().value()->AsPython(false), "\n"); From 1d9863470e0e97413d05b98f2852dc7de60611a0 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 13 Dec 2022 20:10:13 +0900 Subject: [PATCH 10/12] [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605) * Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer * add comment --- .../plan_update_buffer_allocation_location.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 4c63d3393fd89..11d8330ec8fe7 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { } private: + bool find(const Buffer& buf) { + return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) != + buffer_alloc_recorder_.end(); + } + void VisitStmt_(const BlockNode* op) final { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } + // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes. + // These buffers only appear in read and match_buffer regions. + for (const auto& region : op->match_buffers) { + if (!find(region->source->buffer)) { + buffer_alloc_recorder_.push_back(region->source->buffer); + } + } + StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitStmt_(op); From 12311dcdefd7f2213ce5ce78f2c590444a04b32d Mon Sep 17 00:00:00 2001 From: Farshid Salemi Parizi Date: Tue, 13 Dec 2022 07:44:08 -0800 Subject: [PATCH 11/12] [Hexagon] Enable depthwise conv2d NHWC with an HWIO kernel layout (#13414) Enable depthwise conv2d NHWC with HWIO kernel layout. The default kernel layout is HWOI, matched to previous behavior. --- python/tvm/relay/op/strategy/arm_cpu.py | 2 +- python/tvm/relay/op/strategy/hexagon.py | 3 +-- python/tvm/relay/op/strategy/x86.py | 3 +-- python/tvm/topi/nn/depthwise_conv2d.py | 19 ++++++++++++++++--- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 261b979dedaf3..c8d51bc23c82d 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -318,7 +318,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): else: logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index c1d64f2fe143c..f42503a1477c1 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -86,9 +86,8 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.hexagon", ) elif layout == "NHWC": - assert kernel_layout == "HWOI" strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.hexagon", ) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3e59209f58228..7ff4dbc0ad1b7 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -228,13 +228,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": - assert kernel_layout == "HWOI" if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout): logger.warning( "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." ) strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index 48ffb8c6d9ffc..7c446a23a8139 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm +import numpy as np from tvm import te from .dilate import dilate @@ -211,7 +212,9 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No return Output -def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None): +def depthwise_conv2d_nhwc( + Input, Filter, stride, padding, dilation, kernel_layout="HWOI", out_dtype=None +): """Depthwise convolution nhwc forward operator. Parameters @@ -252,8 +255,14 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape + # shape of dilated kernel - filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape + if kernel_layout == "HWIO": + filter_height, filter_width, channel_multiplier, filter_channel = Filter.shape + kernel_permutation = [0, 1, 3, 2] + else: + filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape + kernel_permutation = [0, 1, 2, 3] dilated_kernel_h = (filter_height - 1) * dilation_h + 1 dilated_kernel_w = (filter_width - 1) * dilation_w + 1 @@ -285,7 +294,11 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No idxdiv(c, channel_multiplier), ].astype(out_dtype) * Filter[ - di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier) + tuple( + np.array( + [di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)] + )[kernel_permutation] + ) ].astype(out_dtype) ), axis=[di, dj], From c547bbb13d2c42b7c447dfaee74734e2f7ffe18c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 13 Dec 2022 09:37:03 -0800 Subject: [PATCH 12/12] [Relay][Frontend][Onnx] SequenceAt and SplitToSequence Operators (#13602) * Add support for SequenceAt and SplitToSequence to onnx importer * Formatting * Change keepdims comparison * Only unify non-tuples in If --- python/tvm/relay/frontend/onnx.py | 79 ++++++++++++++++++++++ python/tvm/relay/op/_transform.py | 2 - tests/python/frontend/onnx/test_forward.py | 31 +++++---- 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 62f0f4b2dd255..3470099100d46 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params): for var in else_free_vars: graph_scope._nodes.update({var.name_hint: var}) + # Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks. + # Often these dont contribute anything. If we see a dynamic rank output, try to unify + # them so we can continue without breaking. + if not isinstance(then_expr, _expr.Tuple) and not isinstance(else_expr, _expr.Tuple): + then_shape = infer_shape(then_expr) + else_shape = infer_shape(else_expr) + if len(then_shape) != len(else_shape): + warning_msg = ( + "If statement produced outputs with different rank. " + "Attempting to unify ranks but this may produce incorrect results." + ) + warnings.warn(warning_msg) + if len(then_shape) < len(else_shape): + then_expr = _op.broadcast_to_like(then_expr, else_expr) + else: + else_expr = _op.broadcast_to_like(else_expr, then_expr) + # Now we can construct the relay if statement and return. ret = _expr.If(cond, then_expr, else_expr) if len(then_branch.output) > 1: @@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params): return _op.concatenate(inputs[0], axis=axis) +class SplitToSequence(OnnxOpConverter): + """Operator converter for split to sequence op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", 1) + + input_tensor = inputs[0] + input_shape = infer_shape(input_tensor) + split = inputs[1] + + # If split is not provided, we split all values along axis. + if split is None: + output = _op.split(input_tensor, input_shape[axis], axis=axis) + # If keepdims is 0, then we need to squeeze off the axis. + if not keepdims: + output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output] + return _expr.Tuple(list(output)) + + # Otherwise, split based on provided split value. + else: + # For now we only support constant valued split. + assert isinstance( + split, _expr.Constant + ), "Only constant split supported for SplitToSequence" + split = split.data.numpy() + if len(split.shape) == 1 and split.shape[0] > 1: + # If split is a 1D tensor, it must be converted to indices for relay compatibility. + split = np.cumsum(split) + # Remove final invalid index. + split = split[:-1] + else: + # Otherwise get split as an integer. + split = int(split) + + output = _op.split(input_tensor, split, axis=axis) + + # If keepdims is set to 0 remove split axis. Note that this is + # an inconsistency with the onnx spec but is needed for pytorch compatibility. + if not keepdims: + output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output] + return _expr.Tuple(list(output)) + + +class SequenceAt(OnnxOpConverter): + """Operator converter for sequence at op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + input_sequence = inputs[0] + position = inputs[1] + assert isinstance( + position, _expr.Constant + ), "Only constant position supported for SequenceAt" + # Convert position to integer. + position = int(position.data.numpy()) + return input_sequence[position] + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -5793,6 +5870,8 @@ def _get_convert_map(opset): "SequenceConstruct": SequenceConstruct.get_converter(opset), "SequenceInsert": SequenceInsert.get_converter(opset), "ConcatFromSequence": ConcatFromSequence.get_converter(opset), + "SplitToSequence": SplitToSequence.get_converter(opset), + "SequenceAt": SequenceAt.get_converter(opset), } diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5b7e342c4b4ed..d4e4a527835a3 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -432,8 +432,6 @@ def _concatenate_shape_func(inputs, axis): for i in const_range(ndim): if i != axis: out[i] = inputs[0][i] - for j in const_range(1, len(inputs)): - assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate." else: out[i] = int64(0) for j in const_range(len(inputs)): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 211d7f798aba7..dcd4f2defbe82 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7043,7 +7043,7 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1): def test_sequence(target, dev): """test_sequence""" - def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_axis=None): + def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=None): tensor_shape = list(tensor_shape) tensor_values = [] for i in range(num_tensors): @@ -7062,20 +7062,30 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax outputs=["sequence"], ) - insert_inputs = ["sequence", input_tensor_names[0]] - position_node = None - if position is not None: - insert_inputs.append("position") - position_node = make_constant_node("position", TensorProto.INT32, (), [position]) + position_node = make_constant_node("position", TensorProto.INT32, (), [position]) # Test sequence insertion. insert_node = helper.make_node( - "SequenceInsert", inputs=insert_inputs, outputs=["inserted_sequence"] + "SequenceInsert", + inputs=["sequence", input_tensor_names[0], "position"], + outputs=["inserted_sequence"], ) # Test sequence concatenation. concat_node = helper.make_node( - "ConcatFromSequence", inputs=["inserted_sequence"], outputs=["output"], axis=axis + "ConcatFromSequence", + inputs=["inserted_sequence"], + outputs=["concat_sequence"], + axis=axis, + ) + + # Test splitting a tensor into a sequence. + split_node = helper.make_node( + "SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis + ) + + at_node = helper.make_node( + "SequenceAt", inputs=["split_sequence", "position"], outputs=["output"] ) if new_axis is not None: @@ -7097,10 +7107,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax output_shape[axis] = (num_tensors + 1) * output_shape[axis] graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)] - graph_nodes = [] - if position_node is not None: - graph_nodes.append(position_node) - graph_nodes += [construct_node, insert_node, concat_node] + graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node] graph = helper.make_graph( graph_nodes,