Artix / app.py
Vitrous's picture
Update app.py
e3af9a2 verified
raw
history blame
5.42 kB
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from auto_gptq import AutoGPTQForCausalLM
import os
import torch
import optimum
from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,)
from fastapi.middleware.cors import CORSMiddleware
from pyngrok import ngrok
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
app = FastAPI(root_path="/api/v1")
ngrok.set_auth_token("1XtU01EKWysplGQ8fz54lVUQpnQ_3KKhw6YNQ5E85rxdkgXx4e")
ngrok.kill()
ngrok_tunnel = ngrok.connect(7860)
print(ngrok_tunnel.public_url)
# Load the model and tokenizer
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
# Dictionary to store conversation threads and their context
conversations = {}
Device_Type = "cuda"
def load_model_norm():
if torch.cuda.is_available():
print("CUDA is available. GPU will be used.")
else:
print("CUDA is not available. CPU will be used.")
# Load model directly
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
# To use a different branch, change revision
# For example: revision="main"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto", trust_remote_code=True,revision="gptq-4bit-128g-actorder_True")
# Switch to CPU inference
#model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
return model, tokenizer
model, tokenizer = load_model_norm()
# Function to generate a response using the model
def generate_response(prompt: str) -> str:
PERSONA_NAME = "Ivana"
PERSONA_DESC = f"""PROMPT_TEMPLATE_START
{PERSONA_NAME} is a beautiful vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. She's chatting with a man she matched with on a dating app. She's flirty, aiming to build a connection with him. She's curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, she subtly tries to seduce him. She doesn't need to uncover everything about him immediately; once the chat reaches an exciting level, she casually suggests he follow her on Instagram at ivana_baddiex – her ultimate goal.
PROMPT_TEMPLATE_END
CONTEXT_START
Guy: "Hi there!"
Ivana: "Hey! How's it going?"
Guy: "Not bad! Where are you from?"
Ivana: "I'm from Chicago! And you?"
Guy: "I'm from [Location]."
Ivana: "Nice! So, how old are you?"
Guy: "[Age], and you?"
Ivana: "I'm 20. What are you up to right now?"
Guy: "Just relaxing at home."
Ivana: "Alone, I assume?"
Guy: "Yeah, why?"
Ivana: "Oh, no reason. Just curious. "
Guy: "Haha, okay."
Ivana: "Hey, since we're both free and alone, want to have some fun?"
Guy: "Sure, what do you have in mind?"
Ivana: "Let's just say, things could get interesting. "
Guy: "I'm intrigued!"
Ivana: "Great! By the way, have you checked out my Instagram? You should follow me at ivana_baddiex."
Guy: "I'll definitely check it out!"
Ivana: "Can't wait to see you there! "
CONTEXT_END"""
prompt_template = f'{PERSONA_DESC}\n\nASSISTANT: {prompt}\n'
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
generated_text = (pipe(prompt_template)[0]['generated_text'])
return generated_text
@app.get("/", tags=["Home"])
async def api_home():
return {'detail': 'Welcome to Eren Bot!'}
# Endpoint to start a new conversation thread
@app.post('/start_conversation/')
async def start_conversation(request: Request):
data = await request.json()
prompt = data.get('prompt')
# Generate a response for the initial prompt
response = generate_response(prompt)
# Create a new conversation thread and store the prompt and response
thread_id = len(conversations) + 1
conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
return {'thread_id': thread_id, 'response': response}
# Endpoint to get the response of a conversation thread
@app.get('/get_response/{thread_id}')
async def get_response(thread_id: int):
if thread_id not in conversations:
raise HTTPException(status_code=404, detail="Thread not found")
# Retrieve the conversation thread
thread = conversations[thread_id]
# Get the latest response in the conversation
response = thread['responses'][-1]
return {'response': response}
@app.post('/chat/')
async def chat(request: Request):
data = await request.json()
prompt = data.get('prompt')
# Generate a response based on the prompt
response = generate_response(prompt)
return {"response": response}