|
import torch
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSeq2SeqLM,
|
|
)
|
|
|
|
from peft import (
|
|
get_peft_model,
|
|
)
|
|
|
|
class Model:
|
|
def __init__(self, checkpoint):
|
|
self.checkpoint = checkpoint
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
|
self.base_model = None
|
|
|
|
def get_model(self):
|
|
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
|
|
|
def get_peft(self, lora_config):
|
|
return get_peft_model(self.base_model, lora_config)
|
|
|
|
def prepare_quantize(self, bnb_config):
|
|
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint,
|
|
quantization_config=bnb_config,
|
|
device_map={"":0},
|
|
trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
def generate_summary(self, input_text, generation_config, do_sample=True):
|
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
|
|
output_ids = self.base_model.generate(input_ids=input_ids, do_sample=do_sample, generation_config=generation_config)
|
|
|
|
if "bart" in self.checkpoint:
|
|
output_ids[0][1] = 2
|
|
|
|
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
print(f"\033[94mSummary: {output_text}\n\033[00m")
|
|
return output_text
|
|
|
|
class BartSum(Model):
|
|
def __init__(self, checkpoint):
|
|
super().__init__(checkpoint)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
|
|
|
def get_model(self):
|
|
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
|
|
|
|
|
class FlanT5Sum(Model):
|
|
def __init__(self, checkpoint):
|
|
super().__init__(checkpoint)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
|
|
|
def get_model(self):
|
|
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
|
|
|
|
|
def load_model(checkpoint):
|
|
|
|
try:
|
|
if "bart" in checkpoint:
|
|
print(f"\033[92mLoad Bart model from checkpoint: {checkpoint}\033[00m")
|
|
return BartSum(checkpoint)
|
|
|
|
if "flan" in checkpoint:
|
|
print(f"\033[92mLoad Flan-T5 model from checkpoint: {checkpoint}\033[00m")
|
|
return FlanT5Sum(checkpoint)
|
|
|
|
else:
|
|
print(f"\033[92mLoad general model from checkpoint: {checkpoint}\033[00m")
|
|
return Model(checkpoint)
|
|
|
|
except Exception as e:
|
|
print("Error while loading model: {e}")
|
|
raise e |