|
import os |
|
import base64 |
|
from openai import OpenAI |
|
import gradio as gr |
|
from typing import Callable |
|
|
|
def get_fn(model_name: str, preprocess: Callable, postprocess: Callable, api_key: str): |
|
def fn(message, history): |
|
inputs = preprocess(message, history) |
|
client = OpenAI( |
|
base_url="https://api.sambanova.ai/v1/", |
|
api_key=api_key, |
|
) |
|
try: |
|
completion = client.chat.completions.create( |
|
model=model_name, |
|
messages=inputs["messages"], |
|
stream=True, |
|
) |
|
response_text = "" |
|
for chunk in completion: |
|
delta = chunk.choices[0].delta.content or "" |
|
response_text += delta |
|
yield postprocess(response_text) |
|
except Exception as e: |
|
error_message = f"Error: {str(e)}" |
|
return error_message |
|
|
|
return fn |