|
import copy |
|
import os |
|
import types |
|
import uuid |
|
from typing import Any, Dict, List, Union, Optional, Tuple, Mapping |
|
import time |
|
import queue |
|
import pathlib |
|
from datetime import datetime |
|
|
|
from langchain.schema import BasePromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.chains import MapReduceDocumentsChain, StuffDocumentsChain, ReduceDocumentsChain |
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain |
|
from langchain.chains.summarize import map_reduce_prompt, LoadingCallable, _load_stuff_chain, _load_map_reduce_chain, \ |
|
_load_refine_chain |
|
from langchain.schema.language_model import BaseLanguageModel |
|
|
|
from src.utils import hash_file, get_sha |
|
|
|
from langchain.callbacks.base import BaseCallbackHandler, Callbacks |
|
from langchain.schema import LLMResult |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
|
|
|
|
class StreamingGradioCallbackHandler(BaseCallbackHandler): |
|
""" |
|
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend |
|
""" |
|
|
|
def __init__(self, timeout: Optional[float] = None, block=True, max_time=None, verbose=False): |
|
super().__init__() |
|
self.text_queue = queue.SimpleQueue() |
|
self.stop_signal = None |
|
self.do_stop = False |
|
self.timeout = timeout |
|
self.block = block |
|
self.max_time = max_time |
|
self.tgen0 = None |
|
self.verbose = verbose |
|
|
|
def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
self.tgen0 = time.time() |
|
"""Run when LLM starts running. Clean the queue.""" |
|
while not self.text_queue.empty(): |
|
try: |
|
self.text_queue.get(block=False) |
|
except queue.Empty: |
|
continue |
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
"""Run on new LLM token. Only available when streaming is enabled.""" |
|
if self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time: |
|
if self.verbose: |
|
print("Took too long in StreamingGradioCallbackHandler: %s" % (time.time() - self.tgen0), flush=True) |
|
self.text_queue.put(self.stop_signal) |
|
else: |
|
self.text_queue.put(token) |
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
|
"""Run when LLM ends running.""" |
|
self.text_queue.put(self.stop_signal) |
|
|
|
def on_llm_error( |
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM errors.""" |
|
self.text_queue.put(self.stop_signal) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
while True: |
|
try: |
|
value = self.stop_signal |
|
if self.do_stop: |
|
print("hit stop", flush=True) |
|
|
|
raise StopIteration() |
|
|
|
value = self.text_queue.get(block=self.block, timeout=self.timeout) |
|
break |
|
except queue.Empty: |
|
time.sleep(0.01) |
|
if value == self.stop_signal: |
|
raise StopIteration() |
|
else: |
|
return value |
|
|
|
|
|
def _chunk_sources(sources, chunk=True, chunk_size=512, language=None, db_type=None): |
|
assert db_type is not None |
|
|
|
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources): |
|
|
|
sources = [sources] |
|
if not chunk: |
|
[x.metadata.update(dict(chunk_id=0)) for chunk_id, x in enumerate(sources)] |
|
if db_type in ['chroma', 'chroma_old']: |
|
|
|
source_chunks = [Document(page_content=x.page_content, |
|
metadata=copy.deepcopy(x.metadata) or {}) |
|
for x in sources] |
|
else: |
|
source_chunks = sources |
|
else: |
|
if language and False: |
|
|
|
|
|
|
|
keep_separator = True |
|
separators = RecursiveCharacterTextSplitter.get_separators_for_language(language) |
|
else: |
|
separators = ["\n\n", "\n", " ", ""] |
|
keep_separator = False |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator, |
|
separators=separators) |
|
source_chunks = splitter.split_documents(sources) |
|
|
|
|
|
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)] |
|
|
|
if db_type in ['chroma', 'chroma_old']: |
|
|
|
|
|
|
|
|
|
[x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sources)] |
|
|
|
|
|
return list(sources) + source_chunks |
|
else: |
|
return source_chunks |
|
|
|
|
|
def add_parser(docs1, parser): |
|
[x.metadata.update(dict(parser=x.metadata.get('parser', parser))) for x in docs1] |
|
|
|
|
|
def _add_meta(docs1, file, headsize=50, filei=0, parser='NotSet'): |
|
if os.path.isfile(file): |
|
file_extension = pathlib.Path(file).suffix |
|
hashid = hash_file(file) |
|
else: |
|
file_extension = str(file) |
|
hashid = get_sha(file) |
|
doc_hash = str(uuid.uuid4())[:10] |
|
if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
|
docs1 = [docs1] |
|
[x.metadata.update(dict(input_type=file_extension, |
|
parser=x.metadata.get('parser', parser), |
|
date=str(datetime.now()), |
|
time=time.time(), |
|
order_id=order_id, |
|
hashid=hashid, |
|
doc_hash=doc_hash, |
|
file_id=filei, |
|
head=x.page_content[:headsize].strip())) for order_id, x in enumerate(docs1)] |
|
|
|
|
|
def fix_json_meta(docs1): |
|
if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
|
docs1 = [docs1] |
|
|
|
[x.metadata.update(dict(sender_name=x.metadata.get('sender_name') or '')) for x in docs1] |
|
[x.metadata.update(dict(timestamp_ms=x.metadata.get('timestamp_ms') or '')) for x in docs1] |
|
|
|
|
|
class H2OMapReduceDocumentsChain(MapReduceDocumentsChain): |
|
def combine_docs( |
|
self, |
|
docs: List[Document], |
|
token_max: Optional[int] = None, |
|
callbacks: Callbacks = None, |
|
**kwargs: Any, |
|
) -> Tuple[List, dict]: |
|
"""Combine documents in a map reduce manner. |
|
|
|
Combine by mapping first chain over all documents, then reducing the results. |
|
This reducing can be done recursively if needed (if there are many documents). |
|
""" |
|
map_results = self.llm_chain.apply( |
|
|
|
[{self.document_variable_name: d.page_content, **kwargs} for d in docs], |
|
callbacks=callbacks, |
|
) |
|
question_result_key = self.llm_chain.output_key |
|
result_docs = [ |
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata) |
|
|
|
for i, r in enumerate(map_results) |
|
] |
|
extra_return_dict = {} |
|
if self.return_intermediate_steps: |
|
intermediate_steps = [r[question_result_key] for r in map_results] |
|
extra_return_dict["intermediate_steps"] = intermediate_steps |
|
result_docs_content = [x.page_content for x in result_docs] |
|
return result_docs_content, extra_return_dict |
|
|
|
async def acombine_docs( |
|
self, |
|
docs: List[Document], |
|
token_max: Optional[int] = None, |
|
callbacks: Callbacks = None, |
|
**kwargs: Any, |
|
) -> Tuple[List, dict]: |
|
"""Combine documents in a map reduce manner. |
|
|
|
Combine by mapping first chain over all documents, then reducing the results. |
|
This reducing can be done recursively if needed (if there are many documents). |
|
""" |
|
map_results = await self.llm_chain.aapply( |
|
|
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], |
|
callbacks=callbacks, |
|
) |
|
question_result_key = self.llm_chain.output_key |
|
result_docs = [ |
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata) |
|
|
|
for i, r in enumerate(map_results) |
|
] |
|
extra_return_dict = {} |
|
if self.return_intermediate_steps: |
|
intermediate_steps = [r[question_result_key] for r in map_results] |
|
extra_return_dict["intermediate_steps"] = intermediate_steps |
|
result_docs_content = [x.page_content for x in result_docs] |
|
return result_docs_content, extra_return_dict |
|
|
|
@property |
|
def _chain_type(self) -> str: |
|
return "map_documents_chain" |
|
|
|
|
|
def _load_map_chain( |
|
llm: BaseLanguageModel, |
|
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
|
combine_document_variable_name: str = "text", |
|
map_reduce_document_variable_name: str = "text", |
|
collapse_prompt: Optional[BasePromptTemplate] = None, |
|
reduce_llm: Optional[BaseLanguageModel] = None, |
|
collapse_llm: Optional[BaseLanguageModel] = None, |
|
verbose: Optional[bool] = None, |
|
token_max: int = 3000, |
|
callbacks: Callbacks = None, |
|
**kwargs: Any, |
|
) -> H2OMapReduceDocumentsChain: |
|
map_chain = LLMChain( |
|
llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks |
|
) |
|
_reduce_llm = reduce_llm or llm |
|
reduce_chain = LLMChain( |
|
llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks |
|
) |
|
|
|
combine_documents_chain = StuffDocumentsChain( |
|
llm_chain=reduce_chain, |
|
document_variable_name=combine_document_variable_name, |
|
verbose=verbose, |
|
callbacks=callbacks, |
|
) |
|
if collapse_prompt is None: |
|
collapse_chain = None |
|
if collapse_llm is not None: |
|
raise ValueError( |
|
"collapse_llm provided, but collapse_prompt was not: please " |
|
"provide one or stop providing collapse_llm." |
|
) |
|
else: |
|
_collapse_llm = collapse_llm or llm |
|
collapse_chain = StuffDocumentsChain( |
|
llm_chain=LLMChain( |
|
llm=_collapse_llm, |
|
prompt=collapse_prompt, |
|
verbose=verbose, |
|
callbacks=callbacks, |
|
), |
|
document_variable_name=combine_document_variable_name, |
|
) |
|
reduce_documents_chain = ReduceDocumentsChain( |
|
combine_documents_chain=combine_documents_chain, |
|
collapse_documents_chain=collapse_chain, |
|
token_max=token_max, |
|
verbose=verbose, |
|
callbacks=callbacks, |
|
) |
|
return H2OMapReduceDocumentsChain( |
|
llm_chain=map_chain, |
|
reduce_documents_chain=reduce_documents_chain, |
|
document_variable_name=map_reduce_document_variable_name, |
|
verbose=verbose, |
|
callbacks=callbacks, |
|
**kwargs, |
|
) |
|
|
|
|
|
def load_general_summarization_chain( |
|
llm: BaseLanguageModel, |
|
chain_type: str = "stuff", |
|
verbose: Optional[bool] = None, |
|
**kwargs: Any, |
|
) -> BaseCombineDocumentsChain: |
|
"""Load summarizing chain. |
|
|
|
Args: |
|
llm: Language Model to use in the chain. |
|
chain_type: Type of document combining chain to use. Should be one of "stuff", |
|
"map_reduce", and "refine". |
|
verbose: Whether chains should be run in verbose mode or not. Note that this |
|
applies to all chains that make up the final chain. |
|
|
|
Returns: |
|
A chain to use for summarizing. |
|
""" |
|
loader_mapping: Mapping[str, LoadingCallable] = { |
|
"stuff": _load_stuff_chain, |
|
"map_reduce": _load_map_reduce_chain, |
|
"refine": _load_refine_chain, |
|
"map": _load_map_chain, |
|
} |
|
if chain_type not in loader_mapping: |
|
raise ValueError( |
|
f"Got unsupported chain type: {chain_type}. " |
|
f"Should be one of {loader_mapping.keys()}" |
|
) |
|
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) |
|
|
|
|
|
"""Utils for interacting with the Semantic Scholar API.""" |
|
import logging |
|
from typing import Any, Dict, Optional |
|
|
|
from langchain_core.pydantic_v1 import BaseModel, root_validator |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class H2OSemanticScholarAPIWrapper(BaseModel): |
|
"""Wrapper around semanticscholar.org API. |
|
https://github.com/danielnsilva/semanticscholar |
|
|
|
You should have this library installed. |
|
|
|
`pip install semanticscholar` |
|
|
|
Semantic Scholar API can conduct searches and fetch document metadata |
|
like title, abstract, authors, etc. |
|
|
|
Attributes: |
|
top_k_results: number of the top-scored document used for the Semantic Scholar tool |
|
load_max_docs: a limit to the number of loaded documents |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper |
|
ss = SemanticScholarAPIWrapper( |
|
top_k_results = 3, |
|
load_max_docs = 3 |
|
) |
|
ss.run("biases in large language models") |
|
""" |
|
|
|
semanticscholar_search: Any |
|
top_k_results: int = 5 |
|
S2_MAX_QUERY_LENGTH: int = 300 |
|
load_max_docs: int = 100 |
|
doc_content_chars_max: Optional[int] = 4000 |
|
returned_fields = [ |
|
"title", |
|
"abstract", |
|
"venue", |
|
"year", |
|
"paperId", |
|
"citationCount", |
|
"openAccessPdf", |
|
"authors", |
|
"externalIds", |
|
] |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that the python package exists in environment.""" |
|
try: |
|
from semanticscholar import SemanticScholar |
|
|
|
sch = SemanticScholar(api_key=os.getenv('S2_API_KEY')) |
|
values["semanticscholar_search"] = sch.search_paper |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import Semanticscholar python package. " |
|
"Please install it with `pip install semanticscholar`." |
|
) |
|
return values |
|
|
|
def run(self, query: str) -> str: |
|
"""Run the Semantic Scholar API.""" |
|
results = self.semanticscholar_search( |
|
query, limit=self.load_max_docs, fields=self.returned_fields |
|
) |
|
documents = [] |
|
for item in results[: self.top_k_results]: |
|
authors = ", ".join( |
|
author["name"] for author in getattr(item, "authors", []) |
|
) |
|
documents.append( |
|
f"Published year: {getattr(item, 'year', None)}\n" |
|
f"Title: {getattr(item, 'title', None)}\n" |
|
f"Authors: {authors}\n" |
|
f"Astract: {getattr(item, 'abstract', None)}\n" |
|
) |
|
|
|
if documents: |
|
return "\n\n".join(documents)[: self.doc_content_chars_max] |
|
else: |
|
return "No results found." |
|
|