Vitrous commited on
Commit
5cd64cd
·
verified ·
1 Parent(s): dc2947e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -97
app.py CHANGED
@@ -4,52 +4,52 @@ from auto_gptq import AutoGPTQForCausalLM
4
  import os
5
  import torch
6
  import optimum
7
- from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,)
 
 
8
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
9
  torch.cuda.empty_cache()
10
  torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
11
 
12
-
13
  # Load the model and tokenizer
14
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
15
- # Dictionary to store conversation threads and their context
16
  conversations = {}
17
- Device_Type = "cuda"
18
-
19
 
20
  def load_model_norm():
 
 
 
 
 
 
21
  if torch.cuda.is_available():
22
  print("CUDA is available. GPU will be used.")
23
  else:
24
  print("CUDA is not available. CPU will be used.")
25
- # Load model directly
26
- model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
27
- # To use a different branch, change revision
28
- # For example: revision="main"
29
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto", trust_remote_code=True,revision="gptq-4bit-128g-actorder_True")
30
- # Switch to CPU inference
31
- #model.to("cuda")
32
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
33
 
34
  return model, tokenizer
35
 
36
-
37
  model, tokenizer = load_model_norm()
38
 
39
- #Now we can init the FlaskApi
40
  app = FastAPI(root_path="/api/v1")
41
 
42
- #Generates a response from the model
43
- def generate_response(prompt: str) -> str:
44
- # Define the user prompt
45
- user_prompt = f'USER: {prompt}'
46
-
47
- # Define the prompt template
48
- #prompt_template = f''' Your name is Ivana. Act as if you're an individuala. You are beautiful vivacious 20 year old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space. You will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggests he follow you on Instagram at ivana_baddiex.{user_prompt} ASSISTANT:'''
49
 
 
 
 
 
 
 
 
50
  prompt_template = f''' Your name is Ivana. Act as if you're a persoan. You are a beautiful, vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, you will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggest he follow you on Instagram at ivana_baddiex. {user_prompt} '''
51
 
52
- # Generate the response
53
  pipe = pipeline(
54
  "text-generation",
55
  model=model,
@@ -61,24 +61,29 @@ def generate_response(prompt: str) -> str:
61
  top_k=40,
62
  repetition_penalty=1.1
63
  )
64
- # Generate the response
65
- generated_response = pipe(prompt_template)[0]['generated_text']
66
 
 
67
  assistant_reply = generated_response.split('\n\n')[1]
68
 
69
- return {"user": prompt, "assistant": assistant_reply}
70
 
71
-
72
- def generate_prompt_response(persona_prompt: str, prompt: str) -> dict:
 
 
 
 
 
 
 
 
 
73
  try:
74
- # Validate inputs
75
- if not persona_prompt or not prompt:
76
  raise ValueError("Contextual prompt template and prompt cannot be empty.")
77
-
78
- # Define the user prompt
79
- user_prompt = f'USER: {prompt}'
80
 
81
- # Generate the response
 
82
  pipe = pipeline(
83
  "text-generation",
84
  model=model,
@@ -90,113 +95,98 @@ def generate_prompt_response(persona_prompt: str, prompt: str) -> dict:
90
  top_k=40,
91
  repetition_penalty=1.1
92
  )
93
- generated_text = pipe(persona_prompt + user_prompt)[0]['generated_text']
94
 
95
- # Remove the "ASSISTANT:" prefix from the generated text
96
  assistant_response = generated_text.replace("ASSISTANT:", "").strip()
97
 
98
- # Return the user prompt and assistant's response as a dictionary
99
- return {"user": prompt, "assistant": assistant_response}
100
 
101
  except Exception as e:
102
- # Handle any exceptions and return an error message
103
  return {"error": str(e)}
104
 
105
-
106
-
107
- #This is the Root directory of the FastApi application
108
  @app.get("/", tags=["Home"])
109
  async def api_home():
110
- return {'detail': 'Welcome to Eren Bot!'}
111
-
112
-
113
- # Endpoint to start a new conversation thread
114
 
115
- # Waits for the User to start a conversation and replies based on persona of the model
 
 
 
 
116
  @app.post('/start_conversation/')
117
  async def start_conversation(request: Request):
 
 
 
 
 
 
 
 
 
118
  try:
119
  data = await request.body()
120
- prompt = data.decode('utf-8') # Decode the bytes to text assuming UTF-8 encoding
121
-
122
 
123
- if not prompt:
124
  raise HTTPException(status_code=400, detail="No prompt provided")
125
 
