File size: 5,426 Bytes
d296c34
 
e87f4b7
 
025e412
9aff9bb
c104abf
 
025e412
 
 
989534c
9a30a8c
9aff9bb
c104abf
989534c
9a30a8c
025e412
a755c90
cdb28da
d296c34
9a30a8c
d296c34
9ceb6f1
b3ff09d
9ceb6f1
d296c34
 
4c18e6f
025e412
9aff9bb
9ceb6f1
9aff9bb
9ceb6f1
 
 
9aff9bb
c104abf
9aff9bb
 
 
 
d296c34
c104abf
d296c34
 
85b06be
36269bc
e87f4b7
85b06be
 
989534c
 
 
d296c34
36269bc
 
 
e87f4b7
36269bc
85b06be
e87f4b7
85b06be
e87f4b7
6eebc5e
e87f4b7
85b06be
c104abf
 
9aff9bb
36269bc
989534c
 
 
9aff9bb
989534c
 
 
 
c104abf
989534c
 
 
 
 
 
c104abf
36269bc
 
 
 
 
9ceb6f1
36269bc
9a30a8c
025e412
 
d296c34
c104abf
 
 
d296c34
 
989534c
4c18e6f
36269bc
989534c
 
4c18e6f
a755c90
989534c
 
a586a8c
989534c
 
31207dc
4c18e6f
a586a8c
e78912d
989534c
 
9ad7a80
cdb28da
 
 
85b06be
 
a755c90
85b06be
989534c
85b06be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
from st_audiorec import st_audiorec
from Modules.Speech2Text.transcribe import transcribe
import base64
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
import pandas as pd 
import json
from dotenv import load_dotenv
load_dotenv() # load .env api keys 
import os
import shutil
from Modules.rag import rag_chain
from Modules.router import router_chain
from Modules.workout_plan import workout_chain
from Modules.PoseEstimation.pose_agent import agent_executor

mistral_api_key = os.getenv("MISTRAL_API_KEY")
from Modules.PoseEstimation import pose_estimator
from utils import save_uploaded_file, encode_video_H264


st.set_page_config(layout="wide", initial_sidebar_state="collapsed")

st.markdown("# Welcome to the AI Bros playground! πŸ‹οΈβ€β™‚οΈπŸ€–πŸ‹οΈβ€β™€οΈ")

# Create two columns
col1, col2 = st.columns(2)
video_uploaded = None
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
    template =""" You are a personal coach with an expertise in nutrition and fitness. 
    You are having a conversation with your client, which is either a beginner or an advanced athlete. 
    You are motivative and have very good vibe.
    Call your client buddy or sweatheart.
    You are a human, not a robot.
    Always try to answer concisely to the queries.
    
    User: {question}
    AI Coach:"""
)
base_chain = prompt | llm 

display_workout = False
# First column containers
with col1:
    
    st.subheader("Coach chat")

    if "messages" not in st.session_state:
        st.session_state.messages = []
    
    if "file_name" not in st.session_state:
        st.session_state.file_name = None

    # for message in st.session_state.messages:
    #     with st.chat_message(message["role"]):
    #         st.markdown(message["content"])

    if prompt := st.chat_input("What's up?"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant", avatar="data/AI_Bro.png"):
            # Build answer from LLM
            direction = router_chain.invoke({"question":prompt})
            print(type(direction))
            print(direction)
            if direction=='fitness_advices':
                with st.spinner("Retrieving relevant data..."):
                    response = rag_chain.invoke(
                                prompt
                                )
            elif direction=='smalltalk':
                with st.spinner("Thinking..."):
                    response = base_chain.invoke(
                        {"question":prompt}
                    ).content
            elif direction =='movement_analysis':
                if st.session_state.file_name is not None:
                    prompt += "the file name is " + st.session_state.file_name
                with st.spinner("Analyzing movement..."):
                    response = agent_executor.invoke(
                        {"input" : prompt}
                    )["output"]
            else:
                with st.spinner("Creating workout program..."):
                    response = "Sure! I just made a workout for you. Check on the table I just provided you."
                    json_output = workout_chain.invoke({"query":prompt})
                    exercises_list = json_output['exercises']
                    workout_df = pd.DataFrame(exercises_list)
                    workout_df.columns = ["exercice", "sets", "reps", "rest"]
                    display_workout=True
            print(type(response))
            st.session_state.messages.append({"role": "assistant", "content": response})
            st.markdown(response)

    if display_workout:
        st.subheader("Workout")
        st.data_editor(workout_df)
# Second column containers
with col2:
    # st.subheader("Sports Agenda")
        # TO DO
    st.subheader("Technique Analysis")

    video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
    if video_uploaded:
        video_uploaded = save_uploaded_file(video_uploaded)
        if video_uploaded.split("/")[-1] != st.session_state.file_name:
            shutil.rmtree('fig', ignore_errors=True)
            shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True)
            st.session_state.file_name = None
        st.session_state.file_name = video_uploaded.split("/")[-1]
        _left, mid, _right = st.columns([1, 2, 1])
        with mid:
            if os.path.exists('/home/user/.pyenv/runs'):
                predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose'))
                predict_list.sort()
                predict_dir = predict_list[-1]
                file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0]
                file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name)
                file_path = encode_video_H264(file_path, remove_original=True)
                st.video(file_path, loop=True)
            else :
                st.video(video_uploaded)
            
    if os.path.exists('fig'):
        st.subheader("Graph Displayer")
        file_list = os.listdir('fig')
        for file in file_list:
            st.image(os.path.join('fig', file))