Spaces:
Sleeping
Sleeping
import gradio as gr | |
import boto3 | |
import sagemaker | |
import json | |
import io | |
import os | |
from transformers import AutoTokenizer | |
from huggingface_hub import login | |
region = os.getenv("region") | |
sm_endpoint_name = os.getenv("sm_endpoint_name") | |
access_key = os.getenv("access_key") | |
secret_key = os.getenv("secret_key") | |
hf_token = os.getenv("hf_read_access") | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
print("hf_token",hf_token) | |
print("HF_TOKEN",HF_TOKEN) | |
session = boto3.Session( | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key, | |
region_name=region | |
) | |
sess = sagemaker.Session(boto_session=session) | |
smr = session.client("sagemaker-runtime") | |
DEFAULT_SYSTEM_PROMPT = ( | |
"You are an helpful, concise and direct Assistant." | |
) | |
# load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",token=hf_token) | |
MAX_INPUT_TOKEN_LENGTH = 256 | |
# hyperparameters for llm | |
parameters = { | |
"do_sample": True, | |
"top_p": 0.6, | |
"temperature": 0.9, | |
"max_new_tokens": 768, | |
"repetition_penalty": 1.2, | |
"return_full_text": False, | |
"stop": ["</s>"], | |
} | |
# Helper for reading lines from a stream | |
class LineIterator: | |
def __init__(self, stream): | |
self.byte_iterator = iter(stream) | |
self.buffer = io.BytesIO() | |
self.read_pos = 0 | |
def __iter__(self): | |
return self | |
def __next__(self): | |
while True: | |
self.buffer.seek(self.read_pos) | |
line = self.buffer.readline() | |
if line and line[-1] == ord("\n"): | |
self.read_pos += len(line) | |
return line[:-1] | |
try: | |
chunk = next(self.byte_iterator) | |
except StopIteration: | |
if self.read_pos < self.buffer.getbuffer().nbytes: | |
continue | |
raise | |
if "PayloadPart" not in chunk: | |
print("Unknown event type:" + chunk) | |
continue | |
self.buffer.seek(0, io.SEEK_END) | |
self.buffer.write(chunk["PayloadPart"]["Bytes"]) | |
def format_prompt(message, history): | |
''' | |
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] | |
for interaction in history: | |
messages.append({"role": "user", "content": interaction[0]}) | |
messages.append({"role": "assistant", "content": interaction[1]}) | |
messages.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
''' | |
messages = [ | |
{"role": "user", "content": "Can you tell me an interesting fact about AWS?"},] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
return prompt | |
def generate( | |
prompt, | |
history, | |
): | |
formatted_prompt = format_prompt(prompt, history) | |
check_input_token_length(formatted_prompt) | |
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True} | |
resp = smr.invoke_endpoint_with_response_stream( | |
EndpointName=sm_endpoint_name, | |
Body=json.dumps(request), | |
ContentType="application/json", | |
) | |
output = "" | |
for c in LineIterator(resp["Body"]): | |
c = c.decode("utf-8") | |
if c.startswith("data:"): | |
chunk = json.loads(c.lstrip("data:").rstrip("/n")) | |
if chunk["token"]["special"]: | |
continue | |
if chunk["token"]["text"] in request["parameters"]["stop"]: | |
break | |
output += chunk["token"]["text"] | |
for stop_str in request["parameters"]["stop"]: | |
if output.endswith(stop_str): | |
output = output[: -len(stop_str)] | |
output = output.rstrip() | |
yield output | |
yield output | |
return output | |
def check_input_token_length(prompt: str) -> None: | |
input_token_length = len(tokenizer(prompt)["input_ids"]) | |
if input_token_length > MAX_INPUT_TOKEN_LENGTH: | |
raise gr.Error( | |
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again." | |
) | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Open Sans"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
) | |
demo = gr.ChatInterface( | |
generate, | |
chatbot=gr.Chatbot(layout="panel"), | |
theme=theme, | |
) | |
demo.queue().launch(share=False) |