diff --git a/src/argument.cpp b/src/argument.cpp index ba8e821d632..733e27acac1 100644 --- a/src/argument.cpp +++ b/src/argument.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -102,6 +102,24 @@ void argument::assign_buffer(std::function d) })(s); } +std::vector flatten(const std::vector& args) +{ + std::vector result; + for(const auto& arg : args) + { + if(arg.get_shape().type() == shape::tuple_type) + { + auto subs = flatten(arg.get_sub_objects()); + result.insert(result.end(), subs.begin(), subs.end()); + } + else + { + result.push_back(arg); + } + } + return result; +} + std::vector to_shapes(const std::vector& args) { std::vector shapes; diff --git a/src/cpp_generator.cpp b/src/cpp_generator.cpp index da292b83cce..433ccaadb5b 100644 --- a/src/cpp_generator.cpp +++ b/src/cpp_generator.cpp @@ -38,6 +38,7 @@ inline namespace MIGRAPHX_INLINE_NS { cpp_generator::function& cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g) { + const std::string prefix = "zz"; std::unordered_map names; std::stringstream ss; @@ -53,12 +54,13 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate } else if(ins->name() == "@return") { - assert(ins->inputs().size() == 1); - return_ins = ins->inputs().front(); + names[ins] = prefix + "return"; + ss << "auto " << names[ins] << " = " << g(ins, names) << ";\n"; + return_ins = ins; } else { - std::string n = "z" + std::to_string(names.size()); + std::string n = prefix + std::to_string(names.size()); names[ins] = n; ss << "auto " << n << " = " << g(ins, names) << ";\n"; } @@ -125,6 +127,7 @@ struct cpp_generator_impl std::function fmap = nullptr; std::function fresult = nullptr; std::unordered_map point_op_map = {}; + bool always_return_tuple = false; }; cpp_generator::cpp_generator() : impl(std::make_unique()) {} @@ -142,6 +145,8 @@ void cpp_generator::fmap(const std::function& f) { imp void cpp_generator::fresult(const std::function& f) { impl->fresult = f; } +void cpp_generator::always_return_tuple(bool b) { impl->always_return_tuple = b; } + void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) { impl->point_op_map[op_name] = code; @@ -222,6 +227,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m, }); return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")"; } + if(ins->name() == "@return") + { + // TODO: Customize the make_tuple call + if(impl->always_return_tuple or ins->inputs().size() != 1) + return "make_tuple(" + join_strings(to_args(ins->inputs(), names), ", ") + ")"; + return names.at(ins->inputs().front()); + } auto s = g(ins, names); if(impl->fresult) return impl->fresult(ins->get_shape()) + '(' + s + ')'; diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 10b560ff52e..1f2e5429e52 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -113,18 +113,17 @@ static void create_pointwise_modules(module_pass_manager& mpm) } } -static std::vector append_pointwise_module(instruction_ref ins, - instruction_ref output) +static module::with_inputs append_pointwise_module(instruction_ref ins, instruction_ref output) { assert(contains(output->inputs(), ins)); - module_ref pm = ins->module_inputs().at(0); + module pm = *ins->module_inputs().at(0); module_ref xm = output->module_inputs().at(0); - auto last = std::prev(pm->end()); + auto last = std::prev(pm.end()); assert(last->name() == "@return"); assert(last->inputs().size() == 1); - assert(pm->get_parameter_names().size() == ins->inputs().size()); + assert(pm.get_parameter_names().size() == ins->inputs().size()); assert(xm->get_parameter_names().size() == output->inputs().size()); std::vector inputs = ins->inputs(); @@ -134,8 +133,8 @@ static std::vector append_pointwise_module(instruction_ref ins, for(auto i : range(inputs.size())) { auto input = inputs[i]; - auto param = pm->get_parameter("x" + std::to_string(i)); - assert(param != pm->end()); + auto param = pm.get_parameter("x" + std::to_string(i)); + assert(param != pm.end()); input_map[input] = param; } // Add the new parameter and additional inputs @@ -157,20 +156,20 @@ static std::vector append_pointwise_module(instruction_ref ins, else { map_ins[param] = - pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); + pm.add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); inputs.push_back(input); input_map[input] = map_ins[param]; } } - pm->replace_return(pm->insert_instructions(last, xm, &map_ins)); - return inputs; + pm.replace_return(pm.insert_instructions(last, xm, &map_ins)); + return {std::move(pm), inputs}; } -static bool find_pointwise_modules(module& m) +static bool find_pointwise_modules(module_pass_manager& mpm) { bool changed = false; - auto last = std::prev(m.end()); - for(auto ins : iterator_for(m)) + auto last = std::prev(mpm.get_module().end()); + for(auto ins : iterator_for(mpm.get_module())) { if(ins->name() != "pointwise") continue; @@ -183,10 +182,11 @@ static bool find_pointwise_modules(module& m) continue; auto input = *it; - auto new_inputs = append_pointwise_module(input, ins); - m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs()); - m.replace_instruction(ins, input); - m.move_instruction(input, ins); + auto fused = append_pointwise_module(input, ins); + auto name = fused.mod.name(); + mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted"); + auto* new_pm = mpm.create_module(name, std::move(fused.mod)); + mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm}); changed = true; } @@ -213,7 +213,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const for(int i = 0; i < 8; i++) { mpm.run_pass(rewrite_reshapes{}); - if(not find_pointwise_modules(mpm.get_module())) + if(not find_pointwise_modules(mpm)) break; mpm.run_pass(dead_code_elimination{}); } diff --git a/src/include/migraphx/argument.hpp b/src/include/migraphx/argument.hpp index 30c0df40c56..f07f2f0c8f8 100644 --- a/src/include/migraphx/argument.hpp +++ b/src/include/migraphx/argument.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -117,6 +117,8 @@ struct MIGRAPHX_EXPORT argument : raw_data data_t m_data{}; }; +std::vector flatten(const std::vector& args); + MIGRAPHX_EXPORT std::vector to_shapes(const std::vector& args); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const argument& a); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, argument& a); diff --git a/src/include/migraphx/cpp_generator.hpp b/src/include/migraphx/cpp_generator.hpp index ef052558051..9f34ba159d8 100644 --- a/src/include/migraphx/cpp_generator.hpp +++ b/src/include/migraphx/cpp_generator.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -95,6 +95,8 @@ struct MIGRAPHX_EXPORT cpp_generator void fresult(const std::function& f); + void always_return_tuple(bool b = true); + void add_point_op(const std::string& op_name, const std::string& code); std::string generate_point_op(const operation& op, const std::vector& args); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index f9d41121159..254a4da8a07 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -202,6 +202,20 @@ struct MIGRAPHX_EXPORT module instruction_ref begin() const; instruction_ref end() const; + struct compute_shapes_options + { + std::string name = "compute_shapes"; + bool strict_type = false; + bool strict_lens = false; + std::vector scalar_const_out_lens = {}; + }; + + /// Compute a new ouput shape by replacing each parameter with input + /// shapes passed in. + std::vector compute_shapes(const std::vector& inputs, + compute_shapes_options options) const; + std::vector compute_shapes(const std::vector& inputs) const; + std::vector get_output_shapes() const; instruction_ref validate() const; diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 6e7d6f92f30..c76276a9f08 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,23 +45,18 @@ struct pointwise { MIGRAPHX_THROW("should have one submodule."); } - auto* pm = mods.front(); - if(pm->get_output_shapes().size() != 1) - MIGRAPHX_THROW("pointwise should have only one output."); if(inputs.empty()) MIGRAPHX_THROW("pointwise should have at least one input"); + auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - std::sort(pnames.begin(), pnames.end()); check_shapes{inputs, *this}.has(pnames.size()).same_dims(); - auto type = pm->get_output_shapes().front().type(); - - // Scalar output if all inputs are scalar - if(inputs.front().elements() == 1 and - all_of(inputs, [](const auto& s) { return s.scalar(); })) - return shape{type}; - - return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs)); + auto result = pm->compute_shapes( + inputs, + {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + if(result.size() == 1) + return result.front(); + return shape{result}; } argument compute(const shape& output_shape, @@ -75,7 +70,7 @@ struct pointwise auto pnames = pm->get_parameter_names(); std::sort(pnames.begin(), pnames.end()); - par_for(output_shape.elements(), [&](auto i) { + par_for(args[0].get_shape().elements(), [&](auto i) { std::unordered_map params; std::transform( @@ -86,8 +81,15 @@ struct pointwise [&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); }); auto results = run(pm, params); - assert(results.size() == 1); - visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); }); + assert(results.size() == output.get_sub_objects().size() or + (results.size() == 1 and output.get_sub_objects().empty())); + std::vector outputs; + if(results.size() == 1) + outputs = {output.share()}; + else + outputs = output.share().get_sub_objects(); + for(auto j : range(results.size())) + visit_all(outputs[j], results[j])([&](auto out, auto x) { out[i] = x.front(); }); }); return output; } diff --git a/src/include/migraphx/pass_manager.hpp b/src/include/migraphx/pass_manager.hpp index fdbdc123a12..8ba2f64925e 100644 --- a/src/include/migraphx/pass_manager.hpp +++ b/src/include/migraphx/pass_manager.hpp @@ -40,6 +40,7 @@ struct module_pass_manager virtual module& get_module() = 0; virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name, module m) = 0; + virtual void rename_module(const std::string& old_name, const std::string& new_name) = 0; virtual module* get_common_parent() = 0; virtual module* get_root_module() = 0; virtual void run_pass(const pass& p) = 0; diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 5a4884faca4..e86ba628656 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -154,6 +154,7 @@ struct MIGRAPHX_EXPORT program std::unordered_multimap get_module_tree(); void remove_module(const std::string& name); + void rename_module(const std::string& old_name, const std::string& new_name); void remove_unused_modules(); private: diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 91a9deb20e8..19373bab6d4 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -191,6 +191,14 @@ struct raw_data : raw_data_base ss << static_cast(*this); return ss.str(); } + + template + std::vector to_vector() const + { + std::vector result(static_cast(*this).get_shape().elements()); + this->visit([&](auto x) { result.assign(x.begin(), x.end()); }); + return result; + } }; namespace detail { diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index b84dbaa8728..d2b25b091fc 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -428,6 +428,9 @@ struct MIGRAPHX_EXPORT shape std::shared_ptr impl; }; +/// Flatten subshapes to a single vector of non-tuple type of shapes +std::vector flatten(const std::vector& shapes); + MIGRAPHX_EXPORT void migraphx_to_value(value& v, const shape& s); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, shape& s); diff --git a/src/module.cpp b/src/module.cpp index 8c28c15d22d..a90328db6f4 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -663,6 +663,71 @@ std::vector module::get_output_shapes() const } } +std::vector module::compute_shapes(const std::vector& inputs, + compute_shapes_options options) const +{ + auto params = this->get_parameter_names(); + std::sort(params.begin(), params.end()); + std::unordered_map ins_shapes; + std::unordered_map adjusted_param_shapes; + std::transform(inputs.begin(), + inputs.end(), + params.begin(), + std::inserter(adjusted_param_shapes, adjusted_param_shapes.end()), + [](auto ps, auto name) { return std::make_pair(name, ps); }); + for(auto ins : iterator_for(*this)) + { + if(ins->name() == "@param") + { + ins_shapes[ins] = + adjusted_param_shapes[any_cast(ins->get_operator()).parameter]; + if(options.strict_type and ins->get_shape().type() != ins_shapes[ins].type()) + { + MIGRAPHX_THROW(options.name + ": Mismatched type: expected " + + ins->get_shape().type_string() + " but passed " + + ins_shapes[ins].type_string()); + } + if(options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) + { + MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" + + to_string_range(ins->get_shape().lens()) + "} but passed {" + + to_string_range(ins_shapes[ins].lens()) + "}"); + } + } + else if(ins->name() == "@literal") + { + if(not options.scalar_const_out_lens.empty() and ins->get_shape().scalar()) + { + std::vector strides(options.scalar_const_out_lens.size()); + ins_shapes[ins] = + shape{ins->get_shape().type(), options.scalar_const_out_lens, strides}; + } + else + { + ins_shapes[ins] = ins->get_shape(); + } + } + else + { + std::vector input_shapes; + input_shapes.resize(ins->inputs().size()); + std::transform(ins->inputs().begin(), + ins->inputs().end(), + input_shapes.begin(), + [&](auto in) { return ins_shapes.at(in); }); + if(ins->name() == "@return") + return input_shapes; + ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); + } + } + MIGRAPHX_THROW("No return found in the submodule"); +} + +std::vector module::compute_shapes(const std::vector& inputs) const +{ + return compute_shapes(inputs, {}); +} + std::vector module::get_returns() const { auto last = std::prev(this->end()); diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index 3748e094773..af8f0e4d6a7 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -105,6 +105,15 @@ struct module_pm : module_pass_manager return prog->create_module(name, std::move(m)); } + virtual void rename_module(const std::string& old_name, const std::string& new_name) override + { + assert(prog); + assert(mod); + assert( + any_of(mod->get_sub_modules(), [&](module_ref sm) { return sm->name() == old_name; })); + prog->rename_module(old_name, new_name); + } + virtual module* get_common_parent() override { return common_parent; } virtual module* get_root_module() override diff --git a/src/program.cpp b/src/program.cpp index 91935fba0e0..ec7b66d178f 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1209,6 +1209,17 @@ void program::remove_module(const std::string& name) impl->modules.erase(name); } +void program::rename_module(const std::string& old_name, const std::string& new_name) +{ + assert(old_name != new_name); + assert(contains(impl->modules, old_name)); + assert(not contains(impl->modules, new_name)); + auto node = impl->modules.extract(old_name); + node.key() = new_name; + node.mapped().set_name(new_name); + impl->modules.insert(std::move(node)); +} + void program::remove_unused_modules() { std::vector unused; diff --git a/src/shape.cpp b/src/shape.cpp index 79294debcb1..073089c2cd2 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -728,6 +728,24 @@ shape::type_t shape::parse_type(const std::string& s) const std::vector& shape::sub_shapes() const { return impl->m_shapes; } +std::vector flatten(const std::vector& shapes) +{ + std::vector result; + for(const auto& s : shapes) + { + if(s.type() == shape::tuple_type) + { + auto subs = flatten(s.sub_shapes()); + result.insert(result.end(), subs.begin(), subs.end()); + } + else + { + result.push_back(s); + } + } + return result; +} + void migraphx_to_value(value& v, const shape& s) { value result; diff --git a/src/targets/gpu/code_object_op.cpp b/src/targets/gpu/code_object_op.cpp index 67c9d59472e..3f640e59d63 100644 --- a/src/targets/gpu/code_object_op.cpp +++ b/src/targets/gpu/code_object_op.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ shape code_object_op::compute_shape(std::vector inputs) const std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) { return s.normalize_standard(); }); - if(einputs != inputs) + if(einputs != flatten(inputs)) MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" + to_string_range(inputs) + "]"); return output; @@ -48,9 +48,10 @@ shape code_object_op::compute_shape(std::vector inputs) const argument code_object_op::compute(context& ctx, const shape&, const std::vector& args) const { - std::vector kargs(args.size()); + auto fargs = flatten(args); + std::vector kargs(fargs.size()); std::transform( - args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); + fargs.begin(), fargs.end(), kargs.begin(), [](const argument& a) { return a.data(); }); auto [start, stop] = ctx.get_perf_events(); k.launch(ctx.get_stream().get(), global, local, std::move(kargs), start, stop); return args[get_output_arg(args.size())]; diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 2469c6571a8..d5148dcebcc 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -180,12 +180,16 @@ std::string make_transformer_args(std::vector transformers) return join_strings(std::move(transformers), ", "); } -void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) +static void generate_pointwise(cpp_generator& gg, + const module& pm, + const std::string& name, + bool always_return_tuple = false) { module m = pm; run_passes(m, {rewrite_quantization{}, optimize_module{}}); m.sort(); cpp_generator g; + g.always_return_tuple(always_return_tuple); g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); @@ -202,10 +206,10 @@ void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& .set_generic_types(m) .set_name(name)); } -std::string generate_pointwise(const module& pm, const std::string& name) +std::string generate_pointwise(const module& pm, const std::string& name, bool always_return_tuple) { cpp_generator g; - generate_pointwise(g, pm, name); + generate_pointwise(g, pm, name, always_return_tuple); return g.str(); } diff --git a/src/targets/gpu/compile_pointwise.cpp b/src/targets/gpu/compile_pointwise.cpp index c004066e49f..ee682cf2c3a 100644 --- a/src/targets/gpu/compile_pointwise.cpp +++ b/src/targets/gpu/compile_pointwise.cpp @@ -36,7 +36,7 @@ namespace gpu { operation compile_pointwise(context& ctx, const std::vector& in_shapes, const_module_ref pm) { - auto pf = gen::generate_pointwise(*pm, "inner_pointwise"); + auto pf = gen::generate_pointwise(*pm, "inner_pointwise", true); std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)"; auto kernel_name = gen::generate_name_from_ops(*pm, "kernel"); return gpu::compile_op("pointwise", diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a0a16512358..09da4758418 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -150,59 +150,11 @@ struct mlir_op if(inputs.size() < 2) MIGRAPHX_THROW("should have at least two inputs."); - auto type = mod->get_output_shapes().front().type(); - auto mod_params = mod->get_parameter_names(); - std::sort(mod_params.begin(), mod_params.end()); - std::unordered_map mod_ins_shapes; - std::unordered_map adjusted_mod_param_shapes; - std::transform(inputs.begin(), - inputs.end(), - mod_params.begin(), - std::inserter(adjusted_mod_param_shapes, adjusted_mod_param_shapes.end()), - [](auto ps, auto name) { return std::make_pair(name, ps); }); - for(auto ins : iterator_for(*mod)) - { - if(ins->name() == "@param") - { - mod_ins_shapes[ins] = - adjusted_mod_param_shapes[any_cast(ins->get_operator()) - .parameter]; - if(ins->get_shape().type() != mod_ins_shapes[ins].type()) - { - MIGRAPHX_THROW( - "MLIR_OP: adjusted mod parameter doesn't have the same type lens as " - "original input. Type changed from : " + - ins->get_shape().type_string() + " to " + - mod_ins_shapes[ins].type_string()); - } - if(ins->get_shape().lens() != mod_ins_shapes[ins].lens()) - { - MIGRAPHX_THROW("MLIR_OP: adjusted mod parameter doesn't have the same lens as " - "original input. Lens changed from " + - to_string_range(ins->get_shape().lens()) + " to " + - to_string_range(mod_ins_shapes[ins].lens())); - } - } - else if(ins->name() == "@literal") - { - mod_ins_shapes[ins] = ins->get_shape(); - } - else if(ins->name() == "@return") - { - return mod_ins_shapes[ins->inputs().at(0)].with_type(type); - } - else - { - std::vector input_shapes; - input_shapes.resize(ins->inputs().size()); - std::transform(ins->inputs().begin(), - ins->inputs().end(), - input_shapes.begin(), - [&](auto in) { return mod_ins_shapes[in]; }); - mod_ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); - } - } - MIGRAPHX_THROW("No return found in the submodule"); + auto result = + mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); + if(result.size() == 1) + return result.front(); + return shape{result}; } }; MIGRAPHX_REGISTER_OP(mlir_op); diff --git a/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp index 0ed50920584..e8f0be9f2d9 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp @@ -72,7 +72,8 @@ std::string make_transformer_args(Ts... xs) return make_transformer_args({xs.str()...}); } -std::string generate_pointwise(const module& pm, const std::string& name); +std::string +generate_pointwise(const module& pm, const std::string& name, bool always_return_tuple = false); std::string generate_reduce(module m, const std::string& name); diff --git a/src/targets/gpu/jit/pointwise.cpp b/src/targets/gpu/jit/pointwise.cpp index 03b6660b5bf..42beb6de070 100644 --- a/src/targets/gpu/jit/pointwise.cpp +++ b/src/targets/gpu/jit/pointwise.cpp @@ -48,7 +48,7 @@ extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { auto idx = make_index(); - pointwise(idx, ${transformers})(${lambda}, ${args}); + pointwise<${noutputs}>(idx, ${transformers})(${lambda}, ${args}); } } @@ -71,22 +71,25 @@ struct pointwise_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { hip_compile_options options; - options.inputs = inputs; + options.inputs = flatten(inputs); options.output = inputs.back(); - options.virtual_inputs = reduce_dims(normalize_permutation(inputs)); + options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs)); options.emplace_param("-Wno-float-equal"); auto axis = find_fast_axis(options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs); options.kernel_name = v.get("kernel", "kernel"); options.set_launch_params( - v, compute_global_for(ctx, options.output.elements() / vec.size, 256)); - auto src = interpolate_string(pointwise_kernel, - {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + v, compute_global_for(ctx, options.inputs.front().elements() / vec.size, 256)); + auto noutputs = options.inputs.size() - inputs.size() + 1; + auto src = + interpolate_string(pointwise_kernel, + {{"kernel", options.kernel_name}, + {"params", enum_params(options.inputs.size(), "void * private_p")}, + {"args", enum_params(options.inputs.size(), "private_p")}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"noutputs", std::to_string(noutputs)}, + {"preamble", v.get("preamble", std::string{})}}); return compile_hip_code_object(src, options); } @@ -94,10 +97,10 @@ struct pointwise_compiler : compiler { if(contains({"layout", "contiguous"}, op.name())) { - return compile_op( - ctx, - to_shapes(ins->inputs()), - {{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}); + return compile_op(ctx, + to_shapes(ins->inputs()), + {{"lambda", "[](auto x) { return make_tuple(x); }"}, + {"kernel", op.name() + "_kernel"}}); } else { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 019350c54e0..fab865c0587 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -291,16 +291,39 @@ inline constexpr auto transform_args() return make_transform([](auto f, auto... xs) { return f(xs...); }); } -// Rotate the first argument to the last argument -inline constexpr auto rotate_last() +// Rotate the last N arguments to the first N arguments +template +constexpr auto rotate_last() { return make_transform([](auto f, auto... xs) { return sequence_c([&](auto... is) { constexpr auto size = sizeof...(is); - return f(arg_c<(is + size - 1) % size>()(xs...)...); + return f(arg_c<(is + size - N) % size>()(xs...)...); }); }); } +inline constexpr auto rotate_last() { return rotate_last<1>(); } + +// Pack the first N arguments +template +constexpr auto pack_first() +{ + return make_transform([](auto f, auto... xs) { + return sequence_c([&](auto... is) { + return sequence_c([&](auto... js) { + return f(pack(arg_c()(xs...)...), arg_c()(xs...)...); + }); + }); + }); +} + +// Rotate the last N arguments as the first argument packed +template +constexpr auto rotate_and_pack_last() +{ + return transform_args(rotate_last(), pack_first()); +} + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp index 4b5f9fc865c..e7dc2fd845e 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,21 +30,29 @@ #include #include #include +#include namespace migraphx { -template -__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) +template +__device__ void pointwise_tensor(index idx, F f, Output out, T x, Ts... xs) { - idx.global_stride(out.get_shape().elements(), - [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); }); + idx.global_stride(x.get_shape().elements(), [&](auto i) { + auto r = f(x[i], xs[i]...); + out([&](auto... outs) { + r([&](auto... rs) { + static_assert(sizeof...(outs) == sizeof...(rs)); + swallow{(outs[i] = implicit_conversion(rs))...}; + }); + }); + }); } -template +template __device__ auto pointwise(index idx, Transforms... transforms) { return [=](auto f, auto*... ps) { - auto t = transform_args(make_tensors(), rotate_last(), transforms...); + auto t = transform_args(make_tensors(), transforms..., rotate_and_pack_last()); t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); }); }; } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp index 8d197570ce2..a1242453516 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -215,42 +215,55 @@ inline __device__ auto coutln() return make_printer([](auto f) { f(); }, [] { printf("\n"); }); } -template -__device__ void print_each(F f, Ts... xs) +template +__device__ void unsafe_print_each(Stream s, T x, Ts... xs) { - each_args([&](auto x) { f() << x; }, xs...); + s << x; + each_args([&](auto xx) { s << ' ' << xx; }, xs...); } -template -__device__ void print_each_once(F f, Ts... xs) +template +__device__ void print_each(Stream s, Ts... xs) +{ + auto idx = make_index(); + for(auto i = 0; i < idx.nglobal(); i++) + { + if(i == idx.global) + unsafe_print_each(s, xs...); + __syncthreads(); + } +} + +template +__device__ void print_each_once(Stream s, Ts... xs) { auto idx = make_index(); if(idx.global == 0) - print_each(f, xs...); + unsafe_print_each(s, xs...); } template __device__ void print(Ts... xs) { - print_each(&cout, xs...); + print_each(cout(), xs...); } template __device__ void print_once(Ts... xs) { - print_each_once(&cout, xs...); + print_each_once(cout(), xs...); } template __device__ void println(Ts... xs) { - print_each(&cout, xs..., '\n'); + print_each(cout(), xs..., '\n'); } template __device__ void println_once(Ts... xs) { - print_each_once(&cout, xs..., '\n'); + print_each_once(cout(), xs..., '\n'); } } // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp new file mode 100644 index 00000000000..eceefa4714f --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/tuple.hpp @@ -0,0 +1,164 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_KERNELS_TUPLE_HPP +#define MIGRAPHX_GUARD_KERNELS_TUPLE_HPP + +#include + +namespace migraphx { + +namespace tuple_detail { + +template +struct element_storage +{ + [[no_unique_address]] T element; +}; + +template +constexpr const auto& get_element(const element_storage& x) +{ + return x.element; +} + +template +constexpr auto& get_element(element_storage& x) +{ + return x.element; +} + +template +struct tuple_storage; + +template +struct tuple_storage, Ts...> : element_storage... +{ + template + constexpr tuple_storage(Us... ys) : element_storage{ys}... + { + } + + template + constexpr auto operator()(F f) const + { + return f(static_cast&>(*this).element...); + } + + template + constexpr auto operator()(F f) + { + return f(static_cast&>(*this).element...); + } + + template + constexpr auto& operator[](IntegralConstant i) + { + static_assert(i < sizeof...(Ts), "Out of bounds tuple access"); + return get_element(*this); + } + + template + constexpr auto& operator[](IntegralConstant i) const + { + static_assert(i < sizeof...(Ts), "Out of bounds tuple access"); + return get_element(*this); + } + + constexpr index_constant size() const { return {}; } + constexpr auto empty() const { return size() == _c<0>; } +}; + +template +using tuple_base = tuple_detail::tuple_storage::type, Ts...>; + +} // namespace tuple_detail + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_TUPLE_OP(op, binary_op) \ + template \ + constexpr tuple& operator op(const tuple& rhs) \ + { \ + (*this)( \ + [&](auto&... xs) { rhs([&](const auto&... ys) { swallow{((xs op ys), 0)...}; }); }); \ + return *this; \ + } \ + template \ + friend constexpr auto operator binary_op(const tuple& lhs, const tuple& rhs) \ + { \ + using result = tuple() binary_op declval())...>; \ + return lhs([&](auto&... xs) { \ + return rhs([&](const auto&... ys) { return result{xs op ys...}; }); \ + }); \ + } + +template +struct tuple : tuple_detail::tuple_base +{ + using base = tuple_detail::tuple_base; + + template + constexpr tuple(Us... ys) : base(ys...) + { + } + + MIGRAPHX_DEVICE_TUPLE_OP(+=, +) + MIGRAPHX_DEVICE_TUPLE_OP(-=, -) + MIGRAPHX_DEVICE_TUPLE_OP(*=, *) + MIGRAPHX_DEVICE_TUPLE_OP(/=, /) + MIGRAPHX_DEVICE_TUPLE_OP(%=, %) + MIGRAPHX_DEVICE_TUPLE_OP(&=, &) + MIGRAPHX_DEVICE_TUPLE_OP(|=, |) + MIGRAPHX_DEVICE_TUPLE_OP(^=, ^) + + friend constexpr bool operator==(const tuple& x, const tuple& y) + { + return x([&](const auto&... xs) { + return y([&](const auto&... ys) { return ((xs == ys) and ...); }); + }); + } + friend constexpr bool operator!=(const tuple& x, const tuple& y) { return not(x == y); } + friend constexpr bool operator<(const tuple& x, const tuple& y) + { + return x([&](const auto&... xs) { + return y([&](const auto&... ys) { + fold([&](auto a, auto b) { return a == 0 ? b() : 0; })(0, [&] { + return (xs < ys) ? -1 : (ys < xs) ? 1 : 0; + }...); + }); + }); + } + friend constexpr bool operator>(const tuple& x, const tuple& y) { return y < x; } + friend constexpr bool operator<=(const tuple& x, const tuple& y) { return not(x > y); } + friend constexpr bool operator>=(const tuple& x, const tuple& y) { return not(x < y); } +}; + +template +constexpr tuple make_tuple(Ts... xs) +{ + return {xs...}; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_TUPLE_HPP diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 42b2f0f419c..5edb9e2916b 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -116,12 +116,10 @@ TEST_CASE(double_add_without_return) auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto z = mm->add_parameter("z", s); - auto fadd = - add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { - auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); - }); - mm->add_instruction(migraphx::make_op("identity"), fadd); + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); } EXPECT(p1.sort() == p2.sort()); } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 58ad9bad437..249a5aea9fa 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -294,7 +294,7 @@ TEST_CASE(compile_pointwise) migraphx::gpu::context ctx; auto co = migraphx::gpu::compile_op( - "pointwise", ctx, {input, input}, {{"lambda", "[](auto x) { return x + 1; }"}}); + "pointwise", ctx, {input, input}, {{"lambda", "[](auto x) { return make_tuple(x + 1); }"}}); migraphx::program p; auto* mm = p.get_main_module(); diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 1f961115d0c..333a7f9ef63 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -2466,6 +2466,20 @@ TEST_CASE(pointwise_no_output) EXPECT(test::throws([&] { mm->add_instruction(migraphx::make_op("pointwise"), args, {&m}); })); } +TEST_CASE(pointwise_strict_type) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::module pm; + { + auto x = pm.add_parameter("x", s.with_type(migraphx::shape::half_type)); + pm.add_return({x}); + } + auto x = mm->add_parameter("x", s); + EXPECT(test::throws([&] { mm->add_instruction(migraphx::make_op("pointwise"), {x}, {&pm}); })); +} + TEST_CASE(pooling_shape0) { migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; diff --git a/test/ref/pointwise.cpp b/test/ref/pointwise.cpp index a6c61664f96..cbfd8a2ea54 100644 --- a/test/ref/pointwise.cpp +++ b/test/ref/pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,9 +38,12 @@ TEST_CASE(pointwise_test) auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); auto* pm = p.create_module("pointwise"); - auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type}); - auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type}); - pm->add_instruction(migraphx::make_op("add"), x1, x2); + { + auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type}); + auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type}); + auto add = pm->add_instruction(migraphx::make_op("add"), x1, x2); + pm->add_return({add}); + } mm->add_instruction(migraphx::make_op("pointwise"), {l1, l2}, {pm}); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); @@ -49,3 +52,28 @@ TEST_CASE(pointwise_test) std::vector gold = {0, 2, 4}; EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } + +TEST_CASE(pointwise_multi_out_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto a1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto a2 = mm->add_literal(migraphx::literal{s, {1, 16, 3}}); + auto* pm = p.create_module("pointwise"); + { + auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type}); + auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type}); + auto add = pm->add_instruction(migraphx::make_op("add"), x1, x2); + auto sqrt = pm->add_instruction(migraphx::make_op("sqrt"), add); + pm->add_return({add, sqrt}); + } + mm->add_instruction(migraphx::make_op("pointwise"), {a1, a2}, {pm}); + p.compile(migraphx::make_target("ref")); + auto results = p.eval({}).back().get_sub_objects(); + + std::vector gold1 = {0, 16, 4}; + std::vector gold2 = {0, 4, 2}; + EXPECT(results[0].to_vector() == gold1); + EXPECT(results[1].to_vector() == gold2); +} diff --git a/test/verify/test_pointwise_multi_out.cpp b/test/verify/test_pointwise_multi_out.cpp new file mode 100644 index 00000000000..427cccf405e --- /dev/null +++ b/test/verify/test_pointwise_multi_out.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pointwise_multi_out : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto z1 = mm->add_parameter("z1", s); + auto z2 = mm->add_parameter("z2", s); + auto* pm = p.create_module("pointwise"); + { + auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type}); + auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type}); + auto add = pm->add_instruction(migraphx::make_op("add"), x1, x2); + auto abs = pm->add_instruction(migraphx::make_op("abs"), add); + auto sqrt = pm->add_instruction(migraphx::make_op("sqrt"), abs); + pm->add_return({add, sqrt}); + } + pm->set_bypass(); + auto pw = mm->add_instruction(migraphx::make_op("pointwise"), {z1, z2}, {pm}); + auto e0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), pw); + auto e1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), pw); + auto sub = mm->add_instruction(migraphx::make_op("sub"), e0, e1); + mm->add_return({sub}); + return p; + } +};