import torch import transformers from torch import nn from transformers.modeling_outputs import SemanticSegmenterOutput class FaceSegmenterConfig(transformers.PretrainedConfig): model_type = "image-segmentation" _id2label = { 0: "skin", 1: "l_brow", 2: "r_brow", 3: "l_eye", 4: "r_eye", 5: "eye_g", 6: "l_ear", 7: "r_ear", 8: "ear_r", 9: "nose", 10: "mouth", 11: "u_lip", 12: "l_lip", 13: "neck", 14: "neck_l", 15: "cloth", 16: "hair", 17: "hat", } _label2id = { "skin": 0, "l_brow": 1, "r_brow": 2, "l_eye": 3, "r_eye": 4, "eye_g": 5, "l_ear": 6, "r_ear": 7, "ear_r": 8, "nose": 9, "mouth": 10, "u_lip": 11, "l_lip": 12, "neck": 13, "neck_l": 14, "cloth": 15, "hair": 16, "hat": 17, } def __init__(self, **kwargs): super().__init__(**kwargs) self.id2label = kwargs.get("id2label", self._id2label) # for some reason these are getting convert to strings when used in pipelines id_keys = list(self.id2label.keys()) for label_id in id_keys: label_value = self.id2label.pop(label_id) self.id2label[int(label_id)] = label_value self.label2id = kwargs.get("label2id", self._label2id) self.num_classes = kwargs.get("num_classes", len(self.id2label)) def encode_down(c_in: int, c_out: int): return nn.Sequential( nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=c_out), nn.ReLU(inplace=True), nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=c_out), nn.ReLU(inplace=True), ) def decode_up(c: int): return nn.ConvTranspose2d( in_channels=c, out_channels=int(c / 2), kernel_size=2, stride=2, ) class FaceUNet(nn.Module): def __init__(self, num_classes: int): super().__init__() self.num_classes = num_classes # unet self.down_1 = nn.Conv2d( in_channels=3, out_channels=64, kernel_size=3, padding=1, ) self.down_2 = encode_down(64, 128) self.down_3 = encode_down(128, 256) self.down_4 = encode_down(256, 512) self.down_5 = encode_down(512, 1024) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Below, `in_channels` again becomes 1024 as we are concatinating. self.up_1 = decode_up(1024) self.up_c1 = encode_down(1024, 512) self.up_2 = decode_up(512) self.up_c2 = encode_down(512, 256) self.up_3 = decode_up(256) self.up_c3 = encode_down(256, 128) self.up_4 = decode_up(128) self.up_c4 = encode_down(128, 64) self.segment = nn.Conv2d( in_channels=64, out_channels=self.num_classes, kernel_size=3, padding=1, ) def forward(self, x): d1 = self.down_1(x) d2 = self.pool(d1) d3 = self.down_2(d2) d4 = self.pool(d3) d5 = self.down_3(d4) d6 = self.pool(d5) d7 = self.down_4(d6) d8 = self.pool(d7) d9 = self.down_5(d8) u1 = self.up_1(d9) x = self.up_c1(torch.cat([d7, u1], 1)) u2 = self.up_2(x) x = self.up_c2(torch.cat([d5, u2], 1)) u3 = self.up_3(x) x = self.up_c3(torch.cat([d3, u3], 1)) u4 = self.up_4(x) x = self.up_c4(torch.cat([d1, u4], 1)) x = self.segment(x) return x class Segformer(transformers.PreTrainedModel): config_class = FaceSegmenterConfig def __init__(self, config): super().__init__(config) self.config = config self.model = FaceUNet(num_classes=config.num_classes) def forward(self, tensor): return self.model.forward_features(tensor) class SegformerForSemanticSegmentation(transformers.PreTrainedModel): config_class = FaceSegmenterConfig def __init__(self, config): super().__init__(config) self.config = config self.model = FaceUNet(num_classes=config.num_classes) def forward(self, pixel_values, labels=None): logits = self.model(pixel_values) values = {"logits": logits} if labels is not None: loss = torch.nn.cross_entropy(logits, labels) values["loss"] = loss return SemanticSegmenterOutput(**values)