import streamlit as st from st_audiorec import st_audiorec from Modules.Speech2Text.transcribe import transcribe import base64 from langchain_mistralai import ChatMistralAI from langchain_core.prompts import ChatPromptTemplate import pandas as pd import json from dotenv import load_dotenv load_dotenv() # load .env api keys import os import shutil from Modules.rag import rag_chain from Modules.router import router_chain from Modules.workout_plan import workout_chain from Modules.PoseEstimation.pose_agent import agent_executor mistral_api_key = os.getenv("MISTRAL_API_KEY") from Modules.PoseEstimation import pose_estimator from utils import save_uploaded_file, encode_video_H264 st.set_page_config(layout="wide", initial_sidebar_state="collapsed") st.markdown("# Welcome to the AI Bros playground! 🏋️‍♂️🤖🏋️‍♀️") # Create two columns col1, col2 = st.columns(2) video_uploaded = None llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0) prompt = ChatPromptTemplate.from_template( template =""" You are a personal coach with an expertise in nutrition and fitness. You are having a conversation with your client, which is either a beginner or an advanced athlete. You are motivative and have very good vibe. Call your client buddy or sweatheart. You are a human, not a robot. Always try to answer concisely to the queries. User: {question} AI Coach:""" ) base_chain = prompt | llm display_workout = False # First column containers with col1: st.subheader("Coach chat") if "messages" not in st.session_state: st.session_state.messages = [] if "file_name" not in st.session_state: st.session_state.file_name = None # for message in st.session_state.messages: # with st.chat_message(message["role"]): # st.markdown(message["content"]) if prompt := st.chat_input("What's up?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant", avatar="data/AI_Bro.png"): # Build answer from LLM direction = router_chain.invoke({"question":prompt}) print(type(direction)) print(direction) if direction=='fitness_advices': with st.spinner("Retrieving relevant data..."): response = rag_chain.invoke( prompt ) elif direction=='smalltalk': with st.spinner("Thinking..."): response = base_chain.invoke( {"question":prompt} ).content elif direction =='movement_analysis': if st.session_state.file_name is not None: prompt += "the file name is " + st.session_state.file_name with st.spinner("Analyzing movement..."): response = agent_executor.invoke( {"input" : prompt} )["output"] else: with st.spinner("Creating workout program..."): response = "Sure! I just made a workout for you. Check on the table I just provided you." json_output = workout_chain.invoke({"query":prompt}) exercises_list = json_output['exercises'] workout_df = pd.DataFrame(exercises_list) workout_df.columns = ["exercice", "sets", "reps", "rest"] display_workout=True print(type(response)) st.session_state.messages.append({"role": "assistant", "content": response}) st.markdown(response) if display_workout: st.subheader("Workout") st.data_editor(workout_df) # Second column containers with col2: # st.subheader("Sports Agenda") # TO DO st.subheader("Technique Analysis") video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"]) if video_uploaded: video_uploaded = save_uploaded_file(video_uploaded) if video_uploaded.split("/")[-1] != st.session_state.file_name: shutil.rmtree('fig', ignore_errors=True) shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True) st.session_state.file_name = None st.session_state.file_name = video_uploaded.split("/")[-1] _left, mid, _right = st.columns([1, 2, 1]) with mid: if os.path.exists('/home/user/.pyenv/runs'): predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose')) predict_list.sort() predict_dir = predict_list[-1] file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0] file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name) file_path = encode_video_H264(file_path, remove_original=True) st.video(file_path, loop=True) else : st.video(video_uploaded) if os.path.exists('fig'): st.subheader("Graph Displayer") file_list = os.listdir('fig') for file in file_list: st.image(os.path.join('fig', file))