metrics.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import re
  2. import string
  3. from collections import Counter
  4. def levenshtein(a, b):
  5. n, m = len(a), len(b)
  6. if n > m:
  7. a, b, n, m = b, a, m, n
  8. current = list(range(n + 1))
  9. for i in range(1, m + 1):
  10. previous, current = current, [i] + [0] * n
  11. for j in range(1, n + 1):
  12. add, delete = previous[j] + 1, current[j - 1] + 1
  13. change = previous[j - 1]
  14. if a[j - 1] != b[i - 1]:
  15. change = change + 1
  16. current[j] = min(add, delete, change)
  17. return current[n]
  18. def word_error_rate(x, y):
  19. scores = words = 0
  20. for h, r in zip(x, y):
  21. h_list = h.split()
  22. r_list = r.split()
  23. words += len(r_list)
  24. scores += levenshtein(h_list, r_list)
  25. return float(scores) / words, float(scores), words
  26. def one_hot(x):
  27. return x.one_hot(3).squeeze(1).permute(0, 4, 1, 2, 3)
  28. def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6, argmax=True, to_one_hot_x=True):
  29. channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
  30. if argmax: prediction = prediction.argmax(axis=channel_axis)
  31. else: prediction = prediction.softmax(axis=channel_axis)
  32. if to_one_hot_x: prediction = one_hot(prediction)
  33. target = one_hot(target)
  34. prediction, target = prediction[:, 1:], target[:, 1:]
  35. assert prediction.shape == target.shape, f"prediction ({prediction.shape}) and target ({target.shape}) shapes do not match"
  36. intersection = (prediction * target).sum(axis=reduce_axis)
  37. target_sum = target.sum(axis=reduce_axis)
  38. prediction_sum = prediction.sum(axis=reduce_axis)
  39. result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
  40. return result
  41. def normalize_string(s):
  42. s = "".join(c for c in s.lower() if c not in string.punctuation)
  43. s = re.sub(r'\b(a|an|the)\b', ' ', s)
  44. return " ".join(s.split())
  45. def f1_score(x, y):
  46. xt = normalize_string(x).split()
  47. yt = normalize_string(y).split()
  48. ct = Counter(xt) & Counter(yt)
  49. if (ns := sum(ct.values())) == 0:
  50. return 0.0
  51. p = ns / len(xt)
  52. r = ns / len(yt)
  53. return 2 * p * r / (p + r)