diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc index 0f8b8ee1718327..7d2d68a61f17c7 100644 --- a/binaries/aot_model_compiler.cc +++ b/binaries/aot_model_compiler.cc @@ -36,6 +36,10 @@ C10_DEFINE_string( "Input memory format." "If multiple inputs needed, use semicolon to separate." "Supported values: contiguous, channels_last"); +C10_DEFINE_string( + dynamic_dims, + "", + "Comma separated dimensions of input tensors that can be dynamic"); C10_DEFINE_string(method_name, "forward", "The name of the method."); C10_DEFINE_string( output_llvm, @@ -68,6 +72,7 @@ c10::Dict createCompileSpec() { method_spec.insert("sizes", FLAGS_input_dims); method_spec.insert("types", FLAGS_input_types); method_spec.insert("memory_formats", FLAGS_input_memory_formats); + method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims); method_spec.insert("asmfile", FLAGS_output_llvm); method_spec.insert("model_name", FLAGS_model_name); method_spec.insert("model_version", FLAGS_model_version); diff --git a/test/mobile/nnc/test_nnc_backend.cpp b/test/mobile/nnc/test_nnc_backend.cpp index f7adcb62459ff7..35bf60f2cca791 100644 --- a/test/mobile/nnc/test_nnc_backend.cpp +++ b/test/mobile/nnc/test_nnc_backend.cpp @@ -23,7 +23,9 @@ c10::Dict create_compile_spec( const std::string& method_name, const std::string& model_name, const std::string& input_shapes, - const std::string& input_types) { + const std::string& input_types, + const std::string& memory_formats, + const std::string& dynamic_sizes) { c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); @@ -33,6 +35,8 @@ c10::Dict create_compile_spec( method_spec.insert("model_version", "v1"); method_spec.insert("asmfile", "fake_nnc_model.s"); method_spec.insert("arch", "x86-64"); + method_spec.insert("memory_formats", memory_formats); + method_spec.insert("dynamic_sizes", dynamic_sizes); c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); @@ -63,7 +67,7 @@ REGISTER_NNC_KERNEL( TEST(NNCBackendTest, AOTCompileThenExecute) { torch::jit::Module m("m"); - auto param = torch::ones({}); + auto param = torch::ones({1}); m.register_parameter("param", param, false); m.define(R"( def forward(self, input): @@ -77,7 +81,7 @@ TEST(NNCBackendTest, AOTCompileThenExecute) { // Compile the model with NNC. auto compile_spec = create_compile_spec( - "forward", "_add_kernel_nnc_fake_model", "4,4", "float"); + "forward", "_add_kernel_nnc_fake_model", "4,4", "float", "", ""); auto any_dict_ty = c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); auto frozen_m = torch::jit::freeze_module(m.clone()); diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 60152d861d23d6..2b5a7bc7937f20 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -43,9 +43,18 @@ std::vector getConstSizes(const BufPtr b) { // Construct input-specs vector from the inputs of the original graph std::vector toInputSpecs( - const std::shared_ptr& g) { + const std::shared_ptr& kernel) { + const std::shared_ptr& g = kernel->graph(); std::vector specs; - for (auto v : g->inputs()) { + + // Graph inputs include scalar values for symbolic shapes, for which we + // don't need input specs. These scalar values come last among the graph + // inputs + auto num_inputs = + g->inputs().size() - kernel->getSymbolicShapeInputs().size(); + + for (int i = 0; i < num_inputs; i++) { + auto v = g->inputs()[i]; const auto& t = v->type(); mobile::nnc::InputSpec spec; TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type"); @@ -120,7 +129,7 @@ std::unique_ptr compileMethod( const std::vector& types) { auto func = std::make_unique(); func->set_name(method_name); - func->set_input_specs(toInputSpecs(kernel->graph())); + func->set_input_specs(toInputSpecs(kernel)); auto params = c10::impl::GenericList(c10::AnyType::get()); auto const_descriptors = kernel->getConstantDescriptors(); @@ -177,18 +186,33 @@ std::pair, const std::string> aotCompile( std::shared_ptr& g, const std::vector>& sizes, const std::vector& types, - const std::string& kernel_func_name) { + const std::string& kernel_func_name, + const std::vector& symbolic_ind) { GRAPH_DEBUG("Input sizes ", sizes); GRAPH_DEBUG("Input types ", types); GRAPH_DEBUG("Method name ", method_name); GRAPH_DEBUG("Kernel func name ", kernel_func_name); - - std::shared_ptr kernel = - std::make_shared( - TensorExprKernel(g, kernel_func_name)); + GRAPH_DEBUG("Symbolic indices ", symbolic_ind); + + std::shared_ptr kernel; + std::vector stride_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + if (!symbolic_ind.empty()) { + for (auto i : g->inputs()) { + symbolic_strides[i] = stride_desc; + } + for (auto o : g->outputs()) { + symbolic_strides[o] = stride_desc; + } + } + kernel = std::make_shared(TensorExprKernel( + g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides)); const std::string compiled_assembly = kernel->getCodeText(); - auto func = compileMethod(kernel, method_name, sizes, types); return std::make_pair(std::move(func), compiled_assembly); } @@ -271,6 +295,17 @@ std::vector parseInputMemoryFormats( return memFormats; } +std::vector parseInputDynamicShapes( + const std::string& dynamic_dims_s) { + std::vector dynamic_dims_list = split(',', dynamic_dims_s); + std::vector dynamic_dims; + dynamic_dims.reserve(dynamic_dims_list.size()); + for (const auto& dim : dynamic_dims_list) { + dynamic_dims.push_back(c10::stoi(dim)); + } + return dynamic_dims; +} + std::string getNncKernelId( const std::string& model_name, const std::string& model_version, @@ -288,9 +323,12 @@ std::string getNncKernelFuncName( return "nnc_" + model_name + "_" + model_version + "_" + method_name; } -std::shared_ptr preprocessGraphPasses( +// Preprocess the graph and returns the processed graph and +// symbolic values if dynamic input shapes are specified +std::pair, std::vector> preprocessGraphPasses( std::shared_ptr& graph, - const std::vector>& example_inputs) { + const std::vector>& example_inputs, + const std::vector& dynamic_sizes) { GRAPH_DEBUG("Before preprocessing graph passes: ", *graph); torch::jit::RemoveTensorMutation(graph); torch::jit::EliminateDeadCode(graph->block()); @@ -321,8 +359,12 @@ std::shared_ptr preprocessGraphPasses( RemoveTensorMutation(graph); EliminateDeadCode(graph); LowerAllTuples(graph); + + auto sym_val = + torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes); + GRAPH_DEBUG("After preprocessing graph passes: ", *graph); - return graph; + return std::make_pair(graph, sym_val); } std::vector> generateExampleInputs( @@ -335,8 +377,7 @@ std::vector> generateExampleInputs( const auto dtype = at::dtype(inputTypes[i]); const auto memory_format = inputMemoryFormats[i]; example_inputs.emplace_back( - at::rand(inputShapes[i], at::TensorOptions(dtype)) - .contiguous(memory_format)); + at::rand(inputShapes[i]).to(dtype).contiguous(memory_format)); } return example_inputs; } @@ -364,6 +405,8 @@ c10::IValue preprocess( auto sizes = parseInputShapes(*method_spec.at("sizes").toString()); auto types = parseInputTypes(*method_spec.at("types").toString()); + auto dynamic_sizes = + parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString()); std::string memory_formats_str = method_spec.contains("memory_formats") ? (*method_spec.at("memory_formats").toString()).string() @@ -374,12 +417,20 @@ c10::IValue preprocess( : parseInputMemoryFormats(memory_formats_str); auto example_inputs = generateExampleInputs(sizes, types, memory_formats); - graph = preprocessGraphPasses(graph, example_inputs); + auto preprocessed = + preprocessGraphPasses(graph, example_inputs, dynamic_sizes); auto kernel_func_name = getNncKernelFuncName(model_name, model_version, method_name); + auto processed_graph = preprocessed.first; + auto sym_values = preprocessed.second; auto compiled = torch::jit::mobile::nnc::aotCompile( - method_name, graph, sizes, types, kernel_func_name); + method_name, + processed_graph, + sizes, + types, + kernel_func_name, + sym_values); writeOutputLlvmAssembly(compiled.second, asmfile_name); auto func = std::move(compiled.first); func->set_nnc_kernel_id( diff --git a/torch/csrc/jit/mobile/nnc/context.cpp b/torch/csrc/jit/mobile/nnc/context.cpp index d56d06614085db..2fe30a65b4f610 100644 --- a/torch/csrc/jit/mobile/nnc/context.cpp +++ b/torch/csrc/jit/mobile/nnc/context.cpp @@ -41,7 +41,17 @@ c10::IValue InputSpec::serialize() const { } bool InputSpec::validate(const at::Tensor& input) const { - return input.sizes() == sizes_ && input.scalar_type() == dtype_; + if (sizes_.size() != input.sizes().size() || input.scalar_type() != dtype_) { + return false; + } + auto spec_sizes = sizes_; + for (int i = 0; i < spec_sizes.size(); i++) { + // InputSpec size 0 means that the dimension is dynamic + if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) { + return false; + } + } + return true; } OutputSpec::OutputSpec(const c10::IValue& value) { @@ -136,6 +146,14 @@ Function::Function(const c10::IValue& value) { // memory_plan_ memory_plan_ = MemoryPlan(dict.at("memory_plan")); + + // symbolic shape positions + for (const auto& sym_shape_pos : + dict.at("sym_shape_pos").toTupleRef().elements()) { + auto sym_shape_elements = sym_shape_pos.toTupleRef().elements(); + sym_shape_positions_.emplace_back( + sym_shape_elements[0].toInt(), sym_shape_elements[1].toInt()); + } } c10::IValue Function::serialize() const { @@ -185,18 +203,20 @@ void Function::init_execution_state() const { ExecutionState state; memory_plan_.allocate(&state); - // The arguments vector consists of 4 sections: inputs, outputs, parameters - // and buffers. + // The arguments vector consists of 5 sections: inputs, symbolic shapes, + // outputs, parameters and buffers. auto input_args = input_specs_.size(); + auto sym_shape_args = sym_shape_positions_.size(); auto output_args = output_specs_.size(); auto param_args = parameters_.size(); auto buffer_args = state.preallocations_.size(); auto& arguments = state.arguments_; - arguments.reserve(input_args + output_args + param_args + buffer_args); + arguments.reserve( + input_args + sym_shape_args + output_args + param_args + buffer_args); // Keep empty slots to fill in inputs/outputs pointers at execution time. - arguments.resize(input_args + output_args); + arguments.resize(input_args + sym_shape_args + output_args); // Fill in parameters as untyped raw pointers. // The underlying storage of the parameters should be owned by `parameters_`, @@ -233,7 +253,7 @@ c10::impl::GenericList Function::run( // Fill in input tensors. TORCH_CHECK( - input_specs_.size() == (inputs.size() + sym_shape_positions_.size()), + input_specs_.size() == inputs.size(), "Input size doesn't match the spec, expect: ", input_specs_.size(), " actual: ",