chansung's picture
Create llama2.py
dd486e6
raw
history blame
2.51 kB
import os
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
class LLaMA2ChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
if context is None or context == "":
return ""
else:
return f"""<<SYS>>
{context}
<</SYS>>
"""
@classmethod
def prompt(cls, pingpong, truncate_size):
ping = pingpong.ping[:truncate_size]
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
return f"""[INST] {ping} [/INST] {pong}"""
class LLaMA2ChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = fmt.ctx(self.ctx)
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
async def gen_text(
prompt,
hf_model='meta-llama/Llama-2-70b-chat-hf',
hf_token=None,
parameters=None
):
if hf_token is None:
raise ValueError("Hugging Face Token is not set")
if parameters is None:
parameters = {
'max_new_tokens': 512,
'do_sample': True,
'return_full_text': False,
'temperature': 1.0,
'top_k': 50,
# 'top_p': 1.0,
'repetition_penalty': 1.2
}
url = f'https://api-inference.huggingface.co/models/{hf_model}'
headers={
'Authorization': f'Bearer {hf_token}',
'Content-type': 'application/json'
}
data = {
'inputs': prompt,
'stream': True,
'options': {
'use_cache': False,
},
'parameters': parameters
}
r = requests.post(
url,
headers=headers,
data=json.dumps(data),
stream=True
)
client = sseclient.SSEClient(r)
for event in client.events():
yield json.loads(event.data)['token']['text']