FitnessEquation / app.py
ThΓ©o Rousseaux
UI
b3ff09d
raw
history blame
5.43 kB
import streamlit as st
from st_audiorec import st_audiorec
from Modules.Speech2Text.transcribe import transcribe
import base64
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
import pandas as pd
import json
from dotenv import load_dotenv
load_dotenv() # load .env api keys
import os
import shutil
from Modules.rag import rag_chain
from Modules.router import router_chain
from Modules.workout_plan import workout_chain
from Modules.PoseEstimation.pose_agent import agent_executor
mistral_api_key = os.getenv("MISTRAL_API_KEY")
from Modules.PoseEstimation import pose_estimator
from utils import save_uploaded_file, encode_video_H264
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
st.markdown("# Welcome to the AI Bros playground! πŸ‹οΈβ€β™‚οΈπŸ€–πŸ‹οΈβ€β™€οΈ")
# Create two columns
col1, col2 = st.columns(2)
video_uploaded = None
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
template =""" You are a personal coach with an expertise in nutrition and fitness.
You are having a conversation with your client, which is either a beginner or an advanced athlete.
You are motivative and have very good vibe.
Call your client buddy or sweatheart.
You are a human, not a robot.
Always try to answer concisely to the queries.
User: {question}
AI Coach:"""
)
base_chain = prompt | llm
display_workout = False
# First column containers
with col1:
st.subheader("Coach chat")
if "messages" not in st.session_state:
st.session_state.messages = []
if "file_name" not in st.session_state:
st.session_state.file_name = None
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
if prompt := st.chat_input("What's up?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant", avatar="data/AI_Bro.png"):
# Build answer from LLM
direction = router_chain.invoke({"question":prompt})
print(type(direction))
print(direction)
if direction=='fitness_advices':
with st.spinner("Retrieving relevant data..."):
response = rag_chain.invoke(
prompt
)
elif direction=='smalltalk':
with st.spinner("Thinking..."):
response = base_chain.invoke(
{"question":prompt}
).content
elif direction =='movement_analysis':
if st.session_state.file_name is not None:
prompt += "the file name is " + st.session_state.file_name
with st.spinner("Analyzing movement..."):
response = agent_executor.invoke(
{"input" : prompt}
)["output"]
else:
with st.spinner("Creating workout program..."):
response = "Sure! I just made a workout for you. Check on the table I just provided you."
json_output = workout_chain.invoke({"query":prompt})
exercises_list = json_output['exercises']
workout_df = pd.DataFrame(exercises_list)
workout_df.columns = ["exercice", "sets", "reps", "rest"]
display_workout=True
print(type(response))
st.session_state.messages.append({"role": "assistant", "content": response})
st.markdown(response)
if display_workout:
st.subheader("Workout")
st.data_editor(workout_df)
# Second column containers
with col2:
# st.subheader("Sports Agenda")
# TO DO
st.subheader("Technique Analysis")
video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
if video_uploaded:
video_uploaded = save_uploaded_file(video_uploaded)
if video_uploaded.split("/")[-1] != st.session_state.file_name:
shutil.rmtree('fig', ignore_errors=True)
shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True)
st.session_state.file_name = None
st.session_state.file_name = video_uploaded.split("/")[-1]
_left, mid, _right = st.columns([1, 2, 1])
with mid:
if os.path.exists('/home/user/.pyenv/runs'):
predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose'))
predict_list.sort()
predict_dir = predict_list[-1]
file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0]
file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name)
file_path = encode_video_H264(file_path, remove_original=True)
st.video(file_path, loop=True)
else :
st.video(video_uploaded)
if os.path.exists('fig'):
st.subheader("Graph Displayer")
file_list = os.listdir('fig')
for file in file_list:
st.image(os.path.join('fig', file))