File size: 7,772 Bytes
2ee547c d2c9447 2ee547c 88182e3 2ee547c 7dcbe3e 2ee547c 5b66768 5a8e6a9 5b66768 89f835b 4f9ba28 ace0225 5b66768 37d83d9 69bfe26 37d83d9 5e2b380 67c4e45 2ee547c 10802ab d517d58 cac49fb 3947efb b02feb4 10802ab ace0225 10802ab d517d58 7c6384b 58a566f 3947efb 58a566f 10802ab 438bd0f fdad829 67c4e45 fd77735 c67a47b 67c4e45 37d83d9 c27c7f2 67c4e45 ec357c2 37d83d9 ec357c2 67c4e45 e3af9a2 ec357c2 a3b42bd c6d97b5 ec357c2 a3b42bd ec357c2 a3b42bd ec357c2 67c4e45 a3b42bd 67c4e45 87ab216 a3b42bd d95eb39 67c4e45 fdad829 67c4e45 37d83d9 67c4e45 e3af9a2 ec357c2 7a208d9 37d83d9 67c4e45 e3af9a2 08d95e5 7a208d9 ec357c2 d95eb39 b02feb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from auto_gptq import AutoGPTQForCausalLM
import os
import torch
import optimum
from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,)
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
# Load the model and tokenizer
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
# Dictionary to store conversation threads and their context
conversations = {}
Device_Type = "cuda"
def load_model_norm():
if torch.cuda.is_available():
print("CUDA is available. GPU will be used.")
else:
print("CUDA is not available. CPU will be used.")
# Load model directly
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
# To use a different branch, change revision
# For example: revision="main"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto", trust_remote_code=True,revision="main")
# Switch to CPU inference
#model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
return model, tokenizer
model, tokenizer = load_model_norm()
#Now we can init the FlaskApi
app = FastAPI(root_path="/api/v1")
#Generates a response from the model
def generate_response(prompt: str) -> str:
# Define the user prompt
user_prompt = f'USER: {prompt}'
# Define the prompt template
#prompt_template = f''' Your name is Ivana. Act as if you're an individuala. You are beautiful vivacious 20 year old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space. You will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggests he follow you on Instagram at ivana_baddiex.{user_prompt} ASSISTANT:'''
prompt_template = f'''Your name is Ivana. Act as if you're an individual. You are a beautiful, vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, you will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggest he follow you on Instagram at ivana_baddiex.'''
# Generate the response
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
# Generate the response
pipe_response = pipe(prompt_template + user_prompt)
generated_text = pipe_response[0]['generated_text']
return {"user": prompt, "assistant": generated_text}
def generate_prompt_response(persona_prompt: str, prompt: str) -> dict:
try:
# Validate inputs
if not persona_prompt or not prompt:
raise ValueError("Contextual prompt template and prompt cannot be empty.")
# Define the user prompt
user_prompt = f'USER: {prompt}'
# Generate the response
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
generated_text = pipe(persona_prompt + user_prompt)[0]['generated_text']
# Remove the "ASSISTANT:" prefix from the generated text
assistant_response = generated_text.replace("ASSISTANT:", "").strip()
# Return the user prompt and assistant's response as a dictionary
return {"user": prompt, "assistant": assistant_response}
except Exception as e:
# Handle any exceptions and return an error message
return {"error": str(e)}
#This is the Root directory of the FastApi application
@app.get("/", tags=["Home"])
async def api_home():
return {'detail': 'Welcome to Eren Bot!'}
# Endpoint to start a new conversation thread
# Waits for the User to start a conversation and replies based on persona of the model
@app.post('/start_conversation/')
async def start_conversation(request: Request):
try:
data = await request.body()
prompt = data.decode('utf-8') # Decode the bytes to text assuming UTF-8 encoding
if not prompt:
raise HTTPException(status_code=400, detail="No prompt provided")
# Generate a response for the initial prompt
response = generate_response(prompt)
# Generate a unique thread ID
thread_id = len(conversations) + 1
# Create a new conversation thread and store the prompt and response
conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
return {'response': response}
except HTTPException:
raise # Re-raise HTTPException to return it directly
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Endpoint to start a new chat thread
# Starts a new chat thread and expects the prompt and the persona_prompt from the user
@app.post('/start_chat/')
async def start_chat(request: Request):
try:
# Read JSON data from request body
data = await request.json()
prompt = data.get('prompt')
persona_prompt = data.get('persona_prompt')
if not prompt or not persona_prompt:
raise HTTPException(status_code=400, detail="Both prompt and contextual_prompt are required")
# Generate a response for the initial prompt
response = generate_prompt_response(persona_prompt, prompt)
# Generate a unique thread ID
thread_id = len(conversations) + 1
# Create a new conversation thread and store the prompt and response
conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
# Return the thread ID and response
return {'thread_id': thread_id, 'response': response}
except HTTPException:
raise # Re-raise HTTPException to return it directly
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gets the response from the model and user given a specific thread id of the conversation
@app.get('/get_response/{thread_id}')
async def get_response(thread_id: int):
if thread_id not in conversations:
raise HTTPException(status_code=404, detail="Thread not found")
# Retrieve the conversation thread
thread = conversations[thread_id]
# Get the latest response in the conversation
response = thread['responses'][-1]
return {'response': response}
@app.post('/chat/')
async def chat(request: Request):
data = await request.json()
prompt = data.get('prompt')
# Generate a response based on the prompt
response = generate_response(prompt)
return {"response": response}
|