dice.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # https://github.com/mlcommons/training/blob/master/image_segmentation/pytorch/model/losses.py
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class Dice:
  6. def __init__(self,
  7. to_onehot_y: bool = True,
  8. to_onehot_x: bool = False,
  9. use_softmax: bool = True,
  10. use_argmax: bool = False,
  11. include_background: bool = False,
  12. layout: str = "NCDHW"):
  13. self.include_background = include_background
  14. self.to_onehot_y = to_onehot_y
  15. self.to_onehot_x = to_onehot_x
  16. self.use_softmax = use_softmax
  17. self.use_argmax = use_argmax
  18. self.smooth_nr = 1e-6
  19. self.smooth_dr = 1e-6
  20. self.layout = layout
  21. def __call__(self, prediction, target):
  22. if self.layout == "NCDHW":
  23. channel_axis = 1
  24. reduce_axis = list(range(2, len(prediction.shape)))
  25. else:
  26. channel_axis = -1
  27. reduce_axis = list(range(1, len(prediction.shape) - 1))
  28. num_pred_ch = prediction.shape[channel_axis]
  29. if self.use_softmax:
  30. prediction = torch.softmax(prediction, dim=channel_axis)
  31. elif self.use_argmax:
  32. prediction = torch.argmax(prediction, dim=channel_axis)
  33. if self.to_onehot_y:
  34. target = to_one_hot(target, self.layout, channel_axis)
  35. if self.to_onehot_x:
  36. prediction = to_one_hot(prediction, self.layout, channel_axis)
  37. if not self.include_background:
  38. assert num_pred_ch > 1, \
  39. f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
  40. if self.layout == "NCDHW":
  41. target = target[:, 1:]
  42. prediction = prediction[:, 1:]
  43. else:
  44. target = target[..., 1:]
  45. prediction = prediction[..., 1:]
  46. assert (target.shape == prediction.shape), \
  47. f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."
  48. intersection = torch.sum(target * prediction, dim=reduce_axis)
  49. target_sum = torch.sum(target, dim=reduce_axis)
  50. prediction_sum = torch.sum(prediction, dim=reduce_axis)
  51. return (2.0 * intersection + self.smooth_nr) / (target_sum + prediction_sum + self.smooth_dr)
  52. def to_one_hot(array, layout, channel_axis):
  53. if len(array.shape) >= 5:
  54. array = torch.squeeze(array, dim=channel_axis)
  55. array = F.one_hot(array.long(), num_classes=3)
  56. if layout == "NCDHW":
  57. array = array.permute(0, 4, 1, 2, 3).float()
  58. return array
  59. class DiceCELoss(nn.Module):
  60. def __init__(self, to_onehot_y, use_softmax, layout, include_background):
  61. super(DiceCELoss, self).__init__()
  62. self.dice = Dice(to_onehot_y=to_onehot_y, use_softmax=use_softmax, layout=layout,
  63. include_background=include_background)
  64. self.cross_entropy = nn.CrossEntropyLoss()
  65. def forward(self, y_pred, y_true):
  66. cross_entropy = self.cross_entropy(y_pred, torch.squeeze(y_true, dim=1).long())
  67. dice = torch.mean(1.0 - self.dice(y_pred, y_true))
  68. return (dice + cross_entropy) / 2
  69. class DiceScore:
  70. def __init__(self, to_onehot_y: bool = True, use_argmax: bool = True, layout: str = "NCDHW",
  71. include_background: bool = False):
  72. self.dice = Dice(to_onehot_y=to_onehot_y, to_onehot_x=True, use_softmax=False,
  73. use_argmax=use_argmax, layout=layout, include_background=include_background)
  74. def __call__(self, y_pred, y_true):
  75. return torch.mean(self.dice(y_pred, y_true), dim=0)