Draichi commited on
Commit
b04134b
1 Parent(s): a98f893

feat: adds `laps_report.py` first version

Browse files
multi-agents-analysis/laps_report.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import (
2
+ create_engine,
3
+ MetaData,
4
+ Table,
5
+ Column,
6
+ String,
7
+ Integer,
8
+ select,
9
+ column,
10
+ )
11
+ import os
12
+ from llama_index.core import Settings, VectorStoreIndex
13
+ from llama_index.core import SQLDatabase
14
+ from llama_index.llms.ollama import Ollama
15
+ from llama_index.core.query_engine import NLSQLTableQueryEngine
16
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
17
+ from llama_index.core.objects import (
18
+ SQLTableNodeMapping,
19
+ ObjectIndex,
20
+ SQLTableSchema,
21
+ )
22
+ from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
23
+ from rich.console import Console
24
+ from rich.theme import Theme
25
+
26
+ custom_theme = Theme({
27
+ "title": "bold white on orchid1",
28
+ "text": "dim chartreuse1",
29
+ })
30
+
31
+ console = Console(theme=custom_theme)
32
+
33
+ Settings.llm = Ollama(model="phi3", request_timeout=360.0)
34
+ Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
35
+
36
+ engine = create_engine("sqlite:///multi-agents-analysis/data/laps.db")
37
+ metadata_obj = MetaData()
38
+
39
+ sql_database = SQLDatabase(engine)
40
+
41
+ # manually set extra context text
42
+ city_stats_text = """This table gives information regarding the performance in a race about each driver.
43
+ The time is split into 3 different sectors.
44
+ The speed is split into SpeedI1, SpeedI2, SpeedFL and SpeedST"""
45
+
46
+ table_node_mapping = SQLTableNodeMapping(sql_database)
47
+ table_schema_objs = [
48
+ (SQLTableSchema(table_name="laps", context_str=city_stats_text))
49
+ ]
50
+
51
+ obj_index = ObjectIndex.from_objects(
52
+ table_schema_objs,
53
+ table_node_mapping
54
+ )
55
+
56
+ query_engine = SQLTableRetrieverQueryEngine(
57
+ sql_database, obj_index.as_retriever(similarity_top_k=1)
58
+ )
59
+ response = query_engine.query("Which driver had the lowers time in sector 1?")
60
+ print(response)