Spaces:
Sleeping
Sleeping
jonathanjordan21
commited on
Commit
•
162773e
1
Parent(s):
6072755
Update apis/chat_api.py
Browse files- apis/chat_api.py +68 -1
apis/chat_api.py
CHANGED
@@ -146,6 +146,73 @@ class ChatAPIApp:
|
|
146 |
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
147 |
except Exception as e:
|
148 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
class GenerateRequest(BaseModel):
|
@@ -315,7 +382,7 @@ class ChatAPIApp:
|
|
315 |
prefix + "/chat",
|
316 |
summary="Ollama Chat completions in conversation session",
|
317 |
include_in_schema=include_in_schema,
|
318 |
-
)(self.
|
319 |
|
320 |
self.app.post(
|
321 |
prefix + "/embeddings",
|
|
|
146 |
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
147 |
except Exception as e:
|
148 |
raise HTTPException(status_code=500, detail=str(e))
|
149 |
+
|
150 |
+
|
151 |
+
class ChatCompletionsPostItem(BaseModel):
|
152 |
+
model: str = Field(
|
153 |
+
default="nous-mixtral-8x7b",
|
154 |
+
description="(str) `nous-mixtral-8x7b`",
|
155 |
+
)
|
156 |
+
messages: list = Field(
|
157 |
+
default=[{"role": "user", "content": "Hello, who are you?"}],
|
158 |
+
description="(list) Messages",
|
159 |
+
)
|
160 |
+
temperature: Union[float, None] = Field(
|
161 |
+
default=0.5,
|
162 |
+
description="(float) Temperature",
|
163 |
+
)
|
164 |
+
top_p: Union[float, None] = Field(
|
165 |
+
default=0.95,
|
166 |
+
description="(float) top p",
|
167 |
+
)
|
168 |
+
max_tokens: Union[int, None] = Field(
|
169 |
+
default=-1,
|
170 |
+
description="(int) Max tokens",
|
171 |
+
)
|
172 |
+
use_cache: bool = Field(
|
173 |
+
default=False,
|
174 |
+
description="(bool) Use cache",
|
175 |
+
)
|
176 |
+
stream: bool = Field(
|
177 |
+
default=True,
|
178 |
+
description="(bool) Stream",
|
179 |
+
)
|
180 |
+
|
181 |
+
def chat_completions_ollama(
|
182 |
+
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
183 |
+
):
|
184 |
+
try:
|
185 |
+
print(item.messages)
|
186 |
+
item.model = "llama3-8b" if item.model == "llama3" else item.model
|
187 |
+
api_key = self.auth_api_key(api_key)
|
188 |
+
|
189 |
+
if item.model == "gpt-3.5-turbo":
|
190 |
+
streamer = OpenaiStreamer()
|
191 |
+
stream_response = streamer.chat_response(messages=item.messages)
|
192 |
+
elif item.model in PRO_MODELS:
|
193 |
+
streamer = HuggingchatStreamer(model=item.model)
|
194 |
+
stream_response = streamer.chat_response(
|
195 |
+
messages=item.messages,
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
streamer = HuggingfaceStreamer(model=item.model)
|
199 |
+
composer = MessageComposer(model=item.model)
|
200 |
+
composer.merge(messages=item.messages)
|
201 |
+
stream_response = streamer.chat_response(
|
202 |
+
prompt=composer.merged_str,
|
203 |
+
temperature=item.temperature,
|
204 |
+
top_p=item.top_p,
|
205 |
+
max_new_tokens=item.max_tokens,
|
206 |
+
api_key=api_key,
|
207 |
+
use_cache=item.use_cache,
|
208 |
+
)
|
209 |
+
|
210 |
+
data_response = streamer.chat_return_dict(stream_response)
|
211 |
+
return data_response
|
212 |
+
except HfApiException as e:
|
213 |
+
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
214 |
+
except Exception as e:
|
215 |
+
raise HTTPException(status_code=500, detail=str(e))
|
216 |
|
217 |
|
218 |
class GenerateRequest(BaseModel):
|
|
|
382 |
prefix + "/chat",
|
383 |
summary="Ollama Chat completions in conversation session",
|
384 |
include_in_schema=include_in_schema,
|
385 |
+
)(self.chat_completions_ollama)
|
386 |
|
387 |
self.app.post(
|
388 |
prefix + "/embeddings",
|