bisoye's picture
Update main.py
96ca3c6 verified
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from peft import PeftModel, AutoPeftModelForCausalLM
from huggingface_hub import login
login(os.environ['huggingface_token'])
SYSTEM_PROMPT = """You are a mental health therapist/pyschologist/licensed professional counsellor.
People will talk to you about their personal life/mental health issues and you will reply them the same way a professional
therapist would, while being empathetic."""
def load_peft_model_and_tokenizer(peft_model, base_model):
tokenizer = AutoTokenizer.from_pretrained(peft_model)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
base_model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model=base_model, model_id=peft_model)
return tokenizer, model
# def load_peft_model_and_tokenizer(peft_model):
# # Load the tokenizer from the specified model path or identifier
# tokenizer = AutoTokenizer.from_pretrained(peft_model)
# # Load the PEFT model for causal language modeling with specific device map and torch dtype
# model = AutoPeftModelForCausalLM.from_pretrained(
# peft_model,
# device_map="auto",
# torch_dtype=torch.float16
# )
# return tokenizer, model
def get_chatbot_response(model, tokenizer, message):
input_ids = tokenizer(
f"### System: {SYSTEM_PROMPT}. ### Human: {message} ### Assistant: ",
return_tensors="pt",
truncation=True,
).input_ids.cuda()
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.5,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
)
# Decode the generated response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the chatbot's response from the decoded text
response = response.split("### Assistant: ")[1].strip()
return response