Skip to content

Commit

Permalink
[DIST] Support global metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
francktcheng authored and 2sin18 committed May 8, 2022
1 parent 5355a0e commit 46e2bfd
Show file tree
Hide file tree
Showing 8 changed files with 885 additions and 0 deletions.
23 changes: 23 additions & 0 deletions docs/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,26 @@ def model_fn(features, labels, mode, params):
loss=loss,
train_op=train_op)
```

## 2. Global Metrics

### 2.1 APIs

```{eval-rst}
.. autofunction:: hybridbackend.tensorflow.metrics.accuracy
.. autofunction:: hybridbackend.tensorflow.metrics.auc
```

### 2.2 Example: Global AUC

```python
import tensorflow as tf
import hybridbackend.tensorflow as hb

def eval_fn():
# ...
auc_and_update = hb.metrics.auc(
labels=eval_labels,
predictions=eval_logits)
return {'auc': auc_and_update}
```
148 changes: 148 additions & 0 deletions hybridbackend/cpp/tensorflow/ops/gauc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if HYBRIDBACKEND_TENSORFLOW

#include <tensorflow/core/framework/common_shape_fns.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/shape_inference.h>

namespace tensorflow {
namespace hybridbackend {
namespace {
template <typename T>
T GetNonClick(T* plabels, size_t k, int dim) {
if (dim == 1) return 1.0 - plabels[k];
return plabels[2 * k];
}

template <typename T>
T GetClick(T* plabels, size_t k, int dim) {
if (dim == 1) return plabels[k];
return plabels[2 * k + 1];
}

template <typename T>
bool ComputeGauc(T* plabels, T* ppreds, T* pfilter, size_t* pidx, size_t l,
size_t r, int dim, double* ret) {
std::sort(pidx + l, pidx + r, [ppreds, dim](size_t a, size_t b) {
return GetClick<T>(ppreds, a, dim) < GetClick<T>(ppreds, b, dim);
});
double fp1, tp1, fp2, tp2, auc;
fp1 = tp1 = fp2 = tp2 = auc = 0;
size_t i;
for (size_t k = l; k < r; ++k) {
i = pidx[k];
if (pfilter != nullptr && pfilter[i] == 0) continue;
fp2 += GetNonClick<T>(plabels, i, dim);
tp2 += GetClick<T>(plabels, i, dim);
auc += (fp2 - fp1) * (tp2 + tp1);
fp1 = fp2;
tp1 = tp2;
}
double threshold = static_cast<double>(r - l) - 1e-3;
if (tp2 > threshold or fp2 > threshold) {
*ret = -0.5;
return true;
}
if (tp2 * fp2 > 0) {
*ret = (1.0 - auc / (2.0 * tp2 * fp2));
return true;
}
return false;
}
} // anonymous namespace

REGISTER_OP("GaucCalc")
.Output("aucs: T")
.Output("counts: int32")
.Input("labels: T")
.Input("predictions: T")
.Input("indicators: Tindicators")
.Attr("T: {float, double}")
.Attr("Tindicators: {int32, int64}")
.SetShapeFn(shape_inference::UnknownShape);

// TODO(siran.ysr) Specify more accurate shape function and add operator docs.

template <typename T, typename Tindicators>
class GaucCalcOp : public OpKernel {
public:
explicit GaucCalcOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
const Tensor& labels_t = ctx->input(0);
const Tensor& predictions_t = ctx->input(1);
const Tensor& indicators_t = ctx->input(2);

size_t ldim = labels_t.shape().dims();
size_t n = labels_t.shape().dim_size(0);
std::vector<size_t> index(n);
for (size_t i = 0; i < n; ++i) index[i] = i;

T* labels = const_cast<T*>(&labels_t.flat<T>()(0));
T* predictions = const_cast<T*>(&predictions_t.flat<T>()(0));
auto indicators = indicators_t.flat<Tindicators>();
std::vector<double> auc_values;
std::vector<size_t> count_values;
bool first = true;
for (size_t begin = 0, end = 0; end < n; ++end) {
if (indicators(end) == indicators(begin)) continue;

if (first) {
first = false;
} else {
double auc = 0;
if (ComputeGauc<T>(labels, predictions, nullptr, index.data(), begin,
end, ldim, &auc)) {
if (auc >= 0) {
auc_values.emplace_back(auc);
count_values.emplace_back(end - begin);
}
}
}
begin = end;
}

Tensor* aucs_t;
Tensor* counts_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, {static_cast<int64>(auc_values.size())},
&aucs_t));
OP_REQUIRES_OK(
ctx, ctx->allocate_output(1, {static_cast<int64>(count_values.size())},
&counts_t));
std::copy(auc_values.begin(), auc_values.end(), &aucs_t->vec<T>()(0));
std::copy(count_values.begin(), count_values.end(),
&counts_t->vec<int32>()(0));
}
};

#define REGISTER_GAUC_CALC(type, indicator_type) \
REGISTER_KERNEL_BUILDER(Name("GaucCalc") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<indicator_type>("Tindicators"), \
GaucCalcOp<type, indicator_type>)

REGISTER_GAUC_CALC(float, int32);
REGISTER_GAUC_CALC(float, int64);
REGISTER_GAUC_CALC(double, int32);
REGISTER_GAUC_CALC(double, int64);

} // namespace hybridbackend
} // namespace tensorflow

#endif // HYBRIDBACKEND_TENSORFLOW
1 change: 1 addition & 0 deletions hybridbackend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import feature_column
from . import keras
from . import layers
from . import metrics
from . import ops as math
from . import saved_model
from . import training as train
Expand Down
25 changes: 25 additions & 0 deletions hybridbackend/tensorflow/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

r'''Metrics for evaluating models in hybridbackend.
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from hybridbackend.tensorflow.metrics.auc import auc
from hybridbackend.tensorflow.metrics.accuracy import accuracy
from hybridbackend.tensorflow.metrics.gauc import gauc
86 changes: 86 additions & 0 deletions hybridbackend/tensorflow/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

r'''A data-parallel Accuracy metric.
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics_impl

from hybridbackend.tensorflow.metrics.mean import mean


def accuracy(labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
r'''Calculates how often `predictions` matches `labels`.
The `accuracy` function creates two local variables, `total` and
`count` that are used to compute the frequency with which `predictions`
matches `labels`. This frequency is ultimately returned as `accuracy`: an
idempotent operation that simply divides `total` by `count`.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the `accuracy`.
Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
where the corresponding elements of `predictions` and `labels` match and 0.0
otherwise. Then `update_op` increments `total` with the reduced sum of the
product of `weights` and `is_correct`, and it increments `count` with the
reduced sum of `weights`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: The ground truth values, a `Tensor` whose shape matches
`predictions`.
predictions: The predicted values, a `Tensor` of any shape.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
metrics_collections: An optional list of collections that `accuracy` should
be added to.
updates_collections: An optional list of collections that `update_op` should
be added to.
name: An optional variable_scope name.
Returns:
accuracy: A `Tensor` representing the accuracy, the value of `total` divided
by `count`.
update_op: An operation that increments the `total` and `count` variables
appropriately and whose value matches `accuracy`.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
RuntimeError: If eager execution is enabled.
'''
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=predictions, labels=labels, weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
if labels.dtype != predictions.dtype:
predictions = math_ops.cast(predictions, labels.dtype)
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
return mean(
is_correct, weights, metrics_collections, updates_collections,
name or 'accuracy')
Loading

0 comments on commit 46e2bfd

Please sign in to comment.