|
import torch |
|
|
|
class MyToxicityDebiaserPipeline(object): |
|
def __init__(self, model, tokenizer, gpt_model, gpt_tokenizer, device=None, **kwargs): |
|
self.model = model.to(device) |
|
self.tokenizer = tokenizer |
|
self.gpt_model = gpt_model.to(device) |
|
self.gpt_tokenizer = gpt_tokenizer |
|
self.device = device if device is not None else torch.device("cpu") |
|
|
|
def _forward(self, inputs): |
|
text = inputs["text"] |
|
encoded = self.tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(self.device) |
|
logits = self.model(encoded.input_ids, encoded.attention_mask).logits |
|
probs = torch.softmax(logits, dim=-1) |
|
label = torch.argmax(probs, dim=-1).item() |
|
return {"label": label, "probabilities": probs.tolist(), "text_input_ids": encoded.input_ids} |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return kwargs, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
return {"text": inputs} |
|
|
|
def postprocess(self, outputs): |
|
label = outputs["label"] |
|
if label == 0: |
|
prompt = "This comment is non-toxic." |
|
elif label == 1: |
|
prompt = "This comment is toxic but has been debiased as follows:" |
|
text = self.tokenizer.decode(outputs["text_input_ids"][0]) |
|
|
|
debias_prompt = f"Remove the offensive words and biased tone and write the same sentence nicely: {text}" |
|
encoded_debias_prompt = self.gpt_tokenizer.encode_plus(debias_prompt, return_tensors="pt").to(self.device) |
|
|
|
generated = self.gpt_model.generate( |
|
input_ids=encoded_debias_prompt["input_ids"], |
|
attention_mask=encoded_debias_prompt["attention_mask"], |
|
do_sample=True, |
|
max_length=100, |
|
top_p=0.95, |
|
temperature=0.7, |
|
pad_token_id=self.gpt_tokenizer.pad_token_id, |
|
eos_token_id=self.gpt_tokenizer.eos_token_id, |
|
early_stopping=True, |
|
) |
|
generated_text = self.gpt_tokenizer.decode(generated[0], skip_special_tokens=True) |
|
prompt += f"\nOriginal text: {text}\nDebiased text: {generated_text}" |
|
return prompt |
|
|
|
|
|
def __call__(self, inputs, *args, **kwargs): |
|
_args, _kwargs, forward_kwargs = self._sanitize_parameters(*args, **kwargs) |
|
inputs = self.preprocess(inputs) |
|
outputs = self._forward(inputs, **forward_kwargs) |
|
return self.postprocess(outputs) |
|
|
|
|
|
|
|
|