losses.py 265 B

123456
  1. from examples.mlperf.metrics import dice_score
  2. def dice_ce_loss(pred, tgt):
  3. ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
  4. dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
  5. return (dice + ce) / 2