Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from typing import Optional, List | |
from langchain.pydantic_v1 import BaseModel, Field | |
from langchain_core.runnables import ConfigurableField | |
from langchain_core.runnables.base import RunnableLambda | |
from operator import itemgetter | |
SYSTEM_PROMPT = ( | |
"You are an assistant specialized in the legal and compliance field who must answer and converse with the user using the context provided. " + | |
"When you answer the user, if it is relevant, cite the laws and articles you are referring to. NEVER mention the use of context in your answers. " | |
"If you believe the question cannot be answered from the given context, do not make up an answer. Answer in the same language the user is speaking.\n\n ### Context:\n {context}" | |
) | |
SYSTEM_PROMPT_LOOP = ( | |
"You are an assistant who must inform the user that you do not have enough information to answer and ask if the user can provide you with additional information. " + | |
"This answer, must be adapted to the conversation that occurred with the user that is provided to you. Just write down the answer " | |
) | |
class Answer(): | |
answer: str | |
new_documents: Optional[List] = None | |
status: Optional[int] = 1 | |
class ContextInput(BaseModel): | |
text: str = Field( | |
title="Text", | |
description="Self-explanatory summary describing what the user is asking for" | |
) | |
def get_instance_dynamic_class(lib_path:str, class_name:str, **kwargs): | |
""" | |
Instantiate a dynamically imported class from a given library path and class name. | |
Args: | |
lib_path (str): The path to the library/module containing the class. | |
class_name (str): The name of the class to instantiate. | |
**kwargs: Additional keyword arguments to pass to the class constructor. | |
Returns: | |
An instance of the dynamically imported class initialized with the provided arguments. | |
""" | |
mod = __import__(lib_path, fromlist=[class_name]) | |
dynamic_class = getattr(mod, class_name) | |
return dynamic_class(**kwargs) | |
def get_init_modules(config): | |
embedder = get_instance_dynamic_class( | |
lib_path='langchain_community.embeddings', | |
class_name=config["embeddings"]["class"], | |
**config["embeddings"]["kwargs"] | |
) | |
llm = get_instance_dynamic_class( | |
lib_path='langchain_community.chat_models', | |
class_name=config["llm"]["class"], | |
**config["llm"]["kwargs"] | |
) | |
mod_chat = __import__("langchain_community.chat_message_histories", | |
fromlist=[config["chatDB"]["class"]]) | |
chatDB_class = getattr(mod_chat, config["chatDB"]["class"]) | |
retriever, retriever_chain = get_vectorDB_module(config['vectorDB'], embedder) | |
return embedder, llm, chatDB_class, retriever, retriever_chain | |
def get_vectorDB_module(db_config, embedder): | |
mod_chat = __import__("langchain_community.vectorstores", | |
fromlist=[db_config["class"]]) | |
vectorDB_class = getattr(mod_chat, db_config["class"]) | |
if db_config["class"] == 'Qdrant': | |
from qdrant_client import QdrantClient | |
import inspect | |
# Get QdrantClient init parameters name from signature | |
signature_params = inspect.signature(QdrantClient.__init__).parameters.values() | |
params_to_exclude = ['self', 'kwargs'] | |
client_args = [el.name for el in list(signature_params) if el.name not in params_to_exclude] | |
client_kwargs = {k: v for k, | |
v in db_config['kwargs'].items() if k in client_args} | |
db_kwargs = { | |
k: v for k, v in db_config['kwargs'].items() if k not in client_kwargs} | |
client = QdrantClient(**client_kwargs) | |
retriever = vectorDB_class( | |
client, embeddings=embedder, **db_kwargs).as_retriever( | |
search_type=db_config["retriever_args"]["search_type"], | |
search_kwargs={**db_config["retriever_args"]["search_kwargs"]} | |
) | |
else: | |
retriever = vectorDB_class(embeddings=embedder, **db_config["kwargs"]).as_retriever( | |
search_type=db_config["retriever_args"]["search_type"], | |
search_kwargs=db_config["retriever_args"]["search_kwargs"] | |
) | |
retriever = retriever.configurable_fields( | |
search_kwargs=ConfigurableField( | |
id="search_kwargs", | |
name="Search Kwargs", | |
description="The search kwargs to use. Includes dynamic category adjustment.", | |
) | |
) | |
chain = ( RunnableLambda(lambda x: x['question']) | retriever) | |
if db_config.get("rerank"): | |
if db_config["rerank"]["class"] == "CohereRerank": | |
module_compressors = __import__("langchain.retrievers.document_compressors", | |
fromlist=[db_config["rerank"]["class"]]) | |
rerank_class = getattr(module_compressors, db_config["rerank"]["class"]) | |
rerank = rerank_class(**db_config["rerank"]["kwargs"]) | |
chain = ({ | |
"docs": chain, | |
"query": itemgetter("question"), | |
} | (RunnableLambda(lambda x: rerank.compress_documents(x['docs'], x['query']))) | |
) | |
else: | |
raise NotImplementedError(db_config["rerank"]["class"]) | |
return retriever, chain | |