Théo Rousseaux commited on
Commit
6eebc5e
1 Parent(s): 7e029d2

hugging face

Browse files
Files changed (2) hide show
  1. Modules/PoseEstimation/pose_agent.py +3 -1
  2. app.py +1 -1
Modules/PoseEstimation/pose_agent.py CHANGED
@@ -3,6 +3,7 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.messages import HumanMessage
5
  from langchain_mistralai.chat_models import ChatMistralAI
 
6
  import os
7
  import sys
8
  import json
@@ -11,6 +12,7 @@ from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angl
11
 
12
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
13
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
 
14
 
15
  @tool
16
  def get_keypoints_from_keypoints(video_path: str) -> str:
@@ -27,7 +29,7 @@ def get_keypoints_from_keypoints(video_path: str) -> str:
27
  save_folder='tmp'
28
  os.makedirs(save_folder, exist_ok=True)
29
  keypoints = []
30
- results = model(video_path, save=True, show_conf=False, show_boxes=False)
31
  for (i, frame) in enumerate(results):
32
  frame_dict = {}
33
  frame_dict['frame'] = i
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.messages import HumanMessage
5
  from langchain_mistralai.chat_models import ChatMistralAI
6
+ import torch
7
  import os
8
  import sys
9
  import json
 
12
 
13
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
14
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
 
17
  @tool
18
  def get_keypoints_from_keypoints(video_path: str) -> str:
 
29
  save_folder='tmp'
30
  os.makedirs(save_folder, exist_ok=True)
31
  keypoints = []
32
+ results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device)
33
  for (i, frame) in enumerate(results):
34
  frame_dict = {}
35
  frame_dict['frame'] = i
app.py CHANGED
@@ -52,7 +52,7 @@ with col1:
52
  with st.chat_message("user"):
53
  st.markdown(prompt)
54
 
55
- with st.chat_message("assistant"):
56
  # Build answer from LLM
57
  direction = router_chain.invoke({"question":prompt})
58
 
 
52
  with st.chat_message("user"):
53
  st.markdown(prompt)
54
 
55
+ with st.chat_message("assistant", avatar="data/AI_Bro.png"):
56
  # Build answer from LLM
57
  direction = router_chain.invoke({"question":prompt})
58