import torch from torchvision import transforms from PIL import Image import numpy as np import torch.nn.functional as F from u2net import U2NET import data_transforms from transformers import Pipeline class U2NetPipeline(Pipeline): def __init__(self, model, **kwargs): super().__init__(model=model, **kwargs) self.model = model self.model.eval() @classmethod def from_pretrained(cls, model_path, **kwargs): model = U2NET(3, 1) model.load_state_dict(torch.load(f"{model_path}/u2net.pth", map_location="cpu")) return cls(model, **kwargs) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, image): if isinstance(image, str): image = Image.open(image).convert("RGB") elif isinstance(image, Image.Image): image = image.convert("RGB") else: raise ValueError("Input must be a PIL Image or a path to an image file") image = np.array(image) transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)]) sample = transform({"imidx": np.array([0]), "image": image, "label": np.zeros(image.shape[:2])}) input_size = [1024, 1024] im_tensor = sample['image'].unsqueeze(0) im_tensor = F.interpolate(im_tensor, input_size, mode="bilinear") image = torch.divide(im_tensor, 255.0) image = transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])(image) return {"image": image, "original_size": image.shape[2:]} def _forward(self, model_inputs): with torch.no_grad(): outputs = self.model(model_inputs["image"]) return {"outputs": outputs, "original_size": model_inputs["original_size"]} def postprocess(self, model_outputs): result = model_outputs["outputs"][0][0] result = F.interpolate(result, size=model_outputs["original_size"], mode='bilinear', align_corners=False) result = result.squeeze().cpu().numpy() ma, mi = result.max(), result.min() result = (result - mi) / (ma - mi) return (result * 255).astype(np.uint8) # Remove or comment out this function as it's no longer needed # def load_model(): # return U2NetPipeline("u2net.pth")