-
-
Notifications
You must be signed in to change notification settings - Fork 624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Label-wise metrics (Accuracy etc.) for multi-label problems #513
Comments
@jphdotam thanks for the feedback! You are correct, multi-label case is always averaged for now for Accuracy, Precision, Recall.
There is an issue with a similar requirement #467
You can find a link for that here : https://pytorch.org/resources |
Many thanks, I've made a pull request here: #516 I'm quite new to working on large projects so apologies if I have gone about this inappropriately. |
In the mean time whilst the core team decide how best to implement this, this is a custom class I've made for the task which inherits from Accuracy: class LabelwiseAccuracy(Accuracy):
def __init__(self, output_transform=lambda x: x):
self._num_correct = None
self._num_examples = None
super(LabelwiseAccuracy, self).__init__(output_transform=output_transform)
def reset(self):
self._num_correct = None
self._num_examples = 0
super(LabelwiseAccuracy, self).reset()
def update(self, output):
y_pred, y = self._check_shape(output)
self._check_type((y_pred, y))
num_classes = y_pred.size(1)
last_dim = y_pred.ndimension()
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct_exact = torch.all(y == y_pred.type_as(y), dim=-1) # Sample-wise
correct_elementwise = torch.sum(y == y_pred.type_as(y), dim=0)
if self._num_correct is not None:
self._num_correct = torch.add(self._num_correct,
correct_elementwise)
else:
self._num_correct = correct_elementwise
self._num_examples += correct_exact.shape[0]
def compute(self):
if self._num_examples == 0:
raise NotComputableError('Accuracy must have at least one example before it can be computed.')
return self._num_correct.type(torch.float) / self._num_examples |
For anyone trying to use @jphdotam code in #513 (comment) ,
throws an exception because that function now returns nothing. Instead, use
However, there's something wrong with it because I'm getting Edit: nvm, I stepped thru the code and it was fine. The bug was on my end. Cheers! |
Hi,
I've made a multi-label classifier using
BCEWithLogitsLoss
. In summary a data sample can be one of 3 binary classes, which aren't mutually eclusive, so y_pred and y can look something like [0, 1, 1].My metrics include
Accuracy(output_transform=thresholded_output_transform, is_multilabel=True)
andPrecision(output_transform=thresholded_output_transform, is_multilabel=True, average=True)}
.However, I'm interesting in having label-specific metrics (i.e. having 3 accuracies etc.). This is important because it allows me to see what labels are compromising my overall accuracy the most (a 70% accuracy be a 30% error in a single label, or a more modest error scattered across 3 labels).
There is no option to disable averaging for
Accuracy()
as with the others, and settingaverage=False
forPrecision()
does not do what I expected (it yields a binary result per datum, not per label, so I end up with a tensor of size 500, not 3, if my dataset n=500).Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?
P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?
The text was updated successfully, but these errors were encountered: