Skip to content

Commit

Permalink
[OpenCL][Textures] Always use SSA for texture loading (#14397)
Browse files Browse the repository at this point in the history
* [OpenCL][Textures] Always use SSA for texture loading

In some cases we must use SSA for textures loading but we didn't do
that. Example of such cases:
1. Storing texture (NCHW4c) directly (w/o temporary buffer) to the
   output buffer (NCHW). In this case we have to use SSA because we
   need to get only one channel from the pixel. In case of storing to
   the local buffer the SSA was used because the buffer was allocated
   in kernel and the logic was written that if the buffer was allocated
   then we should use SSA. But if we store the same texture directly to
   the output buffer then SSA wasn't used and this OpenCL code wasn't
   compiled.
2. Casting texture (NCHW4c) to another data type and then storing it to
   the buffer (NCHW). The SSA for textures was disabled in case of cast
   operation. As a result it was necessary to take an channel from the
   pixel but we got the vector data type (e.g. float4) and then we
   tried to cast it to scalar data type. This code also wasn't
   compiled.

In this PR SSA form was enabled for all cases when `texture2d_load` is
used. The relevant tests cases were added.

* Add regression test on injective

* Fix lit

* Add skip for FP16 test

* Add additional test cases

* Fix lint

* Apply comment

* Fix lint

---------

Co-authored-by: Andrey Malyshev <[email protected]>
  • Loading branch information
echuraev and elvin-n authored Mar 30, 2023
1 parent 5cca18b commit 4011280
Show file tree
Hide file tree
Showing 5 changed files with 469 additions and 50 deletions.
48 changes: 8 additions & 40 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,33 +382,6 @@ std::string CodeGenOpenCL::CastTo(std::string value, DataType target) {
return os.str();
}

void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) {
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::texture2d_load())) {
need_texture_ssa_ = false;
// If storing a texture load into a buffer, don't use an
// intermediate local unless the buffer allocation is a
// single element selected from the texture read.
auto it = allocation_size_.find(op->buffer->data.get());
if (it != allocation_size_.end() && it->second == 1) {
need_texture_ssa_ = true;
}
}
}
CodeGenC::VisitStmt_(op);
need_texture_ssa_ = true;
}

void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) {
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::texture2d_load())) {
need_texture_ssa_ = false;
}
}
CodeGenC::VisitExpr_(op, os);
need_texture_ssa_ = true;
}

void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) {
allocation_size_.insert({op->buffer_var.get(), op->ConstantAllocationSize() * op->dtype.lanes()});
CodeGenC::VisitStmt_(op);
Expand Down Expand Up @@ -472,20 +445,15 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[2], ss);
ss << ")))";

// Only use local SSA if texture is not already being stored
if (need_texture_ssa_) {
std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4));
if (op->args.back().as<RampNode>()) {
os << rhs;
} else {
os << "((";
this->PrintType(op->dtype.with_lanes(1), os);
os << "*)&" << rhs << ")[";
this->PrintExpr(op->args.back(), os);
os << "]";
}
std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4));
if (op->args.back().as<RampNode>()) {
os << rhs;
} else {
os << ss.str();
os << "((";
this->PrintType(op->dtype.with_lanes(1), os);
os << "*)&" << rhs << ")[";
this->PrintExpr(op->args.back(), os);
os << "]";
}
} else if (op->op.same_as(builtin_call_extern_)) {
auto func = Downcast<StringImm>(op->args[0]);
Expand Down
5 changes: 0 additions & 5 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*)

// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode* op, std::ostream& os) final;
Expand All @@ -86,9 +84,6 @@ class CodeGenOpenCL final : public CodeGenC {
// Whether to enable sampler or sampler-less texture reads,
// where the choice depends on the OpenCL version used.
bool enable_compliant_texture_reads_{false};
// Key to disable use of texture SSA in certain scenarios. For example,
// when loaded value is stored directly to a user declared l-value buffer
bool need_texture_ssa_{true};
// Mapping from buffer to allocation size.
// Useful to track when a scalar store of a vectorized texture load is required.
std::unordered_map<const Object*, size_t> allocation_size_;
Expand Down
85 changes: 85 additions & 0 deletions tests/python/relay/opencl_texture/test_injection_texture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re
import pytest
import tvm
import numpy as np
from tvm import relay
from tvm.relay import testing
from tvm.contrib import utils
from utils.adreno_utils import gpu_preprocess, build_run_compare


dtype = tvm.testing.parameter("float32")


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_layout_transform_to_block_nchw4c(remote, target, dtype):
"""Verification of the case NCHW->NCHW4c"""
input_shape = (1, 32, 720, 1280)
A = relay.var("data", shape=input_shape, dtype=dtype)
lt = relay.layout_transform(A, "NCHW", "NCHW4c")
mod = relay.Function([A], lt)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_layout_transform_to_block_nchw(remote, target, dtype):
"""Verification of the case NCHW4c->NCHW"""
input_shape = (1, 36, 1, 1, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
lt = relay.layout_transform(A, "NCHW4c", "NCHW")
mod = relay.Function([A], lt)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_layout_transform_to_block_nhwc4c(remote, target, dtype):
"""Verification of the case NHWC->NHWC4c"""
input_shape = (1, 1, 1, 144)
A = relay.var("data", shape=input_shape, dtype=dtype)
lt = relay.layout_transform(A, "NHWC", "NHWC4c")
mod = relay.Function([A], lt)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@pytest.mark.skipif(
tvm.testing.utils.IS_IN_CI, reason="Skip because GPU in CI doesn't support FP16"
)
@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_layout_transform_to_block_nhwc(remote, target, dtype):
"""Verification of the case NHWC4c->NHWC"""
input_shape = (1, 80, 80, 36, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=[1, 2], keepdims=True)
cast = relay.cast(mean, "float16")
lt = relay.layout_transform(cast, "NHWC4c", "NHWC")
mod = relay.Function([A], lt)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


if __name__ == "__main__":
test_layout_transform_to_block_nhwc(None, "opencl -device=adreno", "float16")
6 changes: 1 addition & 5 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,4 @@ def check_type_casting(ctx, n, dtype):


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
test_opencl_max()
test_opencl_erf()
test_opencl_type_casting()
tvm.testing.main()
Loading

0 comments on commit 4011280

Please sign in to comment.