import gradio as gr import boto3 import sagemaker import json import io import os from transformers import AutoTokenizer region = os.getenv("region") sm_endpoint_name = os.getenv("sm_endpoint_name") access_key = os.getenv("access_key") secret_key = os.getenv("secret_key") session = boto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=region ) sess = sagemaker.Session(boto_session=session) smr = session.client("sagemaker-runtime") DEFAULT_SYSTEM_PROMPT = ( "You are an helpful, concise and direct Assistant." ) # load the tokenizer tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") MAX_INPUT_TOKEN_LENGTH = 256 # hyperparameters for llm parameters = { "do_sample": True, "top_p": 0.6, "temperature": 0.9, "max_new_tokens": 768, "repetition_penalty": 1.2, "return_full_text": False, } # Helper for reading lines from a stream class LineIterator: def __init__(self, stream): self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 def __iter__(self): return self def __next__(self): while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: chunk = next(self.byte_iterator) except StopIteration: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise if "PayloadPart" not in chunk: print("Unknown event type:" + chunk) continue self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) def format_prompt(message, history): ''' messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] for interaction in history: messages.append({"role": "user", "content": interaction[0]}) messages.append({"role": "assistant", "content": interaction[1]}) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) ''' messages = [ {"role": "user", "content": "Can you tell me an interesting fact about AWS?"},] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return prompt def generate( prompt, history, ): formatted_prompt = format_prompt(prompt, history) check_input_token_length(formatted_prompt) request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True} resp = smr.invoke_endpoint_with_response_stream( EndpointName=endpoint_name, Body=json.dumps(request), ContentType="application/json", ) output = "" for c in LineIterator(resp["Body"]): c = c.decode("utf-8") if c.startswith("data:"): chunk = json.loads(c.lstrip("data:").rstrip("/n")) if chunk["token"]["special"]: continue if chunk["token"]["text"] in request["parameters"]["stop"]: break output += chunk["token"]["text"] for stop_str in request["parameters"]["stop"]: if output.endswith(stop_str): output = output[: -len(stop_str)] output = output.rstrip() yield output yield output return output def check_input_token_length(prompt: str) -> None: input_token_length = len(tokenizer(prompt)["input_ids"]) if input_token_length > MAX_INPUT_TOKEN_LENGTH: raise gr.Error( f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again." ) theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[ gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif", ], ) demo = gr.ChatInterface( generate, chatbot=gr.Chatbot(layout="panel"), theme=theme, ) demo.queue(concurrency_count=5).launch(share=False)