FitnessEquation / app.py
Doux Thibault
add router to smalltalk chain and rang chain
9aff9bb
raw
history blame
4.01 kB
import streamlit as st
from st_audiorec import st_audiorec
from Modules.Speech2Text.transcribe import transcribe
import base64
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
load_dotenv() # load .env api keys
import os
from Modules.rag import rag_chain
from Modules.router import router_chain
# from Modules.PoseEstimation.pose_agent import agent_executor
mistral_api_key = os.getenv("MISTRAL_API_KEY")
from Modules.PoseEstimation import pose_estimator
from utils import save_uploaded_file
def format_messages(messages):
formatted_messages = ""
for message in messages:
role = message["role"]
content = message["content"]
formatted_messages += f"{role}: {content}\n"
return formatted_messages
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
# Create two columns
col1, col2 = st.columns(2)
video_uploaded = None
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
template =""" You are a personal AI sports coach with an expertise in nutrition and fitness.
You are having a conversation with your client, which is either a beginner or an advanced athlete.
You must be gentle, kind, and motivative.
Always try to answer concisely to the queries.
User: {question}
AI Coach:"""
)
base_chain = prompt | llm
# First column containers
with col1:
st.subheader("Audio Recorder")
recorded = False
temp_path = 'data/temp_audio/audio_file.wav'
wav_audio_data = st_audiorec()
if wav_audio_data is not None:
with open(temp_path, 'wb') as f:
# Write the audio data to the file
f.write(wav_audio_data)
instruction = transcribe(temp_path)
print(instruction)
recorded = True
st.subheader("LLM answering")
if recorded:
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
st.session_state.messages.append({"role": "user", "content": instruction})
with st.chat_message("user"):
st.markdown(instruction)
with st.chat_message("assistant"):
# Build answer from LLM
direction = router_chain.invoke({"question":instruction})
if direction=='fitness_advices':
response = rag_chain.invoke(
instruction
)
elif direction=='smalltalk':
response = base_chain.invoke(
{"question":instruction}
).content
# elif direction =='movement_analysis':
# response = agent_executor.invoke(
# {"input" : instruction}
# )["output"]
print(type(response))
st.session_state.messages.append({"role": "assistant", "content": response})
st.markdown(response)
st.subheader("Movement Analysis")
# TO DO
# Second column containers
with col2:
st.subheader("Sports Agenda")
# TO DO
st.subheader("Video Analysis")
ask_video = st.empty()
if video_uploaded is None:
video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
if video_uploaded:
video_uploaded = save_uploaded_file(video_uploaded)
ask_video.empty()
_left, mid, _right = st.columns(3)
with mid:
st.video(video_uploaded)
apply_pose = st.button("Apply Pose Estimation")
if apply_pose:
with st.spinner("Processing video"):
keypoints = pose_estimator.get_keypoints_from_keypoints(pose_estimator.model, video_uploaded)
st.subheader("Graph Displayer")
# TO DO