import gradio as gr import torch from transformers import AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel, MT5EncoderModel class MTRankerConfig(PretrainedConfig): def __init__(self, backbone='google/mt5-base', **kwargs): self.backbone = backbone super().__init__(**kwargs) class MTRanker(PreTrainedModel): config_class = MTRankerConfig def __init__(self, config): super().__init__(config) self.encoder = MT5EncoderModel.from_pretrained(config.backbone) self.num_classes = 2 self.classifier = torch.nn.Linear(self.encoder.config.hidden_size, self.num_classes) def forward(self, input_ids, attention_mask): encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state seq_lengths = torch.sum(attention_mask, keepdim=True, dim=1) pooled_hidden_state = torch.sum(encoder_output * attention_mask.unsqueeze(-1).expand(-1, -1, self.encoder.config.hidden_size), dim=1) pooled_hidden_state /= seq_lengths prediction_logit = self.classifier(pooled_hidden_state) return prediction_logit config = MTRankerConfig(backbone='google/mt5-base') tokenizer = AutoTokenizer.from_pretrained(config.backbone) model = MTRanker.from_pretrained('ibraheemmoosa/mt-ranker-base') def predict(source, translation1, translation2): model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2) inputs = tokenizer([model_input], max_length=512, padding='max_length', truncation=True, return_tensors='pt') with torch.inference_mode(): logits = model(inputs.input_ids, inputs.attention_mask) output_scores = torch.softmax(logits, dim=1) output_scores = output_scores[0] return {'Translation 1': output_scores[0], 'Translation 2': output_scores[1]} source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.") translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.") translation2_textbox = gr.Textbox(label="Translation 2", info="Translation 2", value="The cat is on the carpet.") output = gr.Label(label="Result") iface = gr.Interface(fn=predict, inputs=[source_textbox, translation1_textbox, translation2_textbox], outputs=output) iface.launch()