Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Where op type reduction processing #9033

Merged
merged 3 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions onnxruntime/core/providers/op_kernel_type_control.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@

#pragma once

#include <cstdint>
#include <tuple>

#include "boost/mp11.hpp"

#include "core/common/type_list.h"
#include "core/common/type_set_utils.h"

#include "core/framework/data_types.h"

/**
* These utilities provide a way to control what types are enabled for an Op kernel implementation.
* At a high level, we have the notion of default, required, allowed, and enabled type sets.
Expand Down Expand Up @@ -471,5 +466,7 @@ struct EnabledTypes {
* using Dispatcher = onnxruntime::utils::MLTypeCallDispatcherFromTypeList<MyOpFirstInputEnabledTypes>;
*/

#include "core/framework/data_types.h" // for types that might be used in type specifications

// all allowed type specifications should be contained in the following file
#include "core/providers/op_kernel_type_control_overrides.inc"
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ def from_config_entry(self, entry: str):
self._output_types[int(o_str)] = set(values)


class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
'''
Processor for operators where the second input type is used in a typed kernel registration.
'''
def __init__(self, domain: str, optype: str):
# init with tracking of input 1 only.
super().__init__(domain, optype, inputs=[1], outputs=[])

def is_typed_registration_needed(self, type_in_registration: str,
globally_allowed_types: typing.Optional[typing.Set[str]]):
return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)


class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
'''
Processor for operators where the first output type is used in a typed kernel registration.
Expand Down Expand Up @@ -339,8 +352,7 @@ def add(processor):
'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin',
'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum',
'Tanh', 'TopK', 'Transpose',
'Unique',
'Where']
'Unique']

# ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
default_processor_onnx_ops_requiring_ints_for_input_0 = ['Add',
Expand Down Expand Up @@ -392,6 +404,9 @@ def add(processor):
onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial']
[add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops]

# Where always has a boolean first input so track the second input type for typed registration
add(Input1TypedRegistrationProcessor('ai.onnx', 'Where'))

# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
# as that's what is used in the typed registration
add(Output0TypedRegistrationProcessor('ai.onnx', 'QuantizeLinear'))
Expand Down