kz919's picture
Create get_fn.py
a9b015c verified
raw
history blame
914 Bytes
import os
import base64
from openai import OpenAI
import gradio as gr
from typing import Callable
def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str):
def fn(message, history):
inputs = preprocess(message, history)
client = OpenAI(
base_url="https://api.sambanova.ai/v1/",
api_key=api_key,
)
try:
completion = client.chat.completions.create(
model=model_name,
messages=inputs["messages"],
stream=True,
)
response_text = ""
for chunk in completion:
delta = chunk.choices[0].delta.content or ""
response_text += delta
yield postprocess(response_text)
except Exception as e:
error_message = f"Error: {str(e)}"
return error_message
return fn