PerryCheng614's picture
Update app.py
3388a44 verified
import gradio as gr
import websockets
import asyncio
import json
import base64
from PIL import Image
import io
async def process_image_stream(image_path, prompt, max_tokens=512):
"""
Process image with streaming response via WebSocket
"""
if not image_path:
yield "Please upload an image first."
return
try:
# Read and convert image to base64
with Image.open(image_path) as img:
img = img.convert('RGB')
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
# Connect to WebSocket
async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/') as websocket:
# Send image data and parameters as JSON
await websocket.send(json.dumps({
"image": f"data:image/jpeg;base64,{base64_image}",
"prompt": prompt,
"task": "instruct", # Fixed to instruct
"max_tokens": max_tokens
}))
# Initialize response and token counter
response = ""
token_count = 0
# Receive streaming response
async for message in websocket:
try:
data = json.loads(message)
if data["status"] == "generating":
# Skip first three tokens if they match specific patterns
if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
token_count += 1
continue
response += data["token"]
yield response
elif data["status"] == "complete":
break
elif data["status"] == "error":
yield f"Error: {data['error']}"
break
except json.JSONDecodeError:
continue
except Exception as e:
yield f"Error connecting to server: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=process_image_stream,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Textbox(
label="Question",
placeholder="Ask a question about the image...",
value="Describe this image"
),
gr.Slider(
minimum=50,
maximum=200,
value=200,
step=1,
label="Max Tokens"
)
],
outputs=gr.Textbox(label="Response", interactive=False),
title="Nexa Omni Vision",
description="""
*Model updated on Nov 21, 2024\n
Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
""",
examples=[
["example_images/example_1.jpg", "What kind of cat is this?", 128],
["example_images/example_2.jpg", "What color is this dress? ", 128],
["example_images/example_3.jpg", "What is this image about?", 128],
]
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)