import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer from peft import PeftModel, AutoPeftModelForCausalLM from huggingface_hub import login login(os.environ['huggingface_token']) SYSTEM_PROMPT = """You are a mental health therapist/pyschologist/licensed professional counsellor. People will talk to you about their personal life/mental health issues and you will reply them the same way a professional therapist would, while being empathetic.""" def load_peft_model_and_tokenizer(peft_model, base_model): tokenizer = AutoTokenizer.from_pretrained(peft_model) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Load base model base_model = AutoModelForCausalLM.from_pretrained( base_model, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config ) base_model.resize_token_embeddings(len(tokenizer)) model = PeftModel.from_pretrained(model=base_model, model_id=peft_model) return tokenizer, model # def load_peft_model_and_tokenizer(peft_model): # # Load the tokenizer from the specified model path or identifier # tokenizer = AutoTokenizer.from_pretrained(peft_model) # # Load the PEFT model for causal language modeling with specific device map and torch dtype # model = AutoPeftModelForCausalLM.from_pretrained( # peft_model, # device_map="auto", # torch_dtype=torch.float16 # ) # return tokenizer, model def get_chatbot_response(model, tokenizer, message): input_ids = tokenizer( f"### System: {SYSTEM_PROMPT}. ### Human: {message} ### Assistant: ", return_tensors="pt", truncation=True, ).input_ids.cuda() outputs = model.generate( input_ids=input_ids, max_new_tokens=100, do_sample=True, temperature=0.5, top_p=0.9, top_k=50, repetition_penalty=1.1, ) # Decode the generated response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the chatbot's response from the decoded text response = response.split("### Assistant: ")[1].strip() return response