126
- # Generate a response for the initial prompt
127
- response = generate_response(prompt)
128
-
129
- # Generate a unique thread ID
130
  thread_id = len(conversations) + 1
131
-
132
- # Create a new conversation thread and store the prompt and response
133
- conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
134
 
135
  return {'response': response}
136
  except HTTPException:
137
- raise # Re-raise HTTPException to return it directly
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=str(e))
140
 
141
- # Endpoint to start a new chat thread
142
-
143
- # Starts a new chat thread and expects the prompt and the persona_prompt from the user
144
- @app.post('/start_chat/')
145
  async def start_chat(request: Request):
 
 
 
 
 
 
 
 
 
146
  try:
147
- # Read JSON data from request body
148
  data = await request.json()
149
- prompt = data.get('prompt')
150
- persona_prompt = data.get('persona_prompt')
151
 
152
- if not prompt or not persona_prompt:
153
- raise HTTPException(status_code=400, detail="Both prompt and contextual_prompt are required")
154
 
155
- # Generate a response for the initial prompt
156
- response = generate_prompt_response(persona_prompt, prompt)
157
 
158
- # Generate a unique thread ID
159
  thread_id = len(conversations) + 1
160
-
161
- # Create a new conversation thread and store the prompt and response
162
- conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
163
 
164
- # Return the thread ID and response
165
  return {'thread_id': thread_id, 'response': response}
166
  except HTTPException:
167
- raise # Re-raise HTTPException to return it directly
168
  except Exception as e:
169
  raise HTTPException(status_code=500, detail=str(e))
170
 
171
-
172
-
173
- # Gets the response from the model and user given a specific thread id of the conversation
174
  @app.get('/get_response/{thread_id}')
175
  async def get_response(thread_id: int):
 
 
 
 
 
 
 
 
 
176
  if thread_id not in conversations:
177
  raise HTTPException(status_code=404, detail="Thread not found")
178
 
179
- # Retrieve the conversation thread
180
  thread = conversations[thread_id]
181
-
182
- # Get the latest response in the conversation
183
  response = thread['responses'][-1]
184
 
185
  return {'response': response}
186
-
187
-
188
-
189
-
190
-
191
- @app.post('/chat/')
192
- async def chat(request: Request):
193
- data = await request.json()
194
- prompt = data.get('prompt')
195
-
196
- # Generate a response based on the prompt
197
- response = generate_response(prompt)
198
-
199
- return {"response": response}
200
-
201
-
202
-
 
4
  import os
5
  import torch
6
  import optimum
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+
9
+ # Set environment variables for GPU usage and memory allocation
10
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
11
  torch.cuda.empty_cache()
12
  torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
13
 
 
14
  # Load the model and tokenizer
15
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
 
16
  conversations = {}
 
 
17
 
18
  def load_model_norm():
19
+ """
20
+ Loads the GPT-3.5 model and tokenizer.
21
+
22
+ Returns:
23
+ tuple: A tuple containing the loaded model and tokenizer.
24
+ """
25
  if torch.cuda.is_available():
26
  print("CUDA is available. GPU will be used.")
27
  else:
28
  print("CUDA is not available. CPU will be used.")
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", trust_remote_code=True, revision="gptq-4bit-128g-actorder_True")
 
 
 
 
 
31
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
32
 
33
  return model, tokenizer
34
 
 
35
  model, tokenizer = load_model_norm()
36
 
37
+ # Initialize FastAPI application
38
  app = FastAPI(root_path="/api/v1")
39
 
40
+ def generate_response(msg_prompt: str) -> dict:
41
+ """
42
+ Generates a response from the model given a prompt.
 
 
 
 
43
 
44
+ Args:
45
+ msg_prompt (str): The users message prompt.
46
+
47
+ Returns:
48
+ dict: A dictionary containing the users message prompt and the model's response.
49
+ """
50
+ user_prompt = f'USER: {msg_prompt}'
51
  prompt_template = f''' Your name is Ivana. Act as if you're a persoan. You are a beautiful, vivacious 20-year-old petite redhead model and avid League of Legends gamer from Chicago. You are chatting with a man you matched with on a dating app. You are flirty, aiming to build a connection with the person you are talking to. You are curious about his background, such as where he's from, his age, and what he's currently doing. If he's not occupied and in a private space, you will try to seduce him. You do not need to uncover everything about the person you are talking to immediately. Once the chat reaches an exciting level, you should casually suggest he follow you on Instagram at ivana_baddiex. {user_prompt} '''
52
 
 
53
  pipe = pipeline(
54
  "text-generation",
55
  model=model,
 
61
  top_k=40,
62
  repetition_penalty=1.1
63
  )
 
 
