From 46e2bfd5fa7cb3ddd5697ea5287ccb717c2e8843 Mon Sep 17 00:00:00 2001 From: "langshi.cls" Date: Thu, 28 Apr 2022 17:20:38 +0800 Subject: [PATCH] [DIST] Support global metrics --- docs/model.md | 23 ++ hybridbackend/cpp/tensorflow/ops/gauc.cc | 148 +++++++ hybridbackend/tensorflow/__init__.py | 1 + hybridbackend/tensorflow/metrics/__init__.py | 25 ++ hybridbackend/tensorflow/metrics/accuracy.py | 86 ++++ hybridbackend/tensorflow/metrics/auc.py | 394 +++++++++++++++++++ hybridbackend/tensorflow/metrics/gauc.py | 71 ++++ hybridbackend/tensorflow/metrics/mean.py | 137 +++++++ 8 files changed, 885 insertions(+) create mode 100644 hybridbackend/cpp/tensorflow/ops/gauc.cc create mode 100644 hybridbackend/tensorflow/metrics/__init__.py create mode 100644 hybridbackend/tensorflow/metrics/accuracy.py create mode 100644 hybridbackend/tensorflow/metrics/auc.py create mode 100644 hybridbackend/tensorflow/metrics/gauc.py create mode 100644 hybridbackend/tensorflow/metrics/mean.py diff --git a/docs/model.md b/docs/model.md index 7b69761b..d0f680c1 100644 --- a/docs/model.md +++ b/docs/model.md @@ -63,3 +63,26 @@ def model_fn(features, labels, mode, params): loss=loss, train_op=train_op) ``` + +## 2. Global Metrics + +### 2.1 APIs + +```{eval-rst} +.. autofunction:: hybridbackend.tensorflow.metrics.accuracy +.. autofunction:: hybridbackend.tensorflow.metrics.auc +``` + +### 2.2 Example: Global AUC + +```python +import tensorflow as tf +import hybridbackend.tensorflow as hb + +def eval_fn(): + # ... + auc_and_update = hb.metrics.auc( + labels=eval_labels, + predictions=eval_logits) + return {'auc': auc_and_update} +``` diff --git a/hybridbackend/cpp/tensorflow/ops/gauc.cc b/hybridbackend/cpp/tensorflow/ops/gauc.cc new file mode 100644 index 00000000..8c6fd6a4 --- /dev/null +++ b/hybridbackend/cpp/tensorflow/ops/gauc.cc @@ -0,0 +1,148 @@ +/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if HYBRIDBACKEND_TENSORFLOW + +#include +#include +#include + +namespace tensorflow { +namespace hybridbackend { +namespace { +template +T GetNonClick(T* plabels, size_t k, int dim) { + if (dim == 1) return 1.0 - plabels[k]; + return plabels[2 * k]; +} + +template +T GetClick(T* plabels, size_t k, int dim) { + if (dim == 1) return plabels[k]; + return plabels[2 * k + 1]; +} + +template +bool ComputeGauc(T* plabels, T* ppreds, T* pfilter, size_t* pidx, size_t l, + size_t r, int dim, double* ret) { + std::sort(pidx + l, pidx + r, [ppreds, dim](size_t a, size_t b) { + return GetClick(ppreds, a, dim) < GetClick(ppreds, b, dim); + }); + double fp1, tp1, fp2, tp2, auc; + fp1 = tp1 = fp2 = tp2 = auc = 0; + size_t i; + for (size_t k = l; k < r; ++k) { + i = pidx[k]; + if (pfilter != nullptr && pfilter[i] == 0) continue; + fp2 += GetNonClick(plabels, i, dim); + tp2 += GetClick(plabels, i, dim); + auc += (fp2 - fp1) * (tp2 + tp1); + fp1 = fp2; + tp1 = tp2; + } + double threshold = static_cast(r - l) - 1e-3; + if (tp2 > threshold or fp2 > threshold) { + *ret = -0.5; + return true; + } + if (tp2 * fp2 > 0) { + *ret = (1.0 - auc / (2.0 * tp2 * fp2)); + return true; + } + return false; +} +} // anonymous namespace + +REGISTER_OP("GaucCalc") + .Output("aucs: T") + .Output("counts: int32") + .Input("labels: T") + .Input("predictions: T") + .Input("indicators: Tindicators") + .Attr("T: {float, double}") + .Attr("Tindicators: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape); + +// TODO(siran.ysr) Specify more accurate shape function and add operator docs. + +template +class GaucCalcOp : public OpKernel { + public: + explicit GaucCalcOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& labels_t = ctx->input(0); + const Tensor& predictions_t = ctx->input(1); + const Tensor& indicators_t = ctx->input(2); + + size_t ldim = labels_t.shape().dims(); + size_t n = labels_t.shape().dim_size(0); + std::vector index(n); + for (size_t i = 0; i < n; ++i) index[i] = i; + + T* labels = const_cast(&labels_t.flat()(0)); + T* predictions = const_cast(&predictions_t.flat()(0)); + auto indicators = indicators_t.flat(); + std::vector auc_values; + std::vector count_values; + bool first = true; + for (size_t begin = 0, end = 0; end < n; ++end) { + if (indicators(end) == indicators(begin)) continue; + + if (first) { + first = false; + } else { + double auc = 0; + if (ComputeGauc(labels, predictions, nullptr, index.data(), begin, + end, ldim, &auc)) { + if (auc >= 0) { + auc_values.emplace_back(auc); + count_values.emplace_back(end - begin); + } + } + } + begin = end; + } + + Tensor* aucs_t; + Tensor* counts_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, {static_cast(auc_values.size())}, + &aucs_t)); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(1, {static_cast(count_values.size())}, + &counts_t)); + std::copy(auc_values.begin(), auc_values.end(), &aucs_t->vec()(0)); + std::copy(count_values.begin(), count_values.end(), + &counts_t->vec()(0)); + } +}; + +#define REGISTER_GAUC_CALC(type, indicator_type) \ + REGISTER_KERNEL_BUILDER(Name("GaucCalc") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindicators"), \ + GaucCalcOp) + +REGISTER_GAUC_CALC(float, int32); +REGISTER_GAUC_CALC(float, int64); +REGISTER_GAUC_CALC(double, int32); +REGISTER_GAUC_CALC(double, int64); + +} // namespace hybridbackend +} // namespace tensorflow + +#endif // HYBRIDBACKEND_TENSORFLOW diff --git a/hybridbackend/tensorflow/__init__.py b/hybridbackend/tensorflow/__init__.py index 7c609869..2d3d421e 100644 --- a/hybridbackend/tensorflow/__init__.py +++ b/hybridbackend/tensorflow/__init__.py @@ -26,6 +26,7 @@ from . import feature_column from . import keras from . import layers +from . import metrics from . import ops as math from . import saved_model from . import training as train diff --git a/hybridbackend/tensorflow/metrics/__init__.py b/hybridbackend/tensorflow/metrics/__init__.py new file mode 100644 index 00000000..31a78e35 --- /dev/null +++ b/hybridbackend/tensorflow/metrics/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +r'''Metrics for evaluating models in hybridbackend. +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from hybridbackend.tensorflow.metrics.auc import auc +from hybridbackend.tensorflow.metrics.accuracy import accuracy +from hybridbackend.tensorflow.metrics.gauc import gauc diff --git a/hybridbackend/tensorflow/metrics/accuracy.py b/hybridbackend/tensorflow/metrics/accuracy.py new file mode 100644 index 00000000..419fee8f --- /dev/null +++ b/hybridbackend/tensorflow/metrics/accuracy.py @@ -0,0 +1,86 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +r'''A data-parallel Accuracy metric. +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl + +from hybridbackend.tensorflow.metrics.mean import mean + + +def accuracy(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + r'''Calculates how often `predictions` matches `labels`. + + The `accuracy` function creates two local variables, `total` and + `count` that are used to compute the frequency with which `predictions` + matches `labels`. This frequency is ultimately returned as `accuracy`: an + idempotent operation that simply divides `total` by `count`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the `accuracy`. + Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 + where the corresponding elements of `predictions` and `labels` match and 0.0 + otherwise. Then `update_op` increments `total` with the reduced sum of the + product of `weights` and `is_correct`, and it increments `count` with the + reduced sum of `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose shape matches + `predictions`. + predictions: The predicted values, a `Tensor` of any shape. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that `accuracy` should + be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + accuracy: A `Tensor` representing the accuracy, the value of `total` divided + by `count`. + update_op: An operation that increments the `total` and `count` variables + appropriately and whose value matches `accuracy`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + RuntimeError: If eager execution is enabled. + ''' + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, labels=labels, weights=weights) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + if labels.dtype != predictions.dtype: + predictions = math_ops.cast(predictions, labels.dtype) + is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) + return mean( + is_correct, weights, metrics_collections, updates_collections, + name or 'accuracy') diff --git a/hybridbackend/tensorflow/metrics/auc.py b/hybridbackend/tensorflow/metrics/auc.py new file mode 100644 index 00000000..93e59383 --- /dev/null +++ b/hybridbackend/tensorflow/metrics/auc.py @@ -0,0 +1,394 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +r'''A data-parallel AUC metric. +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.platform import tf_logging as logging + +from hybridbackend.tensorflow.distribute.communicator import CollectiveOps +from hybridbackend.tensorflow.distribute.communicator_pool import \ + CommunicatorPool + + +def _allreduce_auc(comm, inputs, inputs_deps): + r'''Communicator call to reduce auc across workers. + ''' + with ops.control_dependencies(inputs_deps): + if isinstance(inputs, (list, tuple)): + inputs = inputs[0] + sum_inputs = comm.allreduce(inputs, CollectiveOps.SUM) + return sum_inputs, None + + +def _confusion_matrix_at_thresholds(labels, + predictions, + thresholds, + weights=None): + r'''Computes true_positives, false_negatives, true_negatives, false_positives. + + This function creates up to four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives`. + `true_positive[i]` is defined as the total weight of values in `predictions` + above `thresholds[i]` whose corresponding entry in `labels` is `True`. + `false_negatives[i]` is defined as the total weight of values in `predictions` + at most `thresholds[i]` whose corresponding entry in `labels` is `True`. + `true_negatives[i]` is defined as the total weight of values in `predictions` + at most `thresholds[i]` whose corresponding entry in `labels` is `False`. + `false_positives[i]` is defined as the total weight of values in `predictions` + above `thresholds[i]` whose corresponding entry in `labels` is `False`. + + For estimation of these metrics over a stream of data, for each metric the + function respectively creates an `update_op` operation that updates the + variable and returns its value. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + + Returns: + values: Dict of variables of shape `[len(thresholds)]`. + update_ops: Dict of operations that increments the `values`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`. + ''' + with ops.control_dependencies([ + check_ops.assert_greater_equal( + predictions, + math_ops.cast(0.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]'), + check_ops.assert_less_equal( + predictions, + math_ops.cast(1.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]')]): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.to_float(predictions), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + num_thresholds = len(thresholds) + + # Reshape predictions and labels. + predictions_2d = array_ops.reshape(predictions, [-1, 1]) + labels_2d = array_ops.reshape( + math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) + + # Use static shape if known. + num_predictions = predictions_2d.get_shape().as_list()[0] + + # Otherwise use dynamic shape. + if num_predictions is None: + num_predictions = array_ops.shape(predictions_2d)[0] + thresh_tiled = array_ops.tile( + array_ops.expand_dims(array_ops.constant(thresholds), [1]), + array_ops.stack([1, num_predictions])) + + # Tile the predictions after thresholding them across different thresholds. + pred_is_pos = math_ops.greater( + array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), + thresh_tiled) + pred_is_neg = math_ops.logical_not(pred_is_pos) + + # Tile labels by number of thresholds + label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) + label_is_neg = math_ops.logical_not(label_is_pos) + + if weights is not None: + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), predictions) + weights_tiled = array_ops.tile( + array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) + thresh_tiled.get_shape().assert_is_compatible_with( + weights_tiled.get_shape()) + else: + weights_tiled = None + + values = {} + update_ops = {} + + true_p = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='true_positives') + is_true_positive = math_ops.to_float( + math_ops.logical_and(label_is_pos, pred_is_pos)) + if weights_tiled is not None: + is_true_positive *= weights_tiled + + false_n = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='false_negatives') + is_false_negative = math_ops.to_float( + math_ops.logical_and(label_is_pos, pred_is_neg)) + if weights_tiled is not None: + is_false_negative *= weights_tiled + + true_n = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='true_negatives') + is_true_negative = math_ops.to_float( + math_ops.logical_and(label_is_neg, pred_is_neg)) + if weights_tiled is not None: + is_true_negative *= weights_tiled + + false_p = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='false_positives') + is_false_positive = math_ops.to_float( + math_ops.logical_and(label_is_neg, pred_is_pos)) + if weights_tiled is not None: + is_false_positive *= weights_tiled + + tp_sum = math_ops.reduce_sum(is_true_positive, 1) + fn_sum = math_ops.reduce_sum(is_false_negative, 1) + tn_sum = math_ops.reduce_sum(is_true_negative, 1) + fp_sum = math_ops.reduce_sum(is_false_positive, 1) + + stacked = array_ops.stack([tp_sum, fn_sum, tn_sum, fp_sum]) + sum_stacked = CommunicatorPool.get().call( + _allreduce_auc, stacked, trainable=False) + if isinstance(sum_stacked, (list, tuple)): + sum_stacked = sum_stacked[0] + tp_sum, fn_sum, tn_sum, fp_sum = array_ops.unstack(sum_stacked) + + update_ops['tp'] = state_ops.assign_add(true_p, tp_sum) + update_ops['fn'] = state_ops.assign_add(false_n, fn_sum) + update_ops['tn'] = state_ops.assign_add(true_n, tn_sum) + update_ops['fp'] = state_ops.assign_add(false_p, fp_sum) + + values['tp'] = true_p + values['fn'] = false_n + values['tn'] = true_n + values['fp'] = false_p + + return values, update_ops + + +def auc(labels, + predictions, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None, + summation_method='trapezoidal'): + r'''Computes the approximate AUC via a Riemann sum. + + The `auc` function creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the AUC. To discretize the AUC curve, a linearly spaced set of + thresholds is used to compute pairs of recall and precision values. The area + under the ROC-curve is therefore computed using the height of the recall + values by the false positive rate, while the area under the PR-curve is the + computed using the height of the precision values by the recall. + + This value is ultimately returned as `auc`, an idempotent operation that + computes the area under a discretized curve of precision versus recall values + (computed using the aforementioned variables). The `num_thresholds` variable + controls the degree of discretization with larger numbers of thresholds more + closely approximating the true AUC. The quality of the approximation may vary + dramatically depending on `num_thresholds`. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC + approximation may be poor if this is not the case. Setting `summation_method` + to 'minoring' or 'majoring' can help quantify the error in the approximation + by providing lower or upper bound estimate of the AUC. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the `auc`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `auc` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + curve: Specifies the name of the curve to be computed, 'ROC' [default] or + 'PR' for the Precision-Recall-curve. + name: An optional variable_scope name. + summation_method: Specifies the Riemann summation method used + (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that + applies the trapezoidal rule; 'careful_interpolation', a variant of it + differing only by a more correct interpolation scheme for PR-AUC - + interpolating (true/false) positives but not the ratio that is precision; + 'minoring' that applies left summation for increasing intervals and right + summation for decreasing intervals; 'majoring' that does the opposite. + Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' + (to be deprecated soon) as it applies the same method for ROC, and a + better one (see Davis & Goadrich 2006 for details) for the PR curve. + + Returns: + (auc, update_op): A tuple of a scalar `Tensor` representing the current + area-under-curve and an operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables + appropriately and whose value matches `auc`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + RuntimeError: If eager execution is enabled. + ''' + with vs.variable_scope(name, 'auc', (labels, predictions, weights)): + if curve not in ('ROC', 'PR'): + raise ValueError(f'curve must be either ROC or PR, {curve} unknown') + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + values, update_ops = _confusion_matrix_at_thresholds( + labels, predictions, thresholds, weights) + + # Add epsilons to avoid dividing by 0. + epsilon = 1.0e-6 + + def interpolate_pr_auc(tp, fp, fn): + r'''Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + + Note here we derive & use a closed formula not present in the paper + - as follows: + Modeling all of TP (true positive weight), + FP (false positive weight) and their sum P = TP + FP (positive weight) + as varying linearly within each interval [A, B] between successive + thresholds, we get + Precision = (TP_A + slope * (P - P_A)) / P + with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). + The area within the interval is thus (slope / total_pos_weight) times + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight + where dTP == TP_B - TP_A. + Note that when P_A == 0 the above calculation simplifies into + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) + which is really equivalent to imputing constant precision throughout the + first bucket having >0 true positives. + + Args: + tp: true positive counts + fp: false positive counts + fn: false negative counts + Returns: + pr_auc: an approximation of the area under the P-R curve. + ''' + dtp = tp[:num_thresholds - 1] - tp[1:] + p = tp + fp + prec_slope = metrics_impl._safe_div( # pylint: disable=protected-access + dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope') + intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) + safe_p_ratio = array_ops.where( + math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), + metrics_impl._safe_div( # pylint: disable=protected-access + p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'), + array_ops.ones_like(p[1:])) + return math_ops.reduce_sum( + metrics_impl._safe_div( # pylint: disable=protected-access + prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), + tp[1:] + fn[1:], + name='pr_auc_increment'), + name='interpolate_pr_auc') + + def compute_auc(tp, fn, tn, fp, name): + r'''Computes the roc-auc or pr-auc based on confusion counts. + ''' + if curve == 'PR': + if summation_method == 'trapezoidal': + logging.warning( + 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' + 'please switch to "careful_interpolation" instead.') + elif summation_method == 'careful_interpolation': + # This one is a bit tricky and is handled separately. + return interpolate_pr_auc(tp, fp, fn) + rec = math_ops.div(tp + epsilon, tp + fn + epsilon) + if curve == 'ROC': + fp_rate = math_ops.div(fp, fp + tn + epsilon) + x = fp_rate + y = rec + else: # curve == 'PR'. + prec = math_ops.div(tp + epsilon, tp + fp + epsilon) + x = rec + y = prec + if summation_method in ('trapezoidal', 'careful_interpolation'): + # Note that the case ('PR', 'careful_interpolation') has been handled + # above. + return math_ops.reduce_sum( + math_ops.multiply( + x[:num_thresholds - 1] - x[1:], + (y[:num_thresholds - 1] + y[1:]) / 2.), + name=name) + if summation_method == 'minoring': + return math_ops.reduce_sum( + math_ops.multiply( + x[:num_thresholds - 1] - x[1:], + math_ops.minimum(y[:num_thresholds - 1], y[1:])), + name=name) + if summation_method == 'majoring': + return math_ops.reduce_sum( + math_ops.multiply( + x[:num_thresholds - 1] - x[1:], + math_ops.maximum(y[:num_thresholds - 1], y[1:])), + name=name) + raise ValueError(f'Invalid summation_method: {summation_method}') + + auc_value = compute_auc( + values['tp'], values['fn'], values['tn'], values['fp'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, auc_value) + + update_op = compute_auc( + update_ops['tp'], + update_ops['fn'], + update_ops['tn'], + update_ops['fp'], + 'update_op') + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return auc_value, update_op diff --git a/hybridbackend/tensorflow/metrics/gauc.py b/hybridbackend/tensorflow/metrics/gauc.py new file mode 100644 index 00000000..56d865fd --- /dev/null +++ b/hybridbackend/tensorflow/metrics/gauc.py @@ -0,0 +1,71 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +r'''A data-parallel gAUC metric. +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs + +from hybridbackend.tensorflow.metrics.mean import mean +from hybridbackend.tensorflow.pywrap import _ops + + +def gauc(labels, + predictions, + indicators=None, + metrics_collections=None, + updates_collections=None, + name=None): + r'''Computes the approximate gAUC. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + indicators: A `Tensor` whose shape matches `predictions`. + metrics_collections: An optional list of collections that `mean` + should be added to. + updates_collections: An optional list of collections that `update_op` + should be added to. + name: An optional variable_scope name. + + Returns: + (gauc, update_op): A tuple of a scalar `Tensor` representing the current + g-area-under-curve and an operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables + appropriately and whose value matches `auc`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + RuntimeError: If eager execution is enabled. + ''' + if indicators is None: + indicators = math_ops.range( + 0, array_ops.shape(array_ops.reshape(labels, [-1]))[0], + dtype=dtypes.int32) + with vs.variable_scope(name, 'gauc', (labels, predictions, indicators)): + aucs, counts = _ops.gauc_calc(labels, predictions, indicators) + return mean(aucs, counts, metrics_collections, updates_collections, name) diff --git a/hybridbackend/tensorflow/metrics/mean.py b/hybridbackend/tensorflow/metrics/mean.py new file mode 100644 index 00000000..3c0ac7c9 --- /dev/null +++ b/hybridbackend/tensorflow/metrics/mean.py @@ -0,0 +1,137 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +r'''A data-parallel Mean metric. +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import weights_broadcast_ops + +from hybridbackend.tensorflow.distribute.communicator import CollectiveOps +from hybridbackend.tensorflow.distribute.communicator_pool import \ + CommunicatorPool + + +def _allreduce_mean(comm, inputs, inputs_deps): + r'''Communicator call to reduce mean across workers. + ''' + with ops.control_dependencies(inputs_deps): + if isinstance(inputs, (list, tuple)): + inputs = inputs[0] + sum_inputs = comm.allreduce(inputs, CollectiveOps.SUM) + return sum_inputs, None + + +def mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + r'''Computes the (weighted) mean of the given values. + + The `mean` function creates two local variables, `total` and `count` + that are used to compute the average of `values`. This average is ultimately + returned as `mean` which is an idempotent operation that simply divides + `total` by `count`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the `mean`. + `update_op` increments `total` with the reduced sum of the product of `values` + and `weights`, and it increments `count` with the reduced sum of `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + values: A `Tensor` of arbitrary dimensions. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `values`, and must be broadcastable to `values` (i.e., all dimensions must + be either `1`, or the same as the corresponding `values` dimension). + metrics_collections: An optional list of collections that `mean` + should be added to. + updates_collections: An optional list of collections that `update_op` + should be added to. + name: An optional variable_scope name. + + Returns: + mean: A `Tensor` representing the current mean, the value of `total` divided + by `count`. + update_op: An operation that increments the `total` and `count` variables + appropriately and whose value matches `mean_value`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match `values`, + or if either `metrics_collections` or `updates_collections` are not a list + or tuple. + RuntimeError: If eager execution is enabled. + ''' + with vs.variable_scope(name, 'mean', (values, weights)): + values = math_ops.to_float(values) + + total = metrics_impl.metric_variable([], dtypes.float32, name='total') + count = metrics_impl.metric_variable([], dtypes.float32, name='count') + + if weights is None: + num_values = math_ops.to_float(array_ops.size(values)) + values_sum = math_ops.reduce_sum(values) + else: + values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=values, labels=None, weights=weights) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), values) + values = math_ops.multiply(values, weights) + values_sum = math_ops.reduce_sum(values) + num_values = math_ops.reduce_sum(weights) + + stacked = array_ops.stack([values_sum, num_values]) + sum_stacked = CommunicatorPool.get().call( + _allreduce_mean, stacked, trainable=False) + if isinstance(sum_stacked, (list, tuple)): + sum_stacked = sum_stacked[0] + values_sum, num_values = array_ops.unstack(sum_stacked) + + update_total_op = state_ops.assign_add(total, values_sum) + with ops.control_dependencies([values]): + update_count_op = state_ops.assign_add(count, num_values) + # pylint: disable=protected-access + metric_op = ( + metrics_impl._safe_scalar_div(total, count, 'value') + if hasattr(metrics_impl, '_safe_scalar_div') + else metrics_impl._safe_div(total, count, 'value')) + + if metrics_collections: + ops.add_to_collections(metrics_collections, metric_op) + + # pylint: disable=protected-access + update_op = ( + metrics_impl._safe_scalar_div( + update_total_op, update_count_op, 'update_op') + if hasattr(metrics_impl, '_safe_scalar_div') + else metrics_impl._safe_div( + update_total_op, update_count_op, 'update_op')) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return metric_op, update_op