File size: 2,084 Bytes
787d3cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from __future__ import annotations
import inspect
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.prompts.base import BasePromptTemplate
from typing import Any, Dict, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.router.llm_router import LLMRouterChain
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
fallback_answer:str = "No sources available to answer this question."
def _call(self,inputs,run_manager=None):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
if len(docs) == 0:
answer = self.fallback_answer
sources = []
else:
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
|