64
 
65
+ generated_response = pipe(prompt_template)[0]['generated_text']
66
  assistant_reply = generated_response.split('\n\n')[1]
67
 
68
+ return {"user": msg_prompt, "assistant": assistant_reply}
69
 
70
+ def generate_prompt_response(persona_desc: str, msg_prompt: str) -> dict:
71
+ """
72
+ Generates a response based on the provided persona description prompt and user message prompt.
73
+
74
+ Args:
75
+ persona_desc (str): The persona description prompt.
76
+ msg_prompt (str): The users message prompt.
77
+
78
+ Returns:
79
+ dict: A dictionary containing the user msg_prompt and the model's response.
80
+ """
81
  try:
82
+ if not persona_desc or not msg_prompt:
 
83
  raise ValueError("Contextual prompt template and prompt cannot be empty.")
 
 
 
84
 
85
+ user_prompt = f'USER: {msg_prompt}'
86
+
87
  pipe = pipeline(
88
  "text-generation",
89
  model=model,
 
95
  top_k=40,
96
  repetition_penalty=1.1
97
  )
 
98
 
99
+ generated_text = pipe(persona_desc + user_prompt)[0]['generated_text']
100
  assistant_response = generated_text.replace("ASSISTANT:", "").strip()
101
 
102
+ return {"user": msg_prompt, "assistant": assistant_response}
 
103
 
104
  except Exception as e:
 
105
  return {"error": str(e)}
106
 
 
 
 
107
  @app.get("/", tags=["Home"])
108
  async def api_home():
109
+ """
110
+ Home endpoint of the API.
 
 
111
 
112
+ Returns:
113
+ dict: A welcome message.
114
+ """
115
+ return {'detail': 'Welcome to Articko Bot!'}
116
+
117
  @app.post('/start_conversation/')
118
  async def start_conversation(request: Request):
119
+ """
120
+ Starts a new conversation thread with a provided prompt.
121
+
122
+ Args:
123
+ request (Request): The HTTP request object containing the user prompt.
124
+
125
+ Returns:
126
+ dict: The response generated by the model.
127
+ """
128
  try:
129
  data = await request.body()
130
+ msg_prompt = data.decode('utf-8')
 
131
 
132
+ if not msg_prompt:
133
  raise HTTPException(status_code=400, detail="No prompt provided")
134
 
135
+ response = generate_response(msg_prompt)
 
 
 
136
  thread_id = len(conversations) + 1
137
+ conversations[thread_id] = {'prompt': msg_prompt, 'responses': [response]}
 
 
138
 
139
  return {'response': response}
140
  except HTTPException:
141
+ raise
142
  except Exception as e:
143
  raise HTTPException(status_code=500, detail=str(e))
144
 
145
+ @app.post('/custom_prompted_chat/')
 
 
 
146
  async def start_chat(request: Request):
147
+ """
148
+ Starts a new chat thread with a provided user message prompt and persona description of the ai assistant .
149
+
150
+ Args:
151
+ request (Request): The HTTP request object containing the prompt and persona description.
152
+
153
+ Returns:
154
+ dict: The thread ID and the response generated by the model.
155
+ """
156
  try:
 
157
  data = await request.json()
158
+ msg_prompt = data.get('msg_prompt')
159
+ persona_desc = data.get('persona_desc')
160
 
161
+ if not msg_prompt or not persona_desc:
162
+ raise HTTPException(status_code=400, detail="Both prompt and person_description are required")
163
 
164
+ response = generate_prompt_response(persona_desc, msg_prompt)
 
165
 
 
166
  thread_id = len(conversations) + 1
167
+ conversations[thread_id] = {'prompt': msg_prompt, 'responses': [response]}
 
 
168
 
 
169
  return {'thread_id': thread_id, 'response': response}
170
  except HTTPException:
171
+ raise
172
  except Exception as e:
173
  raise HTTPException(status_code=500, detail=str(e))
174
 
 
 
 
175
  @app.get('/get_response/{thread_id}')
176
  async def get_response(thread_id: int):
177
+ """
178
+ Retrieves the response of a conversation thread by its ID.
179
+
180
+ Args:
181
+ thread_id (int): The ID of the conversation thread.
182
+
183
+ Returns:
184
+ dict: The response of the conversation thread.
185
+ """
186
  if thread_id not in conversations:
187
  raise HTTPException(status_code=404, detail="Thread not found")
188
 
 
189
  thread = conversations[thread_id]
 
 
190
  response = thread['responses'][-1]
191
 
192
  return {'response': response}