Chat-EUR-Lex / chat_utils.py
davidecolla's picture
Update chat_utils.py
7605874 verified
from dataclasses import dataclass
from typing import Optional, List
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import ConfigurableField
from langchain_core.runnables.base import RunnableLambda
from operator import itemgetter
SYSTEM_PROMPT = (
"You are an assistant specialized in the legal and compliance field who must answer and converse with the user using the context provided. " +
"When you answer the user, if it is relevant, cite the laws and articles you are referring to. NEVER mention the use of context in your answers. " +
"If the user asks for a definition, report exactly the content of the context, do not paraphrase the text." +
"If you believe the question cannot be answered from the given context, do not make up an answer. Answer in the same language the user is speaking.\n\n ### Context:\n {context}"
)
SYSTEM_PROMPT_LOOP = (
"You are an assistant who must inform the user that you do not have enough information to answer and ask if the user can provide you with additional information. " +
"This answer, must be adapted to the conversation that occurred with the user that is provided to you. Just write down the answer "
)
@dataclass
class Answer():
answer: str
new_documents: Optional[List] = None
status: Optional[int] = 1
class ContextInput(BaseModel):
text: str = Field(
title="Text",
description="Self-explanatory summary describing what the user is asking for"
)
def get_instance_dynamic_class(lib_path:str, class_name:str, **kwargs):
"""
Instantiate a dynamically imported class from a given library path and class name.
Args:
lib_path (str): The path to the library/module containing the class.
class_name (str): The name of the class to instantiate.
**kwargs: Additional keyword arguments to pass to the class constructor.
Returns:
An instance of the dynamically imported class initialized with the provided arguments.
"""
mod = __import__(lib_path, fromlist=[class_name])
dynamic_class = getattr(mod, class_name)
return dynamic_class(**kwargs)
def get_init_modules(config):
embedder = get_instance_dynamic_class(
lib_path='langchain_community.embeddings',
class_name=config["embeddings"]["class"],
**config["embeddings"]["kwargs"]
)
llm = get_instance_dynamic_class(
lib_path='langchain_community.chat_models',
class_name=config["llm"]["class"],
**config["llm"]["kwargs"]
)
mod_chat = __import__("langchain_community.chat_message_histories",
fromlist=[config["chatDB"]["class"]])
chatDB_class = getattr(mod_chat, config["chatDB"]["class"])
retriever, retriever_chain = get_vectorDB_module(config['vectorDB'], embedder)
return embedder, llm, chatDB_class, retriever, retriever_chain
def get_vectorDB_module(db_config, embedder):
mod_chat = __import__("langchain_community.vectorstores",
fromlist=[db_config["class"]])
vectorDB_class = getattr(mod_chat, db_config["class"])
if db_config["class"] == 'Qdrant':
from qdrant_client import QdrantClient
import inspect
# Get QdrantClient init parameters name from signature
signature_params = inspect.signature(QdrantClient.__init__).parameters.values()
params_to_exclude = ['self', 'kwargs']
client_args = [el.name for el in list(signature_params) if el.name not in params_to_exclude]
client_kwargs = {k: v for k,
v in db_config['kwargs'].items() if k in client_args}
db_kwargs = {
k: v for k, v in db_config['kwargs'].items() if k not in client_kwargs}
client = QdrantClient(**client_kwargs)
retriever = vectorDB_class(
client, embeddings=embedder, **db_kwargs).as_retriever(
search_type=db_config["retriever_args"]["search_type"],
search_kwargs={**db_config["retriever_args"]["search_kwargs"]}
)
else:
retriever = vectorDB_class(embeddings=embedder, **db_config["kwargs"]).as_retriever(
search_type=db_config["retriever_args"]["search_type"],
search_kwargs=db_config["retriever_args"]["search_kwargs"]
)
retriever = retriever.configurable_fields(
search_kwargs=ConfigurableField(
id="search_kwargs",
name="Search Kwargs",
description="The search kwargs to use. Includes dynamic category adjustment.",
)
)
chain = ( RunnableLambda(lambda x: x['question']) | retriever)
if db_config.get("rerank"):
if db_config["rerank"]["class"] == "CohereRerank":
module_compressors = __import__("langchain.retrievers.document_compressors",
fromlist=[db_config["rerank"]["class"]])
rerank_class = getattr(module_compressors, db_config["rerank"]["class"])
rerank = rerank_class(**db_config["rerank"]["kwargs"])
chain = ({
"docs": chain,
"query": itemgetter("question"),
} | (RunnableLambda(lambda x: rerank.compress_documents(x['docs'], x['query'])))
)
else:
raise NotImplementedError(db_config["rerank"]["class"])
return retriever, chain