|
from typing import Any, List, Mapping, Optional |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms.base import LLM |
|
import chatglm_cpp |
|
|
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
|
|
DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin" |
|
|
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH) |
|
|
|
class ChatGLM(LLM): |
|
temperature: float = 0.7 |
|
base_model: str = DEFAULT_MODEL_PATH |
|
max_length: int = 2048 |
|
verbose: bool = False |
|
streaming: bool = False |
|
top_p: float = 0.9 |
|
top_k: int = 0 |
|
max_context_length: int = 512 |
|
threads: int = 0 |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "chatglm" |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: |
|
if stop is not None: |
|
raise ValueError("stop kwargs are not permitted.") |
|
print("Prompt: ", prompt) |
|
history = [prompt] |
|
response = "" |
|
if self.streaming: |
|
for piece in pipeline.stream_chat( |
|
history, |
|
max_length=self.max_length, |
|
max_context_length=self.max_context_length, |
|
do_sample=self.temperature > 0, |
|
top_k=self.top_k, |
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
num_threads=self.threads, |
|
): |
|
response += piece |
|
return response |
|
|
|
|
|
|
|
|
|
else: |
|
response = pipeline.chat( |
|
history, |
|
max_length=self.max_length, |
|
max_context_length=self.max_context_length, |
|
do_sample=self.temperature > 0, |
|
top_k=self.top_k, |
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
num_threads=self.threads, |
|
) |
|
return response |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
"""Get the identifying parameters.""" |
|
return {"temperature": self.temperature, |
|
"base_model": self.base_model, |
|
"max_length": self.max_length, |
|
"verbose": self.verbose, |
|
"streaming": self.streaming, |
|
"top_p": self.top_p, |
|
"top_k": self.top_k, |
|
"max_context_length": self.max_context_length, |
|
"threads": self.threads} |
|
|
|
|
|
template = "小明的妈妈有两个孩子,一个叫大明 {question}" |
|
prompt = PromptTemplate(template=template, input_variables=["question"]) |
|
question = "另外一个叫什么?" |
|
llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
print(llm_chain.run(question)) |