File size: 597 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch.nn.functional as F
import torch
def l1_loss(output, target_rgb, target_raw, weight=1.0):
raw_loss = F.l1_loss(output["reconstruct_raw"], target_raw)
rgb_loss = F.l1_loss(output["reconstruct_rgb"], target_rgb)
total_loss = raw_loss + weight * rgb_loss
return total_loss, raw_loss, rgb_loss
def l2_loss(output, target_rgb, target_raw, weight=1.0):
raw_loss = F.mse_loss(output["reconstruct_raw"], target_raw)
rgb_loss = F.mse_loss(output["reconstruct_rgb"], target_rgb)
total_loss = raw_loss + weight * rgb_loss
return total_loss, raw_loss, rgb_loss
|