Spaces:
Sleeping
Sleeping
feat: add new 5 tools to the agent
Browse files- app.py +10 -6
- tools/__init__.py +14 -75
- tools/driver_performance.py +88 -0
- tools/event_performance.py +74 -0
- tools/telemetry_analysis.py +77 -0
- tools/tyre_performance.py +82 -0
- tools/weather_impact.py +76 -0
app.py
CHANGED
@@ -10,15 +10,10 @@ from rich.console import Console
|
|
10 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
from gradio import ChatMessage
|
12 |
import textwrap
|
13 |
-
from tools import
|
14 |
load_dotenv()
|
15 |
os.environ['LANGCHAIN_PROJECT'] = 'gradio-test'
|
16 |
|
17 |
-
console = Console(style="chartreuse1 on grey7")
|
18 |
-
|
19 |
-
# * Initialize database
|
20 |
-
db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
|
21 |
-
|
22 |
# * Initialize LLM
|
23 |
llm = ChatGoogleGenerativeAI(
|
24 |
model="gemini-1.5-flash",
|
@@ -32,8 +27,17 @@ llm = ChatGoogleGenerativeAI(
|
|
32 |
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
33 |
tools = toolkit.get_tools()
|
34 |
|
|
|
|
|
35 |
get_telemetry_tool = GetTelemetry()
|
|
|
|
|
|
|
|
|
|
|
36 |
tools.append(get_telemetry_tool)
|
|
|
|
|
37 |
|
38 |
# * Initialize agent
|
39 |
agent_prompt = open("agent_prompt.txt", "r")
|
|
|
10 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
from gradio import ChatMessage
|
12 |
import textwrap
|
13 |
+
from tools import *
|
14 |
load_dotenv()
|
15 |
os.environ['LANGCHAIN_PROJECT'] = 'gradio-test'
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
# * Initialize LLM
|
18 |
llm = ChatGoogleGenerativeAI(
|
19 |
model="gemini-1.5-flash",
|
|
|
27 |
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
28 |
tools = toolkit.get_tools()
|
29 |
|
30 |
+
get_driver_performance_tool = GetDriverPerformance()
|
31 |
+
get_event_performance_tool = GetEventPerformance()
|
32 |
get_telemetry_tool = GetTelemetry()
|
33 |
+
get_tyre_performance_tool = GetTyrePerformance()
|
34 |
+
get_weather_impact_tool = GetWeatherImpact()
|
35 |
+
|
36 |
+
tools.append(get_driver_performance_tool)
|
37 |
+
tools.append(get_event_performance_tool)
|
38 |
tools.append(get_telemetry_tool)
|
39 |
+
tools.append(get_tyre_performance_tool)
|
40 |
+
tools.append(get_weather_impact_tool)
|
41 |
|
42 |
# * Initialize agent
|
43 |
agent_prompt = open("agent_prompt.txt", "r")
|
tools/__init__.py
CHANGED
@@ -1,82 +1,21 @@
|
|
1 |
-
from pydantic import BaseModel, Field
|
2 |
-
from typing import Type
|
3 |
-
from langchain_core.tools import BaseTool
|
4 |
from langchain_community.utilities import SQLDatabase
|
5 |
from rich.console import Console
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
console = Console(style="chartreuse1 on grey7")
|
8 |
|
9 |
db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
""
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
"""Output for the get_telemetry_and_weather tool"""
|
21 |
-
lap_id: int = Field(description="Lap ID")
|
22 |
-
lap_number: int = Field(description="Lap number")
|
23 |
-
lap_time_in_seconds: float | None = Field(
|
24 |
-
description="Lap time in seconds")
|
25 |
-
avg_speed: float = Field(description="Average speed in km/h")
|
26 |
-
max_speed: float = Field(description="Maximum speed in km/h")
|
27 |
-
avg_RPM: float = Field(description="Average RPM")
|
28 |
-
max_RPM: float = Field(description="Maximum RPM")
|
29 |
-
avg_throttle: float = Field(description="Average throttle")
|
30 |
-
brake_percentage: float = Field(description="Brake percentage")
|
31 |
-
drs_usage_percentage: float = Field(description="Drs usage percentage")
|
32 |
-
off_track_percentage: float = Field(description="Off track percentage")
|
33 |
-
avg_air_temp: float | None = Field(
|
34 |
-
description="Average air temperature in celsius")
|
35 |
-
avg_track_temp: float | None = Field(
|
36 |
-
description="Average track temperature in celsius")
|
37 |
-
avg_wind_speed: float | None = Field(
|
38 |
-
description="Average wind speed in meters per second")
|
39 |
-
|
40 |
-
|
41 |
-
class GetTelemetry(BaseTool):
|
42 |
-
name: str = "get_telemetry"
|
43 |
-
description: str = "useful for when you need to answer questions about telemetry for a given driver and lap"
|
44 |
-
args_schema: Type[BaseModel] = GetTelemetryAndWeatherInput
|
45 |
-
|
46 |
-
def _run(
|
47 |
-
self, driver_name: str, lap_number: int
|
48 |
-
) -> GetTelemetryAndWeatherOutput:
|
49 |
-
# """Use the tool."""
|
50 |
-
sql_file = open("tools/telemetry_and_weather_query.sql", "r")
|
51 |
-
sql_query = sql_file.read()
|
52 |
-
sql_file.close()
|
53 |
-
console.print("getting telemetry")
|
54 |
-
response = db.run(sql_query, parameters={
|
55 |
-
"driver_name": driver_name,
|
56 |
-
"lap_number": lap_number})
|
57 |
-
|
58 |
-
if not isinstance(response, str):
|
59 |
-
response = str(response)
|
60 |
-
|
61 |
-
clean_response = response.strip('[]()').split(',')
|
62 |
-
# Convert to appropriate types and create dictionary
|
63 |
-
return GetTelemetryAndWeatherOutput(
|
64 |
-
lap_id=int(float(clean_response[0])),
|
65 |
-
lap_number=int(float(clean_response[1])),
|
66 |
-
lap_time_in_seconds=float(
|
67 |
-
clean_response[2]) if clean_response[2].strip() != 'None' else None,
|
68 |
-
avg_speed=float(clean_response[3]),
|
69 |
-
max_speed=float(clean_response[4]),
|
70 |
-
avg_RPM=float(clean_response[5]),
|
71 |
-
max_RPM=float(clean_response[6]),
|
72 |
-
avg_throttle=float(clean_response[7]),
|
73 |
-
brake_percentage=float(clean_response[8]),
|
74 |
-
drs_usage_percentage=float(clean_response[9]),
|
75 |
-
off_track_percentage=float(clean_response[10]),
|
76 |
-
avg_air_temp=float(
|
77 |
-
clean_response[11]) if clean_response[11].strip() != 'None' else None,
|
78 |
-
avg_track_temp=float(
|
79 |
-
clean_response[12]) if clean_response[12].strip() != 'None' else None,
|
80 |
-
avg_wind_speed=float(
|
81 |
-
clean_response[13]) if clean_response[13].strip() != 'None' else None
|
82 |
-
)
|
|
|
|
|
|
|
|
|
1 |
from langchain_community.utilities import SQLDatabase
|
2 |
from rich.console import Console
|
3 |
+
from .driver_performance import GetDriverPerformance
|
4 |
+
from .event_performance import GetEventPerformance
|
5 |
+
from .telemetry_analysis import GetTelemetry
|
6 |
+
from .tyre_performance import GetTyrePerformance
|
7 |
+
from .weather_impact import GetWeatherImpact
|
8 |
|
9 |
console = Console(style="chartreuse1 on grey7")
|
10 |
|
11 |
db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
|
12 |
|
13 |
+
__all__ = [
|
14 |
+
"GetDriverPerformance",
|
15 |
+
"GetEventPerformance",
|
16 |
+
"GetTelemetry",
|
17 |
+
"GetTyrePerformance",
|
18 |
+
"GetWeatherImpact",
|
19 |
+
"console",
|
20 |
+
"db"
|
21 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/driver_performance.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from langchain_core.tools import BaseTool
|
3 |
+
from . import db, console
|
4 |
+
|
5 |
+
|
6 |
+
class GetDriverPerformanceOutput(BaseModel):
|
7 |
+
"""Output for the get_driver_performance tool"""
|
8 |
+
driver_name: str = Field(description="Name of the driver")
|
9 |
+
event_name: str = Field(description="Name of the event")
|
10 |
+
session_type: str = Field(
|
11 |
+
description="Type of session (Practice, Qualifying, Race)")
|
12 |
+
track_name: str = Field(description="Name of the track")
|
13 |
+
total_laps: int = Field(description="Total number of laps completed")
|
14 |
+
avg_lap_time: float | None = Field(
|
15 |
+
description="Average lap time in seconds")
|
16 |
+
best_lap_time: float | None = Field(description="Best lap time in seconds")
|
17 |
+
avg_sector1_time: float | None = Field(
|
18 |
+
description="Average sector 1 time in seconds")
|
19 |
+
avg_sector2_time: float | None = Field(
|
20 |
+
description="Average sector 2 time in seconds")
|
21 |
+
avg_sector3_time: float | None = Field(
|
22 |
+
description="Average sector 3 time in seconds")
|
23 |
+
avg_finish_line_speed: float | None = Field(
|
24 |
+
description="Average finish line speed in km/h")
|
25 |
+
personal_best_laps: int = Field(description="Number of personal best laps")
|
26 |
+
avg_air_temp: float | None = Field(
|
27 |
+
description="Average air temperature in celsius")
|
28 |
+
avg_track_temp: float | None = Field(
|
29 |
+
description="Average track temperature in celsius")
|
30 |
+
rain_percentage: float = Field(
|
31 |
+
description="Percentage of time it rained during the session")
|
32 |
+
|
33 |
+
|
34 |
+
class GetDriverPerformance(BaseTool):
|
35 |
+
name: str = "get_driver_performance"
|
36 |
+
description: str = "useful for when you need to analyze driver performance statistics across different sessions and events"
|
37 |
+
|
38 |
+
def _run(self) -> list[GetDriverPerformanceOutput]:
|
39 |
+
"""Use the tool."""
|
40 |
+
sql_file = open("tools/sql/driver_performance.query.sql", "r")
|
41 |
+
sql_query = sql_file.read()
|
42 |
+
sql_file.close()
|
43 |
+
|
44 |
+
console.print("getting driver performance data")
|
45 |
+
response = db.run(sql_query)
|
46 |
+
|
47 |
+
if not isinstance(response, str):
|
48 |
+
response = str(response)
|
49 |
+
|
50 |
+
# Remove the outer brackets and split by rows
|
51 |
+
rows = response.strip('[]').split('), (')
|
52 |
+
|
53 |
+
results = []
|
54 |
+
for row in rows:
|
55 |
+
# Clean up the row string and split by columns
|
56 |
+
clean_row = row.strip('()').split(',')
|
57 |
+
|
58 |
+
# Convert to appropriate types and create output object
|
59 |
+
driver_data = GetDriverPerformanceOutput(
|
60 |
+
driver_name=clean_row[0].strip("'"),
|
61 |
+
event_name=clean_row[1].strip("'"),
|
62 |
+
session_type=clean_row[2].strip("'"),
|
63 |
+
track_name=clean_row[3].strip("'"),
|
64 |
+
total_laps=int(float(clean_row[4])),
|
65 |
+
avg_lap_time=float(
|
66 |
+
clean_row[5]) if clean_row[5].strip() != 'None' else None,
|
67 |
+
best_lap_time=float(
|
68 |
+
clean_row[6]) if clean_row[6].strip() != 'None' else None,
|
69 |
+
avg_sector1_time=float(
|
70 |
+
clean_row[7]) if clean_row[7].strip() != 'None' else None,
|
71 |
+
avg_sector2_time=float(
|
72 |
+
clean_row[8]) if clean_row[8].strip() != 'None' else None,
|
73 |
+
avg_sector3_time=float(
|
74 |
+
clean_row[9]) if clean_row[9].strip() != 'None' else None,
|
75 |
+
avg_finish_line_speed=float(
|
76 |
+
clean_row[10]) if clean_row[10].strip() != 'None' else None,
|
77 |
+
personal_best_laps=int(
|
78 |
+
float(clean_row[11])) if clean_row[11].strip() != 'None' else 0,
|
79 |
+
avg_air_temp=float(
|
80 |
+
clean_row[12]) if clean_row[12].strip() != 'None' else None,
|
81 |
+
avg_track_temp=float(
|
82 |
+
clean_row[13]) if clean_row[13].strip() != 'None' else None,
|
83 |
+
rain_percentage=float(
|
84 |
+
clean_row[14]) if clean_row[14].strip() != 'None' else 0.0
|
85 |
+
)
|
86 |
+
results.append(driver_data)
|
87 |
+
|
88 |
+
return results
|
tools/event_performance.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Type
|
3 |
+
from langchain_core.tools import BaseTool
|
4 |
+
from . import db, console
|
5 |
+
|
6 |
+
|
7 |
+
class GetEventPerformanceOutput(BaseModel):
|
8 |
+
"""Output for the get_event_performance tool"""
|
9 |
+
event_name: str = Field(description="Name of the event")
|
10 |
+
country: str = Field(description="Country where the event took place")
|
11 |
+
location: str = Field(description="Specific location of the event")
|
12 |
+
session_type: str = Field(
|
13 |
+
description="Type of session (Practice, Qualifying, Race)")
|
14 |
+
driver_count: int = Field(
|
15 |
+
description="Number of drivers that participated")
|
16 |
+
avg_lap_time: float = Field(description="Average lap time in seconds")
|
17 |
+
best_lap_time: float = Field(description="Best lap time in seconds")
|
18 |
+
max_finish_line_speed: float = Field(
|
19 |
+
description="Maximum speed at finish line in km/h")
|
20 |
+
avg_air_temp: float | None = Field(
|
21 |
+
description="Average air temperature in celsius")
|
22 |
+
avg_track_temp: float | None = Field(
|
23 |
+
description="Average track temperature in celsius")
|
24 |
+
rain_percentage: float = Field(
|
25 |
+
description="Percentage of time it rained during the session")
|
26 |
+
|
27 |
+
|
28 |
+
class GetEventPerformance(BaseTool):
|
29 |
+
name: str = "get_event_performance"
|
30 |
+
description: str = "useful for when you need to get performance statistics for Formula 1 events"
|
31 |
+
|
32 |
+
def _run(self) -> list[GetEventPerformanceOutput]:
|
33 |
+
"""Use the tool."""
|
34 |
+
sql_file = open("tools/sql/event_performance.query.sql", "r")
|
35 |
+
sql_query = sql_file.read()
|
36 |
+
sql_file.close()
|
37 |
+
|
38 |
+
console.print("getting event performance data")
|
39 |
+
response = db.run(sql_query)
|
40 |
+
|
41 |
+
if not isinstance(response, str):
|
42 |
+
response = str(response)
|
43 |
+
|
44 |
+
# Remove the outer brackets and split by rows
|
45 |
+
rows = response.strip('[]').split('), (')
|
46 |
+
|
47 |
+
results = []
|
48 |
+
for row in rows:
|
49 |
+
# Clean up the row string and split by columns
|
50 |
+
clean_row = row.strip('()').split(',')
|
51 |
+
|
52 |
+
# Convert to appropriate types and create output object
|
53 |
+
event_data = GetEventPerformanceOutput(
|
54 |
+
event_name=clean_row[0].strip("'"),
|
55 |
+
country=clean_row[1].strip("'"),
|
56 |
+
location=clean_row[2].strip("'"),
|
57 |
+
session_type=clean_row[3].strip("'"),
|
58 |
+
driver_count=int(float(clean_row[4])),
|
59 |
+
avg_lap_time=float(
|
60 |
+
clean_row[5]) if clean_row[5].strip() != 'None' else 0.0,
|
61 |
+
best_lap_time=float(
|
62 |
+
clean_row[6]) if clean_row[6].strip() != 'None' else 0.0,
|
63 |
+
max_finish_line_speed=float(
|
64 |
+
clean_row[7]) if clean_row[7].strip() != 'None' else 0.0,
|
65 |
+
avg_air_temp=float(
|
66 |
+
clean_row[8]) if clean_row[8].strip() != 'None' else None,
|
67 |
+
avg_track_temp=float(
|
68 |
+
clean_row[9]) if clean_row[9].strip() != 'None' else None,
|
69 |
+
rain_percentage=float(
|
70 |
+
clean_row[10]) if clean_row[10].strip() != 'None' else 0.0
|
71 |
+
)
|
72 |
+
results.append(event_data)
|
73 |
+
|
74 |
+
return results
|
tools/telemetry_analysis.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Type
|
3 |
+
from langchain_core.tools import BaseTool
|
4 |
+
from . import db, console
|
5 |
+
|
6 |
+
|
7 |
+
class GetTelemetryAndWeatherInput(BaseModel):
|
8 |
+
"""Input for the get_telemetry_and_weather tool"""
|
9 |
+
driver_name: str = Field(
|
10 |
+
description="Name of the driver to analyze (e.g., 'VER', 'HAM', 'LEC', etc.)")
|
11 |
+
lap_number: int = Field(description="Lap number to analyze")
|
12 |
+
|
13 |
+
|
14 |
+
class GetTelemetryAndWeatherOutput(BaseModel):
|
15 |
+
"""Output for the get_telemetry_and_weather tool"""
|
16 |
+
lap_id: int = Field(description="Lap ID")
|
17 |
+
lap_number: int = Field(description="Lap number")
|
18 |
+
lap_time_in_seconds: float | None = Field(
|
19 |
+
description="Lap time in seconds")
|
20 |
+
avg_speed: float = Field(description="Average speed in km/h")
|
21 |
+
max_speed: float = Field(description="Maximum speed in km/h")
|
22 |
+
avg_RPM: float = Field(description="Average RPM")
|
23 |
+
max_RPM: float = Field(description="Maximum RPM")
|
24 |
+
avg_throttle: float = Field(description="Average throttle")
|
25 |
+
brake_percentage: float = Field(description="Brake percentage")
|
26 |
+
drs_usage_percentage: float = Field(description="Drs usage percentage")
|
27 |
+
off_track_percentage: float = Field(description="Off track percentage")
|
28 |
+
avg_air_temp: float | None = Field(
|
29 |
+
description="Average air temperature in celsius")
|
30 |
+
avg_track_temp: float | None = Field(
|
31 |
+
description="Average track temperature in celsius")
|
32 |
+
avg_wind_speed: float | None = Field(
|
33 |
+
description="Average wind speed in meters per second")
|
34 |
+
|
35 |
+
|
36 |
+
class GetTelemetry(BaseTool):
|
37 |
+
name: str = "get_telemetry"
|
38 |
+
description: str = "useful for when you need to answer questions about telemetry for a given driver and lap"
|
39 |
+
args_schema: Type[BaseModel] = GetTelemetryAndWeatherInput
|
40 |
+
|
41 |
+
def _run(
|
42 |
+
self, driver_name: str, lap_number: int
|
43 |
+
) -> GetTelemetryAndWeatherOutput:
|
44 |
+
# """Use the tool."""
|
45 |
+
sql_file = open("tools/sql/telemetry_analysis.query.sql", "r")
|
46 |
+
sql_query = sql_file.read()
|
47 |
+
sql_file.close()
|
48 |
+
console.print("getting telemetry")
|
49 |
+
response = db.run(sql_query, parameters={
|
50 |
+
"driver_name": driver_name,
|
51 |
+
"lap_number": lap_number})
|
52 |
+
|
53 |
+
if not isinstance(response, str):
|
54 |
+
response = str(response)
|
55 |
+
|
56 |
+
clean_response = response.strip('[]()').split(',')
|
57 |
+
# Convert to appropriate types and create dictionary
|
58 |
+
return GetTelemetryAndWeatherOutput(
|
59 |
+
lap_id=int(float(clean_response[0])),
|
60 |
+
lap_number=int(float(clean_response[1])),
|
61 |
+
lap_time_in_seconds=float(
|
62 |
+
clean_response[2]) if clean_response[2].strip() != 'None' else None,
|
63 |
+
avg_speed=float(clean_response[3]),
|
64 |
+
max_speed=float(clean_response[4]),
|
65 |
+
avg_RPM=float(clean_response[5]),
|
66 |
+
max_RPM=float(clean_response[6]),
|
67 |
+
avg_throttle=float(clean_response[7]),
|
68 |
+
brake_percentage=float(clean_response[8]),
|
69 |
+
drs_usage_percentage=float(clean_response[9]),
|
70 |
+
off_track_percentage=float(clean_response[10]),
|
71 |
+
avg_air_temp=float(
|
72 |
+
clean_response[11]) if clean_response[11].strip() != 'None' else None,
|
73 |
+
avg_track_temp=float(
|
74 |
+
clean_response[12]) if clean_response[12].strip() != 'None' else None,
|
75 |
+
avg_wind_speed=float(
|
76 |
+
clean_response[13]) if clean_response[13].strip() != 'None' else None
|
77 |
+
)
|
tools/tyre_performance.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Type
|
3 |
+
from langchain_core.tools import BaseTool
|
4 |
+
from . import db, console
|
5 |
+
|
6 |
+
|
7 |
+
class GetTyrePerformanceInput(BaseModel):
|
8 |
+
"""Input for the get_tyre_performance tool"""
|
9 |
+
driver_name: str = Field(description="Name of the driver to analyze")
|
10 |
+
|
11 |
+
|
12 |
+
class GetTyrePerformanceOutput(BaseModel):
|
13 |
+
"""Output for the get_tyre_performance tool"""
|
14 |
+
driver_name: str = Field(description="Name of the driver")
|
15 |
+
lap_number: int = Field(description="Lap number")
|
16 |
+
tyre_compound: str = Field(description="Type of tyre compound used")
|
17 |
+
avg_tyre_life: float | None = Field(
|
18 |
+
description="Average tyre life in laps")
|
19 |
+
avg_lap_time: float | None = Field(
|
20 |
+
description="Average lap time in seconds")
|
21 |
+
avg_top_speed: float | None = Field(
|
22 |
+
description="Average top speed in longest straight in km/h")
|
23 |
+
fresh_tyre_laps: int = Field(
|
24 |
+
description="Number of laps done with fresh tyres")
|
25 |
+
used_tyre_laps: int = Field(
|
26 |
+
description="Number of laps done with used tyres")
|
27 |
+
avg_track_temp: float | None = Field(
|
28 |
+
description="Average track temperature in celsius")
|
29 |
+
avg_air_temp: float | None = Field(
|
30 |
+
description="Average air temperature in celsius")
|
31 |
+
|
32 |
+
|
33 |
+
class GetTyrePerformance(BaseTool):
|
34 |
+
name: str = "get_tyre_performance"
|
35 |
+
description: str = "useful for when you need to analyze tyre performance and degradation for a specific driver across all their laps"
|
36 |
+
args_schema: Type[BaseModel] = GetTyrePerformanceInput
|
37 |
+
|
38 |
+
def _run(self, driver_name: str) -> list[GetTyrePerformanceOutput]:
|
39 |
+
"""Use the tool."""
|
40 |
+
sql_file = open("tools/sql/tyre_performance.query.sql", "r")
|
41 |
+
sql_query = sql_file.read()
|
42 |
+
sql_file.close()
|
43 |
+
|
44 |
+
console.print("getting tyre performance data")
|
45 |
+
response = db.run(sql_query, parameters={
|
46 |
+
"driver_name": driver_name
|
47 |
+
})
|
48 |
+
|
49 |
+
if not isinstance(response, str):
|
50 |
+
response = str(response)
|
51 |
+
|
52 |
+
# Remove the outer brackets and split by rows
|
53 |
+
rows = response.strip('[]').split('), (')
|
54 |
+
|
55 |
+
results = []
|
56 |
+
for row in rows:
|
57 |
+
# Clean up the row string and split by columns
|
58 |
+
clean_row = row.strip('()').split(',')
|
59 |
+
|
60 |
+
# Convert to appropriate types and create output object
|
61 |
+
tyre_data = GetTyrePerformanceOutput(
|
62 |
+
driver_name=clean_row[0].strip("'"),
|
63 |
+
lap_number=int(float(clean_row[1])),
|
64 |
+
tyre_compound=clean_row[2].strip("'"),
|
65 |
+
avg_tyre_life=float(
|
66 |
+
clean_row[3]) if clean_row[3].strip() != 'None' else None,
|
67 |
+
avg_lap_time=float(
|
68 |
+
clean_row[4]) if clean_row[4].strip() != 'None' else None,
|
69 |
+
avg_top_speed=float(
|
70 |
+
clean_row[5]) if clean_row[5].strip() != 'None' else None,
|
71 |
+
fresh_tyre_laps=int(
|
72 |
+
float(clean_row[6])) if clean_row[6].strip() != 'None' else 0,
|
73 |
+
used_tyre_laps=int(
|
74 |
+
float(clean_row[7])) if clean_row[7].strip() != 'None' else 0,
|
75 |
+
avg_track_temp=float(
|
76 |
+
clean_row[8]) if clean_row[8].strip() != 'None' else None,
|
77 |
+
avg_air_temp=float(
|
78 |
+
clean_row[9]) if clean_row[9].strip() != 'None' else None
|
79 |
+
)
|
80 |
+
results.append(tyre_data)
|
81 |
+
|
82 |
+
return results
|
tools/weather_impact.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Type
|
3 |
+
from langchain_core.tools import BaseTool
|
4 |
+
from rich.console import Console
|
5 |
+
from . import db
|
6 |
+
|
7 |
+
console = Console(style="chartreuse1 on grey7")
|
8 |
+
|
9 |
+
|
10 |
+
class GetWeatherImpactInput(BaseModel):
|
11 |
+
"""Input for the get_weather_impact tool"""
|
12 |
+
# Note: This tool doesn't require input parameters as it returns weather impact data
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class GetWeatherImpactOutput(BaseModel):
|
17 |
+
"""Output for the get_weather_impact tool"""
|
18 |
+
event_name: str = Field(description="Name of the event")
|
19 |
+
session_type: str = Field(
|
20 |
+
description="Type of session (Practice, Qualifying, Race)")
|
21 |
+
track_name: str = Field(description="Name of the track")
|
22 |
+
avg_air_temp: float | None = Field(
|
23 |
+
description="Average air temperature in celsius")
|
24 |
+
avg_track_temp: float | None = Field(
|
25 |
+
description="Average track temperature in celsius")
|
26 |
+
avg_humidity: float | None = Field(
|
27 |
+
description="Average relative humidity percentage")
|
28 |
+
avg_wind_speed: float | None = Field(
|
29 |
+
description="Average wind speed in meters per second")
|
30 |
+
rain_percentage: float = Field(
|
31 |
+
description="Percentage of time it rained during the session")
|
32 |
+
avg_lap_time: float | None = Field(
|
33 |
+
description="Average lap time in seconds")
|
34 |
+
best_lap_time: float | None = Field(description="Best lap time in seconds")
|
35 |
+
|
36 |
+
|
37 |
+
class GetWeatherImpact(BaseTool):
|
38 |
+
name: str = "get_weather_impact"
|
39 |
+
description: str = "useful for when you need to analyze how weather conditions impact Formula 1 session performance"
|
40 |
+
args_schema: Type[BaseModel] = GetWeatherImpactInput
|
41 |
+
|
42 |
+
def _run(self) -> GetWeatherImpactOutput:
|
43 |
+
"""Use the tool."""
|
44 |
+
sql_file = open("tools/sql/weather_impact.query.sql", "r")
|
45 |
+
sql_query = sql_file.read()
|
46 |
+
sql_file.close()
|
47 |
+
|
48 |
+
console.print("getting weather impact data")
|
49 |
+
response = db.run(sql_query)
|
50 |
+
|
51 |
+
if not isinstance(response, str):
|
52 |
+
response = str(response)
|
53 |
+
|
54 |
+
# Clean up the single row response
|
55 |
+
clean_row = response.strip('[]()').split(',')
|
56 |
+
|
57 |
+
# Convert to appropriate types and create output object
|
58 |
+
return GetWeatherImpactOutput(
|
59 |
+
event_name=clean_row[0].strip("'"),
|
60 |
+
session_type=clean_row[1].strip("'"),
|
61 |
+
track_name=clean_row[2].strip("'"),
|
62 |
+
avg_air_temp=float(
|
63 |
+
clean_row[3]) if clean_row[3].strip() != 'None' else None,
|
64 |
+
avg_track_temp=float(
|
65 |
+
clean_row[4]) if clean_row[4].strip() != 'None' else None,
|
66 |
+
avg_humidity=float(
|
67 |
+
clean_row[5]) if clean_row[5].strip() != 'None' else None,
|
68 |
+
avg_wind_speed=float(
|
69 |
+
clean_row[6]) if clean_row[6].strip() != 'None' else None,
|
70 |
+
rain_percentage=float(
|
71 |
+
clean_row[7]) if clean_row[7].strip() != 'None' else 0.0,
|
72 |
+
avg_lap_time=float(
|
73 |
+
clean_row[8]) if clean_row[8].strip() != 'None' else None,
|
74 |
+
best_lap_time=float(
|
75 |
+
clean_row[9]) if clean_row[9].strip() != 'None' else None
|
76 |
+
)
|