File size: 3,138 Bytes
3da7b5e 83ef88f 3da7b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import chatglm_cpp
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin"
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH)
class ChatGLM(LLM):
temperature: float = 0.7
base_model: str = DEFAULT_MODEL_PATH
max_length: int = 2048
verbose: bool = False
streaming: bool = False
top_p: float = 0.9
top_k: int = 0
max_context_length: int = 512
threads: int = 0
@property
def _llm_type(self) -> str:
return "chatglm"
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
print("Prompt: ", prompt)
history = [prompt]
response = ""
if self.streaming:
for piece in pipeline.stream_chat(
history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
num_threads=self.threads,
):
response += piece
return response
# yield piece
# response += piece
# history.append(response)
# yield response
else:
response = pipeline.chat(
history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
num_threads=self.threads,
)
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"temperature": self.temperature,
"base_model": self.base_model,
"max_length": self.max_length,
"verbose": self.verbose,
"streaming": self.streaming,
"top_p": self.top_p,
"top_k": self.top_k,
"max_context_length": self.max_context_length,
"threads": self.threads}
template = "小明的妈妈有两个孩子,一个叫大明 {question}"
prompt = PromptTemplate(template=template, input_variables=["question"])
question = "另外一个叫什么?"
llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True)
llm_chain = LLMChain(prompt=prompt, llm=llm)
print(llm_chain.run(question)) |