Vitrous commited on
Commit
37d83d9
·
verified ·
1 Parent(s): 096a16c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -12,6 +12,8 @@ origins = [
12
  "https://articko-artickbot.hf.space/",
13
  "https://articko-artickbot.hf.space/api/start_conversation",
14
  "https://articko-artickbot.hf.space/api/get_response/",
 
 
15
  ]
16
  app.add_middleware(
17
  CORSMiddleware,
@@ -47,7 +49,7 @@ def load_quantized_model(model_id, model_basename):
47
 
48
 
49
  # Making the code device-agnostic
50
- model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
51
 
52
 
53
 
@@ -66,6 +68,9 @@ def load_model_norm():
66
 
67
  return model, tokenizer
68
 
 
 
 
69
 
70
  # Function to generate a response using the model
71
  def generate_response(prompt: str) -> str:
@@ -105,12 +110,14 @@ def generate_response(prompt: str) -> str:
105
  return generated_text
106
 
107
 
 
 
 
108
  @app.get("/", tags=["Home"])
109
  async def api_home():
110
  return {'detail': 'Welcome to Eren Bot!'}
111
 
112
-
113
-
114
  # Endpoint to start a new conversation thread
115
  @app.post('/api/start_conversation')
116
  async def start_conversation(request: Request):
@@ -129,6 +136,8 @@ async def start_conversation(request: Request):
129
 
130
 
131
  # Endpoint to get the response of a conversation thread
 
 
132
  @app.get('/api/get_response/{thread_id}')
133
  async def get_response(thread_id: int):
134
  if thread_id not in conversations:
@@ -143,6 +152,8 @@ async def get_response(thread_id: int):
143
  return {'response': response}
144
 
145
 
 
 
146
  @app.post('/api/chat')
147
  async def chat(request: Request):
148
  data = await request.json()
 
12
  "https://articko-artickbot.hf.space/",
13
  "https://articko-artickbot.hf.space/api/start_conversation",
14
  "https://articko-artickbot.hf.space/api/get_response/",
15
+ "https://127.0.0.1/",
16
+ "https://127.0.0.1/api/",
17
  ]
18
  app.add_middleware(
19
  CORSMiddleware,
 
49
 
50
 
51
  # Making the code device-agnostic
52
+ #model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
53
 
54
 
55
 
 
68
 
69
  return model, tokenizer
70
 
71
+
72
+ model, tokenizer = load_model_norm()
73
+
74
 
75
  # Function to generate a response using the model
76
  def generate_response(prompt: str) -> str:
 
110
  return generated_text
111
 
112
 
113
+
114
+
115
+
116
  @app.get("/", tags=["Home"])
117
  async def api_home():
118
  return {'detail': 'Welcome to Eren Bot!'}
119
 
120
+
 
121
  # Endpoint to start a new conversation thread
122
  @app.post('/api/start_conversation')
123
  async def start_conversation(request: Request):
 
136
 
137
 
138
  # Endpoint to get the response of a conversation thread
139
+
140
+
141
  @app.get('/api/get_response/{thread_id}')
142
  async def get_response(thread_id: int):
143
  if thread_id not in conversations:
 
152
  return {'response': response}
153
 
154
 
155
+
156
+
157
  @app.post('/api/chat')
158
  async def chat(request: Request):
159
  data = await request.json()