Skip to content

Commit

Permalink
Sklearn frontend updates (#157)
Browse files Browse the repository at this point in the history
* Add different path for multiple transform functions of different models and InverseLabelTransform support

* typo fixes and copy the value in identity transformation

* add shapefunc copy and inverse transform of NALabelEncoder

* typo fix

* add import tvm.testing

* add more opeartor support

* adding test cases - test

* fix shape mismatch for NALabelEncoder when using dynamic shapes

* update and merging for more operator support in sklearn frontend

* update and merging for more operator support in sklearn frontend: robustordinalencoder

* remove debug prints

* reformat with black

* pylint format fixes

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: srinidhigoud <[email protected]>
  • Loading branch information
3 people authored and Trevor Morris committed Nov 3, 2020
1 parent e41640d commit 77e3b98
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 43 deletions.
256 changes: 228 additions & 28 deletions python/tvm/relay/frontend/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=import-outside-toplevel

import numpy as np
import tvm
from tvm import relay
from tvm.ir import IRModule

from ... import nd as _nd
Expand Down Expand Up @@ -107,68 +107,268 @@ def _ThresholdOneHotEncoder(op, inexpr, dshape, dtype, columns=None):

def _RobustStandardScaler(op, inexpr, dshape, dtype, columns=None):
"""
Sagemaker-Scikit-Learn-Extension Transformer:
Standardize features by removing the mean and scaling to unit variance
Sagemaker-Scikit-Learn-Extension Transformer:
Standardize features by removing the mean and scaling to unit variance.
"""
scaler = op.scaler_
ret = _op.subtract(inexpr, _op.const(np.array(scaler.mean_, dtype), dtype))
ret = _op.divide(ret, _op.const(np.array(scaler.scale_, dtype), dtype))
return ret

def _ColumnTransformer(op, inexpr, dshape, dtype, columns=None):

def _ColumnTransformer(op, inexpr, dshape, dtype, func_name, columns=None):
"""
Scikit-Learn Compose:
Applies transformers to columns of an array
"""
out = []
for _, pipe, cols in op.transformers_:
mod = pipe.steps[0][1]
out.append(sklearn_op_to_relay(mod, inexpr, dshape, dtype, cols))
out.append(sklearn_op_to_relay(mod, inexpr, dshape, dtype, func_name, cols))

return _op.concatenate(out, axis=1)


def _InverseLabelTransformer(op, inexpr, dshape, dtype, columns=None):
"""
Identity transformation of the label data. The conversion to string happens in runtime.
"""
return _op.copy(inexpr)


def _RobustOrdinalEncoder(op, inexpr, dshape, dtype, columns=None):
"""
Sagemaker-Scikit-Learn-Extension Transformer:
Encode categorical features as an integer array additional feature of handling unseen values.
The input to this transformer should be an array-like of integers or strings, denoting the
values taken on by categorical (discrete) features. The features are converted to ordinal
integers. This results in a single column of integers (0 to n_categories - 1) per feature.
"""
if columns:
column_indices = _op.const(columns)
inexpr = _op.take(inexpr, indices=column_indices, axis=1)

num_cat = len(op.categories_)
cols = _op.split(inexpr, num_cat, axis=1)

out = []
for i in range(num_cat):
category = op.categories_[i]
cat_tensor = _op.const(np.array(category, dtype=dtype))
tiled_col = _op.tile(cols[i], (1, len(category)))
one_hot_mask = _op.equal(tiled_col, cat_tensor)
one_hot = _op.cast(one_hot_mask, dtype)

offset = _op.const(np.arange(-1, len(category) - 1, dtype=dtype))
zeros = _op.full_like(one_hot, _op.const(0, dtype=dtype))
ordinal_col = _op.where(one_hot_mask, _op.add(one_hot, offset), zeros)
ordinal = _op.expand_dims(_op.sum(ordinal_col, axis=1), -1)

seen_mask = _op.cast(_op.sum(one_hot, axis=1), dtype="bool")
seen_mask = _op.expand_dims(seen_mask, -1)
extra_class = _op.full_like(ordinal, _op.const(len(category), dtype=dtype))
robust_ordinal = _op.where(seen_mask, ordinal, extra_class)
out.append(robust_ordinal)

ret = _op.concatenate(out, axis=1)
return ret


def _RobustLabelEncoder(op, inexpr, dshape, dtype, columns=None):
"""
Sagemaker-Scikit-Learn-Extension Transformer:
Encode target labels with value between 0 and n_classes-1.
"""
if columns:
column_indices = _op.const(columns)
inexpr = _op.take(inexpr, indices=column_indices, axis=1)

class_mask = []
for i in range(len(op.classes_)):
val = (
_op.const(i, dtype) if is_inverse else _op.const(np.array(op.classes_[i], dtype), dtype)
)
class_mask.append(_op.equal(inexpr, val))
for i in range(len(op.classes_)):
if is_inverse:
label_mask = _op.full_like(
inexpr, _op.const(np.array(op.classes_[i], dtype), dtype=dtype)
)
else:
label_mask = _op.full_like(inexpr, _op.const(i, dtype=dtype))

if i == 0:
out = _op.where(class_mask[i], label_mask, inexpr)
continue
out = _op.where(class_mask[i], label_mask, out)

if op.fill_unseen_labels:
unseen_mask = class_mask[0]
for mask in class_mask[1:]:
unseen_mask = _op.logical_or(unseen_mask, mask)
unseen_mask = _op.logical_not(unseen_mask)
unseen_label = (
_op.const(-1, dtype=dtype)
if is_inverse
else _op.const(np.array(len(op.classes_)), dtype=dtype)
)
label_mask = _op.full_like(inexpr, unseen_label)
out = _op.where(unseen_mask, label_mask, out)

return out


def _NALabelEncoder(op, inexpr, dshape, dtype, columns=None):
"""
Sagemaker-Scikit-Learn-Extension Transformer:
Encoder for transforming labels to NA values which encode all non-float and non-finite values
as NA values.
"""
if columns:
column_indices = _op.const(columns)
inexpr = _op.take(inexpr, indices=column_indices, axis=1)

flattened_inexpr = _op.reshape(inexpr, newshape=(-1, 1))
# Hardcoded flattened shape to be (?, 1)
flattened_dshape = (relay.Any(), 1)
ri_out = _RobustImputer(op.model_, flattened_inexpr, flattened_dshape, dtype)
ret = _op.reshape(ri_out, newshape=-1)
return ret


def _RobustStandardScaler(op, inexpr, dshape, dtype, columns=None):
"""
Sagemaker-Scikit-Learn-Extension Transformer:
Standardize features by removing the mean and scaling to unit variance.
"""
if columns:
column_indices = _op.const(columns)
inexpr = _op.take(inexpr, indices=column_indices, axis=1)

scaler = op.scaler_
ret = _op.subtract(inexpr, _op.const(np.array(scaler.mean_, dtype), dtype))
ret = _op.divide(ret, _op.const(np.array(scaler.scale_, dtype), dtype))
return ret


def _KBinsDiscretizer(op, inexpr, dshape, dtype, columns=None):
"""
Scikit-Learn Transformer:
Bin continuous data into intervals.
"""
if columns:
column_indices = _op.const(columns)
inexpr = _op.take(inexpr, indices=column_indices, axis=1)

bin_edges = np.transpose(np.vstack(op.bin_edges_))
out = _op.full_like(inexpr, _op.const(0, dtype=dtype))

for i in range(1, len(bin_edges) - 1):
indices_mask = _op.full_like(inexpr, _op.const(i, dtype=dtype))
bin_edge = _op.const(bin_edges[i])
bin_mask = _op.greater_equal(inexpr, bin_edge)
out = _op.where(bin_mask, indices_mask, out)

return out


def _TfidfVectorizer(op, inexpr, dshape, dtype, columns=None):
"""
Scikit-Learn Transformer:
Transform a count matrix to a normalized tf or tf-idf representation.
"""
if op.use_idf:
idf = _op.const(np.array(op.idf_, dtype=dtype), dtype=dtype)
tfidf = _op.multiply(idf, inexpr)
if op.sublinear_tf:
tfidf = _op.add(tfidf, _op.const(1, dtype))
ret = _op.nn.l2_normalize(tfidf, eps=0.0001, axis=[1])
else:
ret = _op.nn.l2_normalize(inexpr, eps=0.0001, axis=[1])

return ret


def _PCA(op, inexpr, dshape, dtype, columns=None):
"""
Scikit-Learn Transformer:
PCA transformation with existing eigen vector.
"""
eigvec = _op.const(np.array(op.components_, dtype))
ret = _op.nn.dense(inexpr, eigvec)
return ret


_convert_map = {
'ColumnTransformer':_ColumnTransformer,
'SimpleImputer': _SimpleImputer,
'RobustImputer': _RobustImputer,
'RobustStandardScaler': _RobustStandardScaler,
'ThresholdOneHotEncoder': _ThresholdOneHotEncoder
"ColumnTransformer": {"transform": _ColumnTransformer},
"SimpleImputer": {"transform": _SimpleImputer},
"RobustImputer": {"transform": _RobustImputer},
"RobustStandardScaler": {"transform": _RobustStandardScaler},
"ThresholdOneHotEncoder": {"transform": _ThresholdOneHotEncoder},
"NALabelEncoder": {"transform": _NALabelEncoder, "inverse_transform": _InverseLabelTransformer},
"RobustLabelEncoder": {"inverse_transform": _InverseLabelTransformer},
"RobustOrdinalEncoder": {"transform": _RobustOrdinalEncoder},
"KBinsDiscretizer": {"transform": _KBinsDiscretizer},
"TfidfVectorizer": {"transform": _TfidfVectorizer},
"PCA": {"transform": _PCA},
}

def sklearn_op_to_relay(op, inexpr, dshape, dtype, columns=None):

def sklearn_op_to_relay(op, inexpr, dshape, dtype, func_name, columns=None):
"""
Convert Sklearn Ops to Relay Ops.
"""
classname = type(op).__name__
return _convert_map[classname](op, inexpr, dshape, dtype, columns)

def from_sklearn(model,
shape=None,
dtype="float32",
columns=None):
if classname not in _convert_map:
raise NameError("Model {} not supported in scikit-learn frontend".format(classname))
if func_name not in _convert_map[classname]:
raise NameError(
"Function {} of Model {} not supported in scikit-learn frontend".format(
func_name, classname
)
)

if classname == "ColumnTransformer":
return _convert_map[classname][func_name](op, inexpr, dshape, dtype, func_name, columns)

return _convert_map[classname][func_name](op, inexpr, dshape, dtype, columns)


def from_sklearn(model, shape=None, dtype="float32", func_name="transform", columns=None):
"""
Import scikit-learn model to Relay.
"""
try:
import sklearn
except ImportError as e:
raise ImportError(
"Unable to import scikit-learn which is required {}".format(e))

inexpr = _expr.var('input', shape=shape, dtype=dtype)
outexpr = sklearn_op_to_relay(model, inexpr, shape, dtype, columns)
raise ImportError("Unable to import scikit-learn which is required {}".format(e))

inexpr = _expr.var("input", shape=shape, dtype=dtype)
outexpr = sklearn_op_to_relay(model, inexpr, shape, dtype, func_name, columns)

func = _function.Function(analysis.free_vars(outexpr), outexpr)
return IRModule.from_expr(func), []

def from_auto_ml(model,
shape=None,
dtype="float32"):

def from_auto_ml(model, shape=None, dtype="float32", func_name="transform"):
"""
Import scikit-learn model to Relay.
"""
try:
import sklearn
except ImportError as e:
raise ImportError(
"Unable to import scikit-learn which is required {}".format(e))

outexpr = _expr.var('input', shape=shape, dtype=dtype)
for _, transformer in model.feature_transformer.steps:
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, None)
outexpr = _expr.var("input", shape=shape, dtype=dtype)

if func_name == "transform":
for _, transformer in model.feature_transformer.steps:
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, func_name, None)
else:
transformer = model.target_transformer
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, func_name, None)

func = _function.Function(analysis.free_vars(outexpr), outexpr)
return IRModule.from_expr(func), []
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,4 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("isnan", False, elemwise_shape_func)
register_shape_func("isinf", False, elemwise_shape_func)
register_shape_func("where", False, elemwise_shape_func)

register_shape_func("copy", False, elemwise_shape_func)
Loading

0 comments on commit 77e3b98

Please sign in to comment.