|
import torch, sys |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
access_token = sys.argv[2] |
|
device = "xpu:0" if sys.argv[1] == "gpu" else "cpu:0" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./tokenizer/") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
"google/gemma-2b", |
|
token=access_token, |
|
low_cpu_mem_usage=True, |
|
return_dict=True, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
model = PeftModel.from_pretrained(base_model, "adapter_model") |
|
model = model.to(device) |
|
|
|
print("Prompt:", " ".join(sys.argv[3:])) |
|
|
|
inputs = tokenizer(" ".join(sys.argv[3:]), return_tensors="pt").to(device) |
|
outputs = model.generate(**inputs, max_new_tokens=200, |
|
do_sample=False, top_k=100,temperature=0.1, |
|
eos_token_id=tokenizer.eos_token_id) |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |