jonathanjordan21 commited on
Commit
162773e
1 Parent(s): 6072755

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +68 -1
apis/chat_api.py CHANGED
@@ -146,6 +146,73 @@ class ChatAPIApp:
146
  raise HTTPException(status_code=e.status_code, detail=e.detail)
147
  except Exception as e:
148
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  class GenerateRequest(BaseModel):
@@ -315,7 +382,7 @@ class ChatAPIApp:
315
  prefix + "/chat",
316
  summary="Ollama Chat completions in conversation session",
317
  include_in_schema=include_in_schema,
318
- )(self.chat_completions)
319
 
320
  self.app.post(
321
  prefix + "/embeddings",
 
146
  raise HTTPException(status_code=e.status_code, detail=e.detail)
147
  except Exception as e:
148
  raise HTTPException(status_code=500, detail=str(e))
149
+
150
+
151
+ class ChatCompletionsPostItem(BaseModel):
152
+ model: str = Field(
153
+ default="nous-mixtral-8x7b",
154
+ description="(str) `nous-mixtral-8x7b`",
155
+ )
156
+ messages: list = Field(
157
+ default=[{"role": "user", "content": "Hello, who are you?"}],
158
+ description="(list) Messages",
159
+ )
160
+ temperature: Union[float, None] = Field(
161
+ default=0.5,
162
+ description="(float) Temperature",
163
+ )
164
+ top_p: Union[float, None] = Field(
165
+ default=0.95,
166
+ description="(float) top p",
167
+ )
168
+ max_tokens: Union[int, None] = Field(
169
+ default=-1,
170
+ description="(int) Max tokens",
171
+ )
172
+ use_cache: bool = Field(
173
+ default=False,
174
+ description="(bool) Use cache",
175
+ )
176
+ stream: bool = Field(
177
+ default=True,
178
+ description="(bool) Stream",
179
+ )
180
+
181
+ def chat_completions_ollama(
182
+ self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
183
+ ):
184
+ try:
185
+ print(item.messages)
186
+ item.model = "llama3-8b" if item.model == "llama3" else item.model
187
+ api_key = self.auth_api_key(api_key)
188
+
189
+ if item.model == "gpt-3.5-turbo":
190
+ streamer = OpenaiStreamer()
191
+ stream_response = streamer.chat_response(messages=item.messages)
192
+ elif item.model in PRO_MODELS:
193
+ streamer = HuggingchatStreamer(model=item.model)
194
+ stream_response = streamer.chat_response(
195
+ messages=item.messages,
196
+ )
197
+ else:
198
+ streamer = HuggingfaceStreamer(model=item.model)
199
+ composer = MessageComposer(model=item.model)
200
+ composer.merge(messages=item.messages)
201
+ stream_response = streamer.chat_response(
202
+ prompt=composer.merged_str,
203
+ temperature=item.temperature,
204
+ top_p=item.top_p,
205
+ max_new_tokens=item.max_tokens,
206
+ api_key=api_key,
207
+ use_cache=item.use_cache,
208
+ )
209
+
210
+ data_response = streamer.chat_return_dict(stream_response)
211
+ return data_response
212
+ except HfApiException as e:
213
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
214
+ except Exception as e:
215
+ raise HTTPException(status_code=500, detail=str(e))
216
 
217
 
218
  class GenerateRequest(BaseModel):
 
382
  prefix + "/chat",
383
  summary="Ollama Chat completions in conversation session",
384
  include_in_schema=include_in_schema,
385
+ )(self.chat_completions_ollama)
386
 
387
  self.app.post(
388
  prefix + "/embeddings",