Skip to content

Commit

Permalink
add an extra tensor for datetime inputs (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
CloudManX authored Feb 10, 2021
1 parent ce987f5 commit aec7a70
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def _DateTimeVectorizer(op, inexpr, dshape, dtype, columns=None):

INPUT_FLOAT = 0
INPUT_STRING = 1
INPUT_DATETIME = 2

column_transformer_op_types = {
"RobustImputer": INPUT_FLOAT,
Expand All @@ -589,6 +590,7 @@ def _DateTimeVectorizer(op, inexpr, dshape, dtype, columns=None):
"RobustStandardScaler": INPUT_FLOAT,
"RobustOrdinalEncoder": INPUT_STRING,
"ThresholdOneHotEncoder": INPUT_STRING,
"DateTimeVectorizer": INPUT_DATETIME,
}


Expand Down Expand Up @@ -644,7 +646,8 @@ def from_auto_ml(model, shape=None, dtype="float32", func_name="transform"):
if func_name == "transform":
inexpr_float = _expr.var("input_float", shape=shape, dtype=dtype)
inexpr_string = _expr.var("input_string", shape=shape, dtype=dtype)
inexpr = [inexpr_float, inexpr_string]
inexpr_datetime = _expr.var("input_datetime", shape=shape, dtype=dtype)
inexpr = [inexpr_float, inexpr_string, inexpr_datetime]

if type(model.feature_transformer.steps[0][1]).__name__ != "ColumnTransformer":
raise NameError(
Expand Down

0 comments on commit aec7a70

Please sign in to comment.