Draichi commited on
Commit
cdd1718
1 Parent(s): c8073ff

feat: add new 5 tools to the agent

Browse files
app.py CHANGED
@@ -10,15 +10,10 @@ from rich.console import Console
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
  from gradio import ChatMessage
12
  import textwrap
13
- from tools import GetTelemetry
14
  load_dotenv()
15
  os.environ['LANGCHAIN_PROJECT'] = 'gradio-test'
16
 
17
- console = Console(style="chartreuse1 on grey7")
18
-
19
- # * Initialize database
20
- db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
21
-
22
  # * Initialize LLM
23
  llm = ChatGoogleGenerativeAI(
24
  model="gemini-1.5-flash",
@@ -32,8 +27,17 @@ llm = ChatGoogleGenerativeAI(
32
  toolkit = SQLDatabaseToolkit(db=db, llm=llm)
33
  tools = toolkit.get_tools()
34
 
 
 
35
  get_telemetry_tool = GetTelemetry()
 
 
 
 
 
36
  tools.append(get_telemetry_tool)
 
 
37
 
38
  # * Initialize agent
39
  agent_prompt = open("agent_prompt.txt", "r")
 
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
  from gradio import ChatMessage
12
  import textwrap
13
+ from tools import *
14
  load_dotenv()
15
  os.environ['LANGCHAIN_PROJECT'] = 'gradio-test'
16
 
 
 
 
 
 
17
  # * Initialize LLM
18
  llm = ChatGoogleGenerativeAI(
19
  model="gemini-1.5-flash",
 
27
  toolkit = SQLDatabaseToolkit(db=db, llm=llm)
28
  tools = toolkit.get_tools()
29
 
30
+ get_driver_performance_tool = GetDriverPerformance()
31
+ get_event_performance_tool = GetEventPerformance()
32
  get_telemetry_tool = GetTelemetry()
33
+ get_tyre_performance_tool = GetTyrePerformance()
34
+ get_weather_impact_tool = GetWeatherImpact()
35
+
36
+ tools.append(get_driver_performance_tool)
37
+ tools.append(get_event_performance_tool)
38
  tools.append(get_telemetry_tool)
39
+ tools.append(get_tyre_performance_tool)
40
+ tools.append(get_weather_impact_tool)
41
 
42
  # * Initialize agent
43
  agent_prompt = open("agent_prompt.txt", "r")
tools/__init__.py CHANGED
@@ -1,82 +1,21 @@
1
- from pydantic import BaseModel, Field
2
- from typing import Type
3
- from langchain_core.tools import BaseTool
4
  from langchain_community.utilities import SQLDatabase
5
  from rich.console import Console
 
 
 
 
 
6
 
7
  console = Console(style="chartreuse1 on grey7")
8
 
9
  db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
10
 
11
-
12
- class GetTelemetryAndWeatherInput(BaseModel):
13
- """Input for the get_telemetry_and_weather tool"""
14
- driver_name: str = Field(
15
- description="Name of the driver to analyze (e.g., 'VER', 'HAM', 'LEC', etc.)")
16
- lap_number: int = Field(description="Lap number to analyze")
17
-
18
-
19
- class GetTelemetryAndWeatherOutput(BaseModel):
20
- """Output for the get_telemetry_and_weather tool"""
21
- lap_id: int = Field(description="Lap ID")
22
- lap_number: int = Field(description="Lap number")
23
- lap_time_in_seconds: float | None = Field(
24
- description="Lap time in seconds")
25
- avg_speed: float = Field(description="Average speed in km/h")
26
- max_speed: float = Field(description="Maximum speed in km/h")
27
- avg_RPM: float = Field(description="Average RPM")
28
- max_RPM: float = Field(description="Maximum RPM")
29
- avg_throttle: float = Field(description="Average throttle")
30
- brake_percentage: float = Field(description="Brake percentage")
31
- drs_usage_percentage: float = Field(description="Drs usage percentage")
32
- off_track_percentage: float = Field(description="Off track percentage")
33
- avg_air_temp: float | None = Field(
34
- description="Average air temperature in celsius")
35
- avg_track_temp: float | None = Field(
36
- description="Average track temperature in celsius")
37
- avg_wind_speed: float | None = Field(
38
- description="Average wind speed in meters per second")
39
-
40
-
41
- class GetTelemetry(BaseTool):
42
- name: str = "get_telemetry"
43
- description: str = "useful for when you need to answer questions about telemetry for a given driver and lap"
44
- args_schema: Type[BaseModel] = GetTelemetryAndWeatherInput
45
-
46
- def _run(
47
- self, driver_name: str, lap_number: int
48
- ) -> GetTelemetryAndWeatherOutput:
49
- # """Use the tool."""
50
- sql_file = open("tools/telemetry_and_weather_query.sql", "r")
51
- sql_query = sql_file.read()
52
- sql_file.close()
53
- console.print("getting telemetry")
54
- response = db.run(sql_query, parameters={
55
- "driver_name": driver_name,
56
- "lap_number": lap_number})
57
-
58
- if not isinstance(response, str):
59
- response = str(response)
60
-
61
- clean_response = response.strip('[]()').split(',')
62
- # Convert to appropriate types and create dictionary
63
- return GetTelemetryAndWeatherOutput(
64
- lap_id=int(float(clean_response[0])),
65
- lap_number=int(float(clean_response[1])),
66
- lap_time_in_seconds=float(
67
- clean_response[2]) if clean_response[2].strip() != 'None' else None,
68
- avg_speed=float(clean_response[3]),
69
- max_speed=float(clean_response[4]),
70
- avg_RPM=float(clean_response[5]),
71
- max_RPM=float(clean_response[6]),
72
- avg_throttle=float(clean_response[7]),
73
- brake_percentage=float(clean_response[8]),
74
- drs_usage_percentage=float(clean_response[9]),
75
- off_track_percentage=float(clean_response[10]),
76
- avg_air_temp=float(
77
- clean_response[11]) if clean_response[11].strip() != 'None' else None,
78
- avg_track_temp=float(
79
- clean_response[12]) if clean_response[12].strip() != 'None' else None,
80
- avg_wind_speed=float(
81
- clean_response[13]) if clean_response[13].strip() != 'None' else None
82
- )
 
 
 
 
1
  from langchain_community.utilities import SQLDatabase
2
  from rich.console import Console
3
+ from .driver_performance import GetDriverPerformance
4
+ from .event_performance import GetEventPerformance
5
+ from .telemetry_analysis import GetTelemetry
6
+ from .tyre_performance import GetTyrePerformance
7
+ from .weather_impact import GetWeatherImpact
8
 
9
  console = Console(style="chartreuse1 on grey7")
10
 
11
  db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db")
12
 
13
+ __all__ = [
14
+ "GetDriverPerformance",
15
+ "GetEventPerformance",
16
+ "GetTelemetry",
17
+ "GetTyrePerformance",
18
+ "GetWeatherImpact",
19
+ "console",
20
+ "db"
21
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/driver_performance.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from langchain_core.tools import BaseTool
3
+ from . import db, console
4
+
5
+
6
+ class GetDriverPerformanceOutput(BaseModel):
7
+ """Output for the get_driver_performance tool"""
8
+ driver_name: str = Field(description="Name of the driver")
9
+ event_name: str = Field(description="Name of the event")
10
+ session_type: str = Field(
11
+ description="Type of session (Practice, Qualifying, Race)")
12
+ track_name: str = Field(description="Name of the track")
13
+ total_laps: int = Field(description="Total number of laps completed")
14
+ avg_lap_time: float | None = Field(
15
+ description="Average lap time in seconds")
16
+ best_lap_time: float | None = Field(description="Best lap time in seconds")
17
+ avg_sector1_time: float | None = Field(
18
+ description="Average sector 1 time in seconds")
19
+ avg_sector2_time: float | None = Field(
20
+ description="Average sector 2 time in seconds")
21
+ avg_sector3_time: float | None = Field(
22
+ description="Average sector 3 time in seconds")
23
+ avg_finish_line_speed: float | None = Field(
24
+ description="Average finish line speed in km/h")
25
+ personal_best_laps: int = Field(description="Number of personal best laps")
26
+ avg_air_temp: float | None = Field(
27
+ description="Average air temperature in celsius")
28
+ avg_track_temp: float | None = Field(
29
+ description="Average track temperature in celsius")
30
+ rain_percentage: float = Field(
31
+ description="Percentage of time it rained during the session")
32
+
33
+
34
+ class GetDriverPerformance(BaseTool):
35
+ name: str = "get_driver_performance"
36
+ description: str = "useful for when you need to analyze driver performance statistics across different sessions and events"
37
+
38
+ def _run(self) -> list[GetDriverPerformanceOutput]:
39
+ """Use the tool."""
40
+ sql_file = open("tools/sql/driver_performance.query.sql", "r")
41
+ sql_query = sql_file.read()
42
+ sql_file.close()
43
+
44
+ console.print("getting driver performance data")
45
+ response = db.run(sql_query)
46
+
47
+ if not isinstance(response, str):
48
+ response = str(response)
49
+
50
+ # Remove the outer brackets and split by rows
51
+ rows = response.strip('[]').split('), (')
52
+
53
+ results = []
54
+ for row in rows:
55
+ # Clean up the row string and split by columns
56
+ clean_row = row.strip('()').split(',')
57
+
58
+ # Convert to appropriate types and create output object
59
+ driver_data = GetDriverPerformanceOutput(
60
+ driver_name=clean_row[0].strip("'"),
61
+ event_name=clean_row[1].strip("'"),
62
+ session_type=clean_row[2].strip("'"),
63
+ track_name=clean_row[3].strip("'"),
64
+ total_laps=int(float(clean_row[4])),
65
+ avg_lap_time=float(
66
+ clean_row[5]) if clean_row[5].strip() != 'None' else None,
67
+ best_lap_time=float(
68
+ clean_row[6]) if clean_row[6].strip() != 'None' else None,
69
+ avg_sector1_time=float(
70
+ clean_row[7]) if clean_row[7].strip() != 'None' else None,
71
+ avg_sector2_time=float(
72
+ clean_row[8]) if clean_row[8].strip() != 'None' else None,
73
+ avg_sector3_time=float(
74
+ clean_row[9]) if clean_row[9].strip() != 'None' else None,
75
+ avg_finish_line_speed=float(
76
+ clean_row[10]) if clean_row[10].strip() != 'None' else None,
77
+ personal_best_laps=int(
78
+ float(clean_row[11])) if clean_row[11].strip() != 'None' else 0,
79
+ avg_air_temp=float(
80
+ clean_row[12]) if clean_row[12].strip() != 'None' else None,
81
+ avg_track_temp=float(
82
+ clean_row[13]) if clean_row[13].strip() != 'None' else None,
83
+ rain_percentage=float(
84
+ clean_row[14]) if clean_row[14].strip() != 'None' else 0.0
85
+ )
86
+ results.append(driver_data)
87
+
88
+ return results
tools/event_performance.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Type
3
+ from langchain_core.tools import BaseTool
4
+ from . import db, console
5
+
6
+
7
+ class GetEventPerformanceOutput(BaseModel):
8
+ """Output for the get_event_performance tool"""
9
+ event_name: str = Field(description="Name of the event")
10
+ country: str = Field(description="Country where the event took place")
11
+ location: str = Field(description="Specific location of the event")
12
+ session_type: str = Field(
13
+ description="Type of session (Practice, Qualifying, Race)")
14
+ driver_count: int = Field(
15
+ description="Number of drivers that participated")
16
+ avg_lap_time: float = Field(description="Average lap time in seconds")
17
+ best_lap_time: float = Field(description="Best lap time in seconds")
18
+ max_finish_line_speed: float = Field(
19
+ description="Maximum speed at finish line in km/h")
20
+ avg_air_temp: float | None = Field(
21
+ description="Average air temperature in celsius")
22
+ avg_track_temp: float | None = Field(
23
+ description="Average track temperature in celsius")
24
+ rain_percentage: float = Field(
25
+ description="Percentage of time it rained during the session")
26
+
27
+
28
+ class GetEventPerformance(BaseTool):
29
+ name: str = "get_event_performance"
30
+ description: str = "useful for when you need to get performance statistics for Formula 1 events"
31
+
32
+ def _run(self) -> list[GetEventPerformanceOutput]:
33
+ """Use the tool."""
34
+ sql_file = open("tools/sql/event_performance.query.sql", "r")
35
+ sql_query = sql_file.read()
36
+ sql_file.close()
37
+
38
+ console.print("getting event performance data")
39
+ response = db.run(sql_query)
40
+
41
+ if not isinstance(response, str):
42
+ response = str(response)
43
+
44
+ # Remove the outer brackets and split by rows
45
+ rows = response.strip('[]').split('), (')
46
+
47
+ results = []
48
+ for row in rows:
49
+ # Clean up the row string and split by columns
50
+ clean_row = row.strip('()').split(',')
51
+
52
+ # Convert to appropriate types and create output object
53
+ event_data = GetEventPerformanceOutput(
54
+ event_name=clean_row[0].strip("'"),
55
+ country=clean_row[1].strip("'"),
56
+ location=clean_row[2].strip("'"),
57
+ session_type=clean_row[3].strip("'"),
58
+ driver_count=int(float(clean_row[4])),
59
+ avg_lap_time=float(
60
+ clean_row[5]) if clean_row[5].strip() != 'None' else 0.0,
61
+ best_lap_time=float(
62
+ clean_row[6]) if clean_row[6].strip() != 'None' else 0.0,
63
+ max_finish_line_speed=float(
64
+ clean_row[7]) if clean_row[7].strip() != 'None' else 0.0,
65
+ avg_air_temp=float(
66
+ clean_row[8]) if clean_row[8].strip() != 'None' else None,
67
+ avg_track_temp=float(
68
+ clean_row[9]) if clean_row[9].strip() != 'None' else None,
69
+ rain_percentage=float(
70
+ clean_row[10]) if clean_row[10].strip() != 'None' else 0.0
71
+ )
72
+ results.append(event_data)
73
+
74
+ return results
tools/telemetry_analysis.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Type
3
+ from langchain_core.tools import BaseTool
4
+ from . import db, console
5
+
6
+
7
+ class GetTelemetryAndWeatherInput(BaseModel):
8
+ """Input for the get_telemetry_and_weather tool"""
9
+ driver_name: str = Field(
10
+ description="Name of the driver to analyze (e.g., 'VER', 'HAM', 'LEC', etc.)")
11
+ lap_number: int = Field(description="Lap number to analyze")
12
+
13
+
14
+ class GetTelemetryAndWeatherOutput(BaseModel):
15
+ """Output for the get_telemetry_and_weather tool"""
16
+ lap_id: int = Field(description="Lap ID")
17
+ lap_number: int = Field(description="Lap number")
18
+ lap_time_in_seconds: float | None = Field(
19
+ description="Lap time in seconds")
20
+ avg_speed: float = Field(description="Average speed in km/h")
21
+ max_speed: float = Field(description="Maximum speed in km/h")
22
+ avg_RPM: float = Field(description="Average RPM")
23
+ max_RPM: float = Field(description="Maximum RPM")
24
+ avg_throttle: float = Field(description="Average throttle")
25
+ brake_percentage: float = Field(description="Brake percentage")
26
+ drs_usage_percentage: float = Field(description="Drs usage percentage")
27
+ off_track_percentage: float = Field(description="Off track percentage")
28
+ avg_air_temp: float | None = Field(
29
+ description="Average air temperature in celsius")
30
+ avg_track_temp: float | None = Field(
31
+ description="Average track temperature in celsius")
32
+ avg_wind_speed: float | None = Field(
33
+ description="Average wind speed in meters per second")
34
+
35
+
36
+ class GetTelemetry(BaseTool):
37
+ name: str = "get_telemetry"
38
+ description: str = "useful for when you need to answer questions about telemetry for a given driver and lap"
39
+ args_schema: Type[BaseModel] = GetTelemetryAndWeatherInput
40
+
41
+ def _run(
42
+ self, driver_name: str, lap_number: int
43
+ ) -> GetTelemetryAndWeatherOutput:
44
+ # """Use the tool."""
45
+ sql_file = open("tools/sql/telemetry_analysis.query.sql", "r")
46
+ sql_query = sql_file.read()
47
+ sql_file.close()
48
+ console.print("getting telemetry")
49
+ response = db.run(sql_query, parameters={
50
+ "driver_name": driver_name,
51
+ "lap_number": lap_number})
52
+
53
+ if not isinstance(response, str):
54
+ response = str(response)
55
+
56
+ clean_response = response.strip('[]()').split(',')
57
+ # Convert to appropriate types and create dictionary
58
+ return GetTelemetryAndWeatherOutput(
59
+ lap_id=int(float(clean_response[0])),
60
+ lap_number=int(float(clean_response[1])),
61
+ lap_time_in_seconds=float(
62
+ clean_response[2]) if clean_response[2].strip() != 'None' else None,
63
+ avg_speed=float(clean_response[3]),
64
+ max_speed=float(clean_response[4]),
65
+ avg_RPM=float(clean_response[5]),
66
+ max_RPM=float(clean_response[6]),
67
+ avg_throttle=float(clean_response[7]),
68
+ brake_percentage=float(clean_response[8]),
69
+ drs_usage_percentage=float(clean_response[9]),
70
+ off_track_percentage=float(clean_response[10]),
71
+ avg_air_temp=float(
72
+ clean_response[11]) if clean_response[11].strip() != 'None' else None,
73
+ avg_track_temp=float(
74
+ clean_response[12]) if clean_response[12].strip() != 'None' else None,
75
+ avg_wind_speed=float(
76
+ clean_response[13]) if clean_response[13].strip() != 'None' else None
77
+ )
tools/tyre_performance.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Type
3
+ from langchain_core.tools import BaseTool
4
+ from . import db, console
5
+
6
+
7
+ class GetTyrePerformanceInput(BaseModel):
8
+ """Input for the get_tyre_performance tool"""
9
+ driver_name: str = Field(description="Name of the driver to analyze")
10
+
11
+
12
+ class GetTyrePerformanceOutput(BaseModel):
13
+ """Output for the get_tyre_performance tool"""
14
+ driver_name: str = Field(description="Name of the driver")
15
+ lap_number: int = Field(description="Lap number")
16
+ tyre_compound: str = Field(description="Type of tyre compound used")
17
+ avg_tyre_life: float | None = Field(
18
+ description="Average tyre life in laps")
19
+ avg_lap_time: float | None = Field(
20
+ description="Average lap time in seconds")
21
+ avg_top_speed: float | None = Field(
22
+ description="Average top speed in longest straight in km/h")
23
+ fresh_tyre_laps: int = Field(
24
+ description="Number of laps done with fresh tyres")
25
+ used_tyre_laps: int = Field(
26
+ description="Number of laps done with used tyres")
27
+ avg_track_temp: float | None = Field(
28
+ description="Average track temperature in celsius")
29
+ avg_air_temp: float | None = Field(
30
+ description="Average air temperature in celsius")
31
+
32
+
33
+ class GetTyrePerformance(BaseTool):
34
+ name: str = "get_tyre_performance"
35
+ description: str = "useful for when you need to analyze tyre performance and degradation for a specific driver across all their laps"
36
+ args_schema: Type[BaseModel] = GetTyrePerformanceInput
37
+
38
+ def _run(self, driver_name: str) -> list[GetTyrePerformanceOutput]:
39
+ """Use the tool."""
40
+ sql_file = open("tools/sql/tyre_performance.query.sql", "r")
41
+ sql_query = sql_file.read()
42
+ sql_file.close()
43
+
44
+ console.print("getting tyre performance data")
45
+ response = db.run(sql_query, parameters={
46
+ "driver_name": driver_name
47
+ })
48
+
49
+ if not isinstance(response, str):
50
+ response = str(response)
51
+
52
+ # Remove the outer brackets and split by rows
53
+ rows = response.strip('[]').split('), (')
54
+
55
+ results = []
56
+ for row in rows:
57
+ # Clean up the row string and split by columns
58
+ clean_row = row.strip('()').split(',')
59
+
60
+ # Convert to appropriate types and create output object
61
+ tyre_data = GetTyrePerformanceOutput(
62
+ driver_name=clean_row[0].strip("'"),
63
+ lap_number=int(float(clean_row[1])),
64
+ tyre_compound=clean_row[2].strip("'"),
65
+ avg_tyre_life=float(
66
+ clean_row[3]) if clean_row[3].strip() != 'None' else None,
67
+ avg_lap_time=float(
68
+ clean_row[4]) if clean_row[4].strip() != 'None' else None,
69
+ avg_top_speed=float(
70
+ clean_row[5]) if clean_row[5].strip() != 'None' else None,
71
+ fresh_tyre_laps=int(
72
+ float(clean_row[6])) if clean_row[6].strip() != 'None' else 0,
73
+ used_tyre_laps=int(
74
+ float(clean_row[7])) if clean_row[7].strip() != 'None' else 0,
75
+ avg_track_temp=float(
76
+ clean_row[8]) if clean_row[8].strip() != 'None' else None,
77
+ avg_air_temp=float(
78
+ clean_row[9]) if clean_row[9].strip() != 'None' else None
79
+ )
80
+ results.append(tyre_data)
81
+
82
+ return results
tools/weather_impact.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Type
3
+ from langchain_core.tools import BaseTool
4
+ from rich.console import Console
5
+ from . import db
6
+
7
+ console = Console(style="chartreuse1 on grey7")
8
+
9
+
10
+ class GetWeatherImpactInput(BaseModel):
11
+ """Input for the get_weather_impact tool"""
12
+ # Note: This tool doesn't require input parameters as it returns weather impact data
13
+ pass
14
+
15
+
16
+ class GetWeatherImpactOutput(BaseModel):
17
+ """Output for the get_weather_impact tool"""
18
+ event_name: str = Field(description="Name of the event")
19
+ session_type: str = Field(
20
+ description="Type of session (Practice, Qualifying, Race)")
21
+ track_name: str = Field(description="Name of the track")
22
+ avg_air_temp: float | None = Field(
23
+ description="Average air temperature in celsius")
24
+ avg_track_temp: float | None = Field(
25
+ description="Average track temperature in celsius")
26
+ avg_humidity: float | None = Field(
27
+ description="Average relative humidity percentage")
28
+ avg_wind_speed: float | None = Field(
29
+ description="Average wind speed in meters per second")
30
+ rain_percentage: float = Field(
31
+ description="Percentage of time it rained during the session")
32
+ avg_lap_time: float | None = Field(
33
+ description="Average lap time in seconds")
34
+ best_lap_time: float | None = Field(description="Best lap time in seconds")
35
+
36
+
37
+ class GetWeatherImpact(BaseTool):
38
+ name: str = "get_weather_impact"
39
+ description: str = "useful for when you need to analyze how weather conditions impact Formula 1 session performance"
40
+ args_schema: Type[BaseModel] = GetWeatherImpactInput
41
+
42
+ def _run(self) -> GetWeatherImpactOutput:
43
+ """Use the tool."""
44
+ sql_file = open("tools/sql/weather_impact.query.sql", "r")
45
+ sql_query = sql_file.read()
46
+ sql_file.close()
47
+
48
+ console.print("getting weather impact data")
49
+ response = db.run(sql_query)
50
+
51
+ if not isinstance(response, str):
52
+ response = str(response)
53
+
54
+ # Clean up the single row response
55
+ clean_row = response.strip('[]()').split(',')
56
+
57
+ # Convert to appropriate types and create output object
58
+ return GetWeatherImpactOutput(
59
+ event_name=clean_row[0].strip("'"),
60
+ session_type=clean_row[1].strip("'"),
61
+ track_name=clean_row[2].strip("'"),
62
+ avg_air_temp=float(
63
+ clean_row[3]) if clean_row[3].strip() != 'None' else None,
64
+ avg_track_temp=float(
65
+ clean_row[4]) if clean_row[4].strip() != 'None' else None,
66
+ avg_humidity=float(
67
+ clean_row[5]) if clean_row[5].strip() != 'None' else None,
68
+ avg_wind_speed=float(
69
+ clean_row[6]) if clean_row[6].strip() != 'None' else None,
70
+ rain_percentage=float(
71
+ clean_row[7]) if clean_row[7].strip() != 'None' else 0.0,
72
+ avg_lap_time=float(
73
+ clean_row[8]) if clean_row[8].strip() != 'None' else None,
74
+ best_lap_time=float(
75
+ clean_row[9]) if clean_row[9].strip() != 'None' else None
76
+ )