Spaces:
Sleeping
Sleeping
File size: 4,540 Bytes
7ea4cc0 b5a3b2c 7ea4cc0 55f2708 b5a3b2c 4082459 55f2708 509bdb6 9bfcf52 7ea4cc0 50e8e05 b5a3b2c 4082459 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 b5a3b2c 2bf22d0 b5a3b2c 7ea4cc0 b5a3b2c 7ea4cc0 0d0b874 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import boto3
import sagemaker
import json
import io
import os
from transformers import AutoTokenizer
from huggingface_hub import login
region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
HF_TOKEN = os.getenv('HF_TOKEN')
print("hf_token",hf_token)
print("HF_TOKEN",HF_TOKEN)
session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region
)
sess = sagemaker.Session(boto_session=session)
smr = session.client("sagemaker-runtime")
DEFAULT_SYSTEM_PROMPT = (
"You are an helpful, concise and direct Assistant."
)
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",token=hf_token)
MAX_INPUT_TOKEN_LENGTH = 256
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"max_new_tokens": 768,
"repetition_penalty": 1.2,
"return_full_text": False,
}
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
def format_prompt(message, history):
'''
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for interaction in history:
messages.append({"role": "user", "content": interaction[0]})
messages.append({"role": "assistant", "content": interaction[1]})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
'''
messages = [
{"role": "user", "content": "Can you tell me an interesting fact about AWS?"},]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def generate(
prompt,
history,
):
formatted_prompt = format_prompt(prompt, history)
check_input_token_length(formatted_prompt)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=sm_endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
def check_input_token_length(prompt: str) -> None:
input_token_length = len(tokenizer(prompt)["input_ids"])
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
)
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
demo = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(layout="panel"),
theme=theme,
)
demo.queue().launch(share=False) |