|
import uvicorn |
|
from fastapi import FastAPI, HTTPException, Request |
|
from auto_gptq import AutoGPTQForCausalLM |
|
import os |
|
import torch |
|
import optimum |
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,) |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
torch.cuda.empty_cache() |
|
torch.cuda.set_per_process_memory_fraction(0.8) |
|
|
|
|
|
|
|
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ" |
|
|
|
conversations = {} |
|
Device_Type = "cuda" |
|
|
|
|
|
def load_model_norm(): |
|
if torch.cuda.is_available(): |
|
print("CUDA is available. GPU will be used.") |
|
else: |
|
print("CUDA is not available. CPU will be used.") |
|
|
|
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto", trust_remote_code=True,revision="main") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model_norm() |
|
|
|
|
|
app = FastAPI(root_path="/api/v1") |
|
|
|
|
|
def generate_response(prompt: str) -> str: |
|
|
|
user_prompt = f'USER: {prompt}' |
|
|
|
|
|
|
|
|
|
prompt_template = f'''Your name is Ivana. Act as if you're an individual. You are a beautiful, vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, you will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggest he follow you on Instagram at ivana_baddiex.''' |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
repetition_penalty=1.1 |
|
) |
|
pipe(prompt_template + user_prompt) |
|
|
|
|
|
return {"user": prompt, "assistant": [0]['generated_text']} |
|
|
|
|
|
def generate_prompt_response(persona_prompt: str, prompt: str) -> dict: |
|
try: |
|
|
|
if not persona_prompt or not prompt: |
|
raise ValueError("Contextual prompt template and prompt cannot be empty.") |
|
|
|
|
|
user_prompt = f'USER: {prompt}' |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
repetition_penalty=1.1 |
|
) |
|
generated_text = pipe(persona_prompt + user_prompt)[0]['generated_text'] |
|
|
|
|
|
assistant_response = generated_text.replace("ASSISTANT:", "").strip() |
|
|
|
|
|
return {"user": prompt, "assistant": assistant_response} |
|
|
|
except Exception as e: |
|
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
@app.get("/", tags=["Home"]) |
|
async def api_home(): |
|
return {'detail': 'Welcome to Eren Bot!'} |
|
|
|
|
|
|
|
|
|
|
|
@app.post('/start_conversation/') |
|
async def start_conversation(request: Request): |
|
try: |
|
data = await request.body() |
|
prompt = data.decode('utf-8') |
|
|
|
|
|
if not prompt: |
|
raise HTTPException(status_code=400, detail="No prompt provided") |
|
|
|
|
|
response = generate_response(prompt) |
|
|
|
|
|
thread_id = len(conversations) + 1 |
|
|
|
|
|
conversations[thread_id] = {'prompt': prompt, 'responses': [response]} |
|
|
|
return {'response': response} |
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
@app.post('/start_chat/') |
|
async def start_chat(request: Request): |
|
try: |
|
|
|
data = await request.json() |
|
prompt = data.get('prompt') |
|
persona_prompt = data.get('persona_prompt') |
|
|
|
if not prompt or not persona_prompt: |
|
raise HTTPException(status_code=400, detail="Both prompt and contextual_prompt are required") |
|
|
|
|
|
response = generate_prompt_response(persona_prompt, prompt) |
|
|
|
|
|
thread_id = len(conversations) + 1 |
|
|
|
|
|
conversations[thread_id] = {'prompt': prompt, 'responses': [response]} |
|
|
|
|
|
return {'thread_id': thread_id, 'response': response} |
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
@app.get('/get_response/{thread_id}') |
|
async def get_response(thread_id: int): |
|
if thread_id not in conversations: |
|
raise HTTPException(status_code=404, detail="Thread not found") |
|
|
|
|
|
thread = conversations[thread_id] |
|
|
|
|
|
response = thread['responses'][-1] |
|
|
|
return {'response': response} |
|
|
|
|
|
|
|
|
|
|
|
@app.post('/chat/') |
|
async def chat(request: Request): |
|
data = await request.json() |
|
prompt = data.get('prompt') |
|
|
|
|
|
response = generate_response(prompt) |
|
|
|
return {"response": response} |
|
|
|
|
|
|
|
|