import gradio as gr import websockets import asyncio import json import base64 from PIL import Image import io async def process_image_stream(image_path, prompt, max_tokens=512): """ Process image with streaming response via WebSocket """ if not image_path: yield "Please upload an image first." return try: # Read and convert image to base64 with Image.open(image_path) as img: img = img.convert('RGB') buffer = io.BytesIO() img.save(buffer, format="JPEG") base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') # Connect to WebSocket async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/') as websocket: # Send image data and parameters as JSON await websocket.send(json.dumps({ "image": f"data:image/jpeg;base64,{base64_image}", "prompt": prompt, "task": "instruct", # Fixed to instruct "max_tokens": max_tokens })) # Initialize response and token counter response = "" token_count = 0 # Receive streaming response async for message in websocket: try: data = json.loads(message) if data["status"] == "generating": # Skip first three tokens if they match specific patterns if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]: token_count += 1 continue response += data["token"] yield response elif data["status"] == "complete": break elif data["status"] == "error": yield f"Error: {data['error']}" break except json.JSONDecodeError: continue except Exception as e: yield f"Error connecting to server: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=process_image_stream, inputs=[ gr.Image(type="filepath", label="Upload Image"), gr.Textbox( label="Question", placeholder="Ask a question about the image...", value="Describe this image" ), gr.Slider( minimum=50, maximum=200, value=200, step=1, label="Max Tokens" ) ], outputs=gr.Textbox(label="Response", interactive=False), title="Nexa Omni Vision", description=""" *Model updated on Nov 21, 2024\n Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries. """, examples=[ ["example_images/example_1.jpg", "What kind of cat is this?", 128], ["example_images/example_2.jpg", "What color is this dress? ", 128], ["example_images/example_3.jpg", "What is this image about?", 128], ] ) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)