import torch import torch.optim as optim import torch.nn.functional as F def gram_matrix(feature): b, c, h, w = feature.size() feature = feature.view(b * c, h * w) return feature @ feature.t() def compute_loss(generated, content, style, bg_masks, alpha, beta): content_loss = sum(F.mse_loss(gf, cf) for gf, cf in zip(generated, content)) style_loss = sum( F.mse_loss( gram_matrix(gf * bg) if bg is not None else gram_matrix(gf), gram_matrix(sf * bg) if bg is not None else gram_matrix(sf), ) / len(generated) for gf, sf, bg in zip(generated, style, bg_masks or [None] * len(generated)) ) return alpha * content_loss, beta * style_loss, alpha * content_loss + beta * style_loss def inference( *, model, sod_model, content_image, content_image_norm, style_features, apply_to_background, lr=1.5e-2, iterations=51, optim_caller=optim.AdamW, alpha=1, beta=1, ): generated_image = content_image.clone().requires_grad_(True) optimizer = optim_caller([generated_image], lr=lr) with torch.no_grad(): content_features = model(content_image) bg_masks = None if apply_to_background: seg_output = torch.sigmoid(sod_model(content_image_norm)[0]) bg_mask = (seg_output <= 0.7).float() bg_masks = [ F.interpolate(bg_mask.unsqueeze(1), size=cf.shape[2:], mode='bilinear', align_corners=False) for cf in content_features ] def closure(): optimizer.zero_grad() generated_features = model(generated_image) content_loss, style_loss, total_loss = compute_loss( generated_features, content_features, style_features, bg_masks, alpha, beta ) total_loss.backward() return total_loss for _ in range(iterations): optimizer.step(closure) if apply_to_background: with torch.no_grad(): fg_mask = F.interpolate(1 - bg_masks[0], size=generated_image.shape[2:], mode='nearest') generated_image.data.mul_(1 - fg_mask).add_(content_image.data * fg_mask) return generated_image