chatglm2-6b-ggml / chatglm_langchain.py
arkii's picture
Create chatglm_langchain.py
3da7b5e
raw
history blame
3.1 kB
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import chatglm_cpp
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin"
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH)
class ChatGLM(LLM):
temperature: float = 0.7
base_model: str = DEFAULT_MODEL_PATH
max_length: int = 2048
verbose: bool = False
streaming: bool = False
top_p: float = 0.9
top_k: int = 0
max_context_length: int = 512
threads: int = 0
@property
def _llm_type(self) -> str:
return "chatglm"
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
history = [prompt]
response = ""
if self.streaming:
for piece in pipeline.stream_chat(
history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
num_threads=self.threads,
):
response += piece
return response
# yield piece
# response += piece
# history.append(response)
# yield response
else:
response = pipeline.chat(
history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
num_threads=self.threads,
)
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"temperature": self.temperature,
"base_model": self.base_model,
"max_length": self.max_length,
"verbose": self.verbose,
"streaming": self.streaming,
"top_p": self.top_p,
"top_k": self.top_k,
"max_context_length": self.max_context_length,
"threads": self.threads}
template = "小明的妈妈有两个孩子,一个叫大明 {question}"
prompt = PromptTemplate(template=template, input_variables=["question"])
question = "另外一个叫什么?"
llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True)
llm_chain = LLMChain(prompt=prompt, llm=llm)
print(llm_chain.run(question))