Skip to content

Commit

Permalink
Allow to add custom post transform functions that are not supported b…
Browse files Browse the repository at this point in the history
…y the ONNX spec yet (#463)

* add _add_post_transform_node function to allow for custom post transform nodes that are not supported by the ONNX spec yet

Signed-off-by: Jan-Benedikt Jagusch <[email protected]>

* add 'poisson' support also to WrappedBooster

Signed-off-by: Jan-Benedikt Jagusch <[email protected]>
  • Loading branch information
janjagusch authored May 20, 2021
1 parent 9542999 commit fd044d7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
5 changes: 3 additions & 2 deletions onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, booster):
if (_model_dict['objective'].startswith('binary') or
_model_dict['objective'].startswith('multiclass')):
self.operator_name = 'LgbmClassifier'
elif _model_dict['objective'].startswith('regression'):
elif (_model_dict['objective'].startswith('regression') or
_model_dict['objective'].startswith('poisson')):
self.operator_name = 'LgbmRegressor'
else:
# Other objectives are not supported.
Expand Down Expand Up @@ -170,4 +171,4 @@ def parse_lightgbm(model, initial_types=None, target_opset=None,
for variable in outputs:
raw_model_container.add_output(variable)

return topology
return topology
26 changes: 26 additions & 0 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import copy
import numbers
import numpy as np
import onnx
from collections import Counter
from ...common._apply_operation import (
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
from ...common._registration import register_converter
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
from ....proto import onnx_proto
from onnxconverter_common.container import ModelComponentContainer


def _translate_split_criterion(criterion):
Expand Down Expand Up @@ -222,6 +224,7 @@ def convert_lightgbm(scope, operator, container):

# Create different attributes for classifier and
# regressor, respectively
post_transform = None
if gbm_text['objective'].startswith('binary'):
n_classes = 1
attrs['post_transform'] = 'LOGISTIC'
Expand All @@ -232,6 +235,13 @@ def convert_lightgbm(scope, operator, container):
n_classes = 1 # Regressor has only one output variable
attrs['post_transform'] = 'NONE'
attrs['n_targets'] = n_classes
elif gbm_text['objective'].startswith('poisson'):
n_classes = 1 # Regressor has only one output variable
attrs['n_targets'] = n_classes
# 'Exp' is not a supported post_transform value in the ONNX spec yet,
# so we need to add an 'Exp' post transform node to the model
attrs['post_transform'] = 'NONE'
post_transform = "Exp"
else:
raise RuntimeError(
"LightGBM objective should be cleaned already not '{}'.".format(
Expand Down Expand Up @@ -392,6 +402,22 @@ def convert_lightgbm(scope, operator, container):
container.add_node('Identity', output_name,
operator.output_full_names,
name=scope.get_unique_operator_name('Identity'))
if post_transform:
_add_post_transform_node(container, post_transform)


def _add_post_transform_node(container: ModelComponentContainer, op_type: str):
"""
Add a post transform node to a ModelComponentContainer.
Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
"""
assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output."
original_output_name = container.outputs[0].name
new_output_name = f"{op_type.lower()}_{original_output_name}"
post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name])
container.nodes.append(post_transform_node)
container.outputs[0].name = new_output_name


def modify_tree_for_rule_in_set(gbm, use_float=False):
Expand Down

0 comments on commit fd044d7

Please sign in to comment.