import torch torch.set_float32_matmul_precision('high') from flask import Flask, send_from_directory, request, Response import os import base64 import numpy as np from inference import OmniInference import io app = Flask(__name__) # Initialize OmniInference try: print("Initializing OmniInference...") omni = OmniInference() print("OmniInference initialized successfully.") except Exception as e: print(f"Error initializing OmniInference: {str(e)}") raise @app.route('/') def serve_html(): return send_from_directory('.', 'webui/omni_html_demo.html') @app.route('/chat', methods=['POST']) def chat(): try: audio_data = request.json['audio'] if not audio_data: return "No audio data received", 400 # Check if the audio_data contains the expected base64 prefix if ',' in audio_data: audio_bytes = base64.b64decode(audio_data.split(',')[1]) else: audio_bytes = base64.b64decode(audio_data) # Save audio to a temporary file temp_audio_path = 'temp_audio.wav' with open(temp_audio_path, 'wb') as f: f.write(audio_bytes) # Generate response using OmniInference try: response_generator = omni.run_AT_batch_stream(temp_audio_path) # Concatenate all audio chunks all_audio = b'' for audio_chunk in response_generator: all_audio += audio_chunk # Clean up temporary file os.remove(temp_audio_path) return Response(all_audio, mimetype='audio/wav') except Exception as inner_e: print(f"Error in OmniInference processing: {str(inner_e)}") return f"An error occurred during audio processing: {str(inner_e)}", 500 finally: # Ensure temporary file is removed even if an error occurs if os.path.exists(temp_audio_path): os.remove(temp_audio_path) except Exception as e: print(f"Error in chat endpoint: {str(e)}") return f"An error occurred: {str(e)}", 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)