dtruong46me's picture
Upload 29 files
97e4014 verified
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
from peft import (
get_peft_model,
)
class Model:
def __init__(self, checkpoint):
self.checkpoint = checkpoint
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
self.base_model = None
def get_model(self):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
def get_peft(self, lora_config):
return get_peft_model(self.base_model, lora_config)
def prepare_quantize(self, bnb_config):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint,
quantization_config=bnb_config,
device_map={"":0},
trust_remote_code=True)
# self.base_model.gradient_checkpointing_enable()
# self.base_model = prepare_model_for_kbit_training(self.base_model)
def generate_summary(self, input_text, generation_config, do_sample=True):
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
output_ids = self.base_model.generate(input_ids=input_ids, do_sample=do_sample, generation_config=generation_config)
if "bart" in self.checkpoint:
output_ids[0][1] = 2
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\033[94mSummary: {output_text}\n\033[00m")
return output_text
class BartSum(Model):
def __init__(self, checkpoint):
super().__init__(checkpoint)
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
def get_model(self):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
class FlanT5Sum(Model):
def __init__(self, checkpoint):
super().__init__(checkpoint)
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
def get_model(self):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
def load_model(checkpoint):
try:
if "bart" in checkpoint:
print(f"\033[92mLoad Bart model from checkpoint: {checkpoint}\033[00m")
return BartSum(checkpoint)
if "flan" in checkpoint:
print(f"\033[92mLoad Flan-T5 model from checkpoint: {checkpoint}\033[00m")
return FlanT5Sum(checkpoint)
else:
print(f"\033[92mLoad general model from checkpoint: {checkpoint}\033[00m")
return Model(checkpoint)
except Exception as e:
print("Error while loading model: {e}")
raise e