Skip to content

Commit

Permalink
add _add_post_transform_node function to allow for custom post transf…
Browse files Browse the repository at this point in the history
…orm nodes that are not supported by the ONNX spec yet
  • Loading branch information
janjagusch committed May 11, 2021
1 parent 9542999 commit 6cc1947
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import copy
import numbers
import numpy as np
import onnx
from collections import Counter
from ...common._apply_operation import (
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
from ...common._registration import register_converter
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
from ....proto import onnx_proto
from onnxconverter_common.container import ModelComponentContainer


def _translate_split_criterion(criterion):
Expand Down Expand Up @@ -222,6 +224,7 @@ def convert_lightgbm(scope, operator, container):

# Create different attributes for classifier and
# regressor, respectively
post_transform = None
if gbm_text['objective'].startswith('binary'):
n_classes = 1
attrs['post_transform'] = 'LOGISTIC'
Expand All @@ -232,6 +235,13 @@ def convert_lightgbm(scope, operator, container):
n_classes = 1 # Regressor has only one output variable
attrs['post_transform'] = 'NONE'
attrs['n_targets'] = n_classes
elif gbm_text['objective'].startswith('poisson'):
n_classes = 1 # Regressor has only one output variable
attrs['n_targets'] = n_classes
# 'Exp' is not a supported post_transform value in the ONNX spec yet,
# so we need to add an 'Exp' post transform node to the model
attrs['post_transform'] = 'NONE'
post_transform = "Exp"
else:
raise RuntimeError(
"LightGBM objective should be cleaned already not '{}'.".format(
Expand Down Expand Up @@ -392,6 +402,22 @@ def convert_lightgbm(scope, operator, container):
container.add_node('Identity', output_name,
operator.output_full_names,
name=scope.get_unique_operator_name('Identity'))
if post_transform:
_add_post_transform_node(container, post_transform)


def _add_post_transform_node(container: ModelComponentContainer, op_type: str):
"""
Add a post transform node to a ModelComponentContainer.
Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
"""
assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output."
original_output_name = container.outputs[0].name
new_output_name = f"{op_type.lower()}_{original_output_name}"
post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name])
container.nodes.append(post_transform_node)
container.outputs[0].name = new_output_name


def modify_tree_for_rule_in_set(gbm, use_float=False):
Expand Down

0 comments on commit 6cc1947

Please sign in to comment.