File size: 2,029 Bytes
bdfd46f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import asyncio
import datetime
from typing import Annotated

import dbally
import sqlalchemy
from dbally import SqlAlchemyBaseView
from dbally.audit import CLIEventHandler
from dbally.embeddings import LiteLLMEmbeddingClient
from dbally.gradio import create_gradio_interface
from dbally.llms import LiteLLM
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore
from dbally.views import decorators
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

dbally.event_handlers = [CLIEventHandler()]

engine = create_engine('sqlite:///clients.db')
load_dotenv()

Base = automap_base()
Base.prepare(autoload_with=engine)
Clients = Base.classes.clients

cities_fetcher = SimpleSqlAlchemyFetcher(
    sqlalchemy_engine=engine, 
    table=Clients, 
    column=Clients.city,
)
cities_store = FaissStore(
    index_dir="indexes",
    index_name="cities_index",
    embedding_client=LiteLLMEmbeddingClient("text-embedding-3-small"),
)

CityIndex = SimilarityIndex(
    fetcher=cities_fetcher,
    store=cities_store,
)


class ClientsView(SqlAlchemyBaseView):

    def get_select(self) -> sqlalchemy.Select:
        return sqlalchemy.select(Clients)

    @decorators.view_filter()
    def filter_by_city(self, city: Annotated[str, CityIndex]):
        return Clients.city == city
    
    @decorators.view_filter()
    def eligible_for_loyalty_program(self):
        total_orders_check = Clients.total_orders > 3
        date_joined_check = Clients.date_joined < (datetime.datetime.now() - datetime.timedelta(days=365))
        return total_orders_check & date_joined_check


async def main() -> None:
    llm = LiteLLM(model_name="gpt-4-turbo")

    collection = dbally.create_collection("clients", llm=llm)
    collection.add(ClientsView, lambda: ClientsView(engine))

    await collection.update_similarity_indexes()

    interface = create_gradio_interface(collection)
    interface.launch()


if __name__ == '__main__':
    asyncio.run(main())