kz919's picture
Update get_fn.py
b015fe9 verified
raw
history blame
941 Bytes
import os
import base64
from openai import OpenAI
import gradio as gr
from typing import Callable
def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
def fn(message, history):
inputs = preprocess(message, history)
client = OpenAI(
base_url="https://fast-cloud-snova-ai-dev-0-api.cloud.snova.ai/v1",
api_key=api_key,
)
try:
completion = client.chat.completions.create(
model=model_name,
messages=inputs["messages"],
stream=True,
)
response_text = ""
for chunk in completion:
delta = chunk.choices[0].delta.content or ""
response_text += delta
yield postprocess(response_text)
except Exception as e:
error_message = f"Error: {str(e)}"
return error_message
return fn