Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label-wise metrics (Accuracy etc.) for multi-label problems #513

Open
jphdotam opened this issue May 2, 2019 · 4 comments · May be fixed by #542
Open

Label-wise metrics (Accuracy etc.) for multi-label problems #513

jphdotam opened this issue May 2, 2019 · 4 comments · May be fixed by #542

Comments

@jphdotam
Copy link

jphdotam commented May 2, 2019

Hi,

I've made a multi-label classifier using BCEWithLogitsLoss. In summary a data sample can be one of 3 binary classes, which aren't mutually eclusive, so y_pred and y can look something like [0, 1, 1].

My metrics include Accuracy(output_transform=thresholded_output_transform, is_multilabel=True) and Precision(output_transform=thresholded_output_transform, is_multilabel=True, average=True)}.

However, I'm interesting in having label-specific metrics (i.e. having 3 accuracies etc.). This is important because it allows me to see what labels are compromising my overall accuracy the most (a 70% accuracy be a 30% error in a single label, or a more modest error scattered across 3 labels).

There is no option to disable averaging for Accuracy() as with the others, and setting average=False for Precision() does not do what I expected (it yields a binary result per datum, not per label, so I end up with a tensor of size 500, not 3, if my dataset n=500).

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

@jphdotam jphdotam changed the title Class-wise metrics (Accuracy etc.) for multi-label problems Label-wise metrics (Accuracy etc.) for multi-label problems May 2, 2019
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 2, 2019

@jphdotam thanks for the feedback! You are correct, multi-label case is always averaged for now for Accuracy, Precision, Recall.

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

There is an issue with a similar requirement #467
For instance we have not much bandwidth to work on that. If you can send a PR for that, we'll be awesome.

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

You can find a link for that here : https://pytorch.org/resources

@jphdotam
Copy link
Author

jphdotam commented May 2, 2019

Many thanks, I've made a pull request here: #516

I'm quite new to working on large projects so apologies if I have gone about this inappropriately.

@jphdotam
Copy link
Author

jphdotam commented May 3, 2019

In the mean time whilst the core team decide how best to implement this, this is a custom class I've made for the task which inherits from Accuracy:

class LabelwiseAccuracy(Accuracy):
    def __init__(self, output_transform=lambda x: x):
        self._num_correct = None
        self._num_examples = None
        super(LabelwiseAccuracy, self).__init__(output_transform=output_transform)

    def reset(self):
        self._num_correct = None
        self._num_examples = 0
        super(LabelwiseAccuracy, self).reset()

    def update(self, output):

        y_pred, y = self._check_shape(output)
        self._check_type((y_pred, y))

        num_classes = y_pred.size(1)
        last_dim = y_pred.ndimension()
        y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
        y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
        correct_exact = torch.all(y == y_pred.type_as(y), dim=-1)  # Sample-wise
        correct_elementwise = torch.sum(y == y_pred.type_as(y), dim=0)

        if self._num_correct is not None:
            self._num_correct = torch.add(self._num_correct,
                                                    correct_elementwise)
        else:
            self._num_correct = correct_elementwise
        self._num_examples += correct_exact.shape[0]

    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError('Accuracy must have at least one example before it can be computed.')
        return self._num_correct.type(torch.float) / self._num_examples

@crypdick
Copy link

crypdick commented Mar 7, 2020

For anyone trying to use @jphdotam code in #513 (comment) ,

y_pred, y = self._check_shape(output)

throws an exception because that function now returns nothing. Instead, use

self._check_shape(output)
y_pred, y = output

However, there's something wrong with it because I'm getting 'labelwise_accuracy': [0.9070000648498535, 0.8530000448226929, 0.8370000123977661, 0.7450000643730164, 0.8720000386238098, 0.7570000290870667, 0.9860000610351562, 0.9190000295639038, 0.8740000128746033] when 'avg_accuracy': 0.285

Edit: nvm, I stepped thru the code and it was fine. The bug was on my end. Cheers!

@vfdev-5 vfdev-5 added PyDataGlobal PyData Global 2020 Sprint and removed Hacktoberfest labels Oct 31, 2020
@vfdev-5 vfdev-5 removed the PyDataGlobal PyData Global 2020 Sprint label Dec 14, 2020
@vfdev-5 vfdev-5 added the module: metrics Metrics module label Jan 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants