Skip to content

Commit

Permalink
add clip
Browse files Browse the repository at this point in the history
Signed-off-by: xavier dupré <[email protected]>
  • Loading branch information
sdpython committed Mar 12, 2021
1 parent fb2065b commit 23bb91d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
23 changes: 12 additions & 11 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numbers
import numpy as np
from collections import Counter
from ...common._apply_operation import apply_div, apply_reshape, apply_sub, apply_cast, apply_identity
from ...common._apply_operation import (
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
from ...common._registration import register_converter
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
from ....proto import onnx_proto
Expand Down Expand Up @@ -452,28 +453,28 @@ def str2number(val):


def convert_lgbm_zipmap(scope, operator, container):
zipmap_attrs = {'name': scope.get_unique_operator_name('ZipMap')}
to_type = onnx_proto.TensorProto.INT64

if hasattr(operator, 'classlabels_int64s'):
zipmap_attrs['classlabels_int64s'] = operator.classlabels_int64s
elif hasattr(operator, 'classlabels_strings'):
zipmap_attrs['classlabels_strings'] = operator.classlabels_strings
to_type = onnx_proto.TensorProto.STRING

if to_type == onnx_proto.TensorProto.STRING:
apply_identity(scope, operator.inputs[0].full_name,
operator.outputs[0].full_name, container)
else:
apply_cast(scope, operator.inputs[0].full_name,
operator.outputs[0].full_name, container, to=to_type)
if operator.zipmap:
zipmap_attrs = {'name': scope.get_unique_operator_name('ZipMap')}
if hasattr(operator, 'classlabels_int64s'):
zipmap_attrs['classlabels_int64s'] = operator.classlabels_int64s
elif hasattr(operator, 'classlabels_strings'):
zipmap_attrs['classlabels_strings'] = operator.classlabels_strings
to_type = onnx_proto.TensorProto.STRING

container.add_node('ZipMap', operator.inputs[1].full_name,
operator.outputs[1].full_name,
op_domain='ai.onnx.ml', **zipmap_attrs)
else:
container.add_node('Identity', operator.inputs[1].full_name,
operator.outputs[1].full_name)
apply_clip(scope, operator.inputs[1].full_name,
operator.outputs[1].full_name, container,
min=0.0, max=1.0)


register_converter('LgbmClassifier', convert_lightgbm)
Expand Down
9 changes: 5 additions & 4 deletions tests/lightgbm/test_LightGbmTreeEnsembleConverters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ def test_lightgbm_classifier_zipmap(self):
assert "zipmap" in str(onx).lower()

def test_lightgbm_classifier_nozipmap(self):
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]]
X = numpy.array(X, dtype=numpy.float32)
y = [0, 1, 0, 1]
model = LGBMClassifier(n_estimators=3, min_child_samples=1)
y = [0, 1, 0, 1, 1, 0]
model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2)
model.fit(X, y)
onx = convert_model(
model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))],
zipmap=False)
assert "zipmap" not in str(onx).lower()
sess = InferenceSession(onx.SerializeToString())
onxs = onx[0].SerializeToString()
sess = onnxruntime.InferenceSession(onxs)
exp = model.predict(X), model.predict_proba(X)
got = sess.run(None, {'X': X})
assert_almost_equal(exp[0], got[0])
Expand Down

0 comments on commit 23bb91d

Please sign in to comment.