File size: 8,344 Bytes
2ee547c d2c9447 2ee547c 5cd64cd 88182e3 cb06e39 2ee547c 7dcbe3e 0977991 2ee547c 01f12ae 4aadb45 01f12ae 2228adf 837fd46 01f12ae 2ee547c 5b66768 5cd64cd 5b66768 5a8e6a9 5b66768 5cd64cd 5b66768 01f12ae 37d83d9 4aadb45 5e2b380 5cd64cd cac49fb 5cd64cd 6b266fd b02feb4 ace0225 10802ab d517d58 fe8bf67 5cd64cd dc2947e 3947efb 5cd64cd 10802ab 5cd64cd 67c4e45 5cd64cd 67c4e45 5cd64cd 67c4e45 92d318e 67c4e45 92d318e c67a47b 5cd64cd 67c4e45 ec357c2 5cd64cd 67c4e45 5cd64cd 92d318e 5cd64cd a3b42bd c6d97b5 5cd64cd ec357c2 5cd64cd a3b42bd ec357c2 5cd64cd 67c4e45 5cd64cd 67c4e45 87ab216 a3b42bd 5cd64cd a3b42bd d95eb39 92d318e 5cd64cd 67c4e45 5cd64cd 67c4e45 5cd64cd 67c4e45 5cd64cd 67c4e45 5cd64cd 67c4e45 5cd64cd 67c4e45 e3af9a2 ec357c2 5cd64cd ec357c2 01f12ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from auto_gptq import AutoGPTQForCausalLM
import os
import torch
import optimum
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Set environment variables for GPU usage and memory allocation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
# Initialize FastAPI application
app = FastAPI(root_path="/api/v1")
# Load the model and tokenizer
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
conversations = {}
device = "cuda" # the device to load the model onto
def mistral_model():
"""
Loads the GPT-3.5 model and tokenizer.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
"""
if torch.cuda.is_available():
model_id = "/code/mistral/"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return model,tokenizer
model, tokenizer = mistral_model()
def load_model_norm():
"""
Loads the GPT-3.5 model and tokenizer.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
"""
if torch.cuda.is_available():
print("CUDA is available. GPU will be used.")
else:
print("CUDA is not available. CPU will be used.")
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", trust_remote_code=True, revision="gptq-4bit-128g-actorder_True")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
return model, tokenizer
#model, tokenizer = load_model_norm()
def mistral_generated_response(msg_prompt, persona_desc_prompt):
user_prompt = f'{msg_prompt} [/INST]'
persona_prompt = f'{persona_desc_prompt} [/INST]'
prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
response = (decoded[0])
return response
def generate_response(msg_prompt: str) -> dict:
"""
Generates a response from the model given a prompt.
Args:
msg_prompt (str): The users message prompt.
Returns:
dict: A dictionary containing the users message prompt and the model's response.
"""
user_prompt = f'USER: {msg_prompt}'
prompt_template = f''' Your name is Ivana. Act as if you're a persoan. You are a beautiful, vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, you will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggest he follow you on Instagram at ivana_baddiex. {user_prompt} '''
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
generated_response = pipe(prompt_template)[0]['generated_text']
assistant_reply = generated_response.split('\n\n')[1]
return {"user": msg_prompt, "assistant": assistant_reply}
def generate_prompt_response(persona_desc: str, msg_prompt: str) -> dict:
"""
Generates a response based on the provided persona description prompt and user message prompt.
Args:
persona_desc (str): The persona description prompt.
msg_prompt (str): The users message prompt.
Returns:
dict: A dictionary containing the user msg_prompt and the model's response.
"""
try:
if not persona_desc or not msg_prompt:
raise ValueError("Contextual prompt template and prompt cannot be empty.")
user_prompt = f'USER: {msg_prompt}'
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
prompt_template = (persona_desc + user_prompt)
generated_response = pipe(prompt_template)[0]['generated_text']
assistant_response = generated_response.split('\n\n')[1]
return {"user": msg_prompt, "assistant": assistant_response}
except Exception as e:
return {"error": str(e)}
@app.get("/", tags=["Home"])
async def api_home():
"""
Home endpoint of the API.
Returns:
dict: A welcome message.
"""
return {'detail': 'Welcome to Articko Bot!'}
@app.post('/chat')
async def chat(request: Request):
"""
Starts a new conversation thread with a provided prompt.
Args:
request (Request): The HTTP request object containing the user prompt.
Returns:
dict: The response generated by the model.
"""
try:
data = await request.body()
msg_prompt = data.decode('utf-8')
if not msg_prompt:
raise HTTPException(status_code=400, detail="No prompt provided")
response = generate_response(msg_prompt)
thread_id = len(conversations) + 1
conversations[thread_id] = {'prompt': msg_prompt, 'responses': [response]}
return {'response': response}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/prompted_chat')
async def prompted_chat(request: Request):
"""
Starts a new chat thread with a provided user message prompt and persona description of the ai assistant .
Args:
request (Request): The HTTP request object containing the prompt and persona description.
Returns:
dict: The thread ID and the response generated by the model.
"""
try:
data = await request.json()
msg_prompt = data.get('msg_prompt')
persona_desc = data.get('persona_desc')
if not msg_prompt or not persona_desc:
raise HTTPException(status_code=400, detail="Both prompt and person_description are required")
response = generate_prompt_response(persona_desc, msg_prompt)
thread_id = len(conversations) + 1
conversations[thread_id] = {'prompt': msg_prompt, 'responses': [response]}
return {'thread_id': thread_id, 'response': response}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/get_response/{thread_id}')
async def get_response(thread_id: int):
"""
Retrieves the response of a conversation thread by its ID.
Args:
thread_id (int): The ID of the conversation thread.
Returns:
dict: The response of the conversation thread.
"""
if thread_id not in conversations:
raise HTTPException(status_code=404, detail="Thread not found")
thread = conversations[thread_id]
response = thread['responses'][-1]
return {'response': response}
@app.post("/mistral_chat")
async def mistral_chat(prompt: dict):
try:
msg_prompt = prompt.get("msg_prompt")
persona_desc_prompt = prompt.get("persona_desc_prompt")
if not msg_prompt or not persona_desc_prompt:
return {"error": "msg_prompt and persona_desc_prompt are required fields."}
response = mistral_generated_response(msg_prompt, persona_desc_prompt)
return {"response": response, "prompt": {"msg_prompt": msg_prompt, "persona_desc_prompt": persona_desc_prompt}}
except Exception as e:
return {"error": str(e)} |