Skip to content

Commit

Permalink
PUBDEV-5294 - XGBoost model in Flow doesn't seem to converge
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Pscheidl committed Feb 15, 2018
1 parent 97f351b commit 9661331
Showing 1 changed file with 40 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import water.fvec.Frame;
import water.util.Log;
import hex.ModelMetrics;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -21,6 +24,7 @@
public class XGBoostModel extends Model<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> {

private XGBoostModelInfo model_info;
private static final NumberFormat localizedNumberFormatter = DecimalFormat.getNumberInstance();

XGBoostModelInfo model_info() { return model_info; }

Expand Down Expand Up @@ -290,9 +294,29 @@ public static HashMap<String, Object> createParams(XGBoostParameters p, XGBoostO
Log.info(" " + s.getKey() + " = " + s.getValue());
}
Log.info("");

localizeDecimalParams(params);
return params;
}

/**
* Iterates over a set of parameters and applies locale-specific formatting
* to decimal ones (Floats and Doubles).
*
* @param params Parameters to localize
*/
private static void localizeDecimalParams(final HashMap<String, Object> params) {

for (String key : params.keySet()) {
final Object value = params.get(key);
if (value instanceof Float || value instanceof Double) {
final String localizedValue = localizedNumberFormatter.format(value);
params.put(key, localizedValue);
}
}

}

@Override
protected double[] score0(double[] data, double[] preds) {
return score0(data, preds, 0.0);
Expand Down Expand Up @@ -325,7 +349,7 @@ private ModelMetrics makeMetrics(Booster booster, Frame data, Frame originalData
ModelMetrics[] mms = new ModelMetrics[1];
Frame predictions = makePreds(booster, data, mms, true, predFrameKey, fs);
if (predFrameKey == null) {
predictions.remove(fs);
predictions.remove(fs);
} else {
DKV.put(predictions, fs);
}
Expand All @@ -345,17 +369,17 @@ private Frame makePredsOnly(Booster booster, Frame data, Key<Frame> destinationK
}

private Frame makePreds(Booster booster, Frame data, ModelMetrics[] mms, boolean computeMetrics, Key<Frame> destinationKey, Futures fs) throws XGBoostError {
assert (! computeMetrics) || (mms != null && mms.length == 1);

XGBoostScoreTask.XGBoostScoreTaskResult score = XGBoostScoreTask.runScoreTask(
model_info(), _output, _parms,
booster, destinationKey, data,
computeMetrics
);
if(computeMetrics) {
mms[0] = score.mm;
}
return score.preds;
assert (! computeMetrics) || (mms != null && mms.length == 1);

XGBoostScoreTask.XGBoostScoreTaskResult score = XGBoostScoreTask.runScoreTask(
model_info(), _output, _parms,
booster, destinationKey, data,
computeMetrics
);
if(computeMetrics) {
mms[0] = score.mm;
}
return score.preds;
}

/**
Expand Down Expand Up @@ -399,8 +423,8 @@ void computeVarImp(Map<String, Integer> varimp) {
public double[] score0(double[] data, double[] preds, double offset) {
DataInfo di = model_info._dataInfoKey.get();
return XGBoostMojoModel.score0(data, offset, preds,
model_info.getBooster(), di._nums, di._cats, di._catOffsets, di._useAllFactorLevels,
_output.nclasses(), _output._priorClassDist, defaultThreshold(), _output._sparse);
model_info.getBooster(), di._nums, di._cats, di._catOffsets, di._useAllFactorLevels,
_output.nclasses(), _output._priorClassDist, defaultThreshold(), _output._sparse);
}

@Override
Expand All @@ -412,8 +436,8 @@ public double[][] score0( Chunk chks[], double[] offset, int[] rowsInChunk, doub
}
DataInfo di = model_info._dataInfoKey.get();
double[][] scored = XGBoostMojoModel.bulkScore0(tmp, offset, preds,
model_info.getBooster(), di._nums, di._cats, di._catOffsets, di._useAllFactorLevels,
_output.nclasses(), _output._priorClassDist, defaultThreshold(), _output._sparse);
model_info.getBooster(), di._nums, di._cats, di._catOffsets, di._useAllFactorLevels,
_output.nclasses(), _output._priorClassDist, defaultThreshold(), _output._sparse);

if(isSupervised()) {
// Correct probabilities obtained from training on oversampled data back to original distribution
Expand Down

0 comments on commit 9661331

Please sign in to comment.