-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmetrics.py
30 lines (25 loc) · 1.06 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Set
import editdistance
import torch
from torch import Tensor
from torchmetrics import Metric
class CharacterErrorRate(Metric):
full_state_update = False
def __init__(self, ignore_indices: Set[int], *args):
super().__init__(*args)
self.ignore_indices = ignore_indices
self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.error: Tensor
self.total: Tensor
def update(self, preds, targets):
N = preds.shape[0]
for i in range(N):
pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
target = [token for token in targets[i].tolist() if token not in self.ignore_indices]
distance = editdistance.distance(pred, target)
if max(len(pred), len(target)) > 0:
self.error += distance / max(len(pred), len(target))
self.total += N
def compute(self) -> Tensor:
return self.error / self.total