File size: 4,562 Bytes
7ea4cc0
 
b5a3b2c
7ea4cc0
 
55f2708
b5a3b2c
4082459
 
55f2708
 
 
 
 
509bdb6
9bfcf52
 
 
7ea4cc0
50e8e05
 
 
 
 
 
 
b5a3b2c
 
 
 
 
 
 
4082459
b5a3b2c
 
 
 
7ea4cc0
 
 
 
 
b5a3b2c
 
7ea4cc0
3854982
7ea4cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5a3b2c
 
 
 
 
 
7ea4cc0
b5a3b2c
 
 
 
7ea4cc0
b5a3b2c
 
 
 
 
7ea4cc0
b5a3b2c
 
 
 
7ea4cc0
b5a3b2c
 
7ea4cc0
b5a3b2c
7ea4cc0
b5a3b2c
 
2bf22d0
b5a3b2c
 
7ea4cc0
b5a3b2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea4cc0
0d0b874
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import boto3
import sagemaker
import json
import io
import os
from transformers import AutoTokenizer
from huggingface_hub import login


region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
HF_TOKEN = os.getenv('HF_TOKEN')
print("hf_token",hf_token)
print("HF_TOKEN",HF_TOKEN)

session = boto3.Session(
    aws_access_key_id=access_key,
    aws_secret_access_key=secret_key,
    region_name=region
)
sess = sagemaker.Session(boto_session=session)

smr = session.client("sagemaker-runtime")

DEFAULT_SYSTEM_PROMPT = (
    "You are an helpful, concise and direct Assistant."
)

# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",token=hf_token)


MAX_INPUT_TOKEN_LENGTH = 256

# hyperparameters for llm
parameters = {
    "do_sample": True,
    "top_p": 0.6,
    "temperature": 0.9,
    "max_new_tokens": 768,
    "repetition_penalty": 1.2,
    "return_full_text": False,
    "stop": ["</s>"],
}


# Helper for reading lines from a stream
class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord("\n"):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if "PayloadPart" not in chunk:
                print("Unknown event type:" + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk["PayloadPart"]["Bytes"])


def format_prompt(message, history):
    '''
    messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
    for interaction in history:
        messages.append({"role": "user", "content": interaction[0]})
        messages.append({"role": "assistant", "content": interaction[1]})
    messages.append({"role": "user", "content": message})
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    '''

    messages = [
    {"role": "user", "content": "Can you tell me an interesting fact about AWS?"},]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
    return prompt


def generate(
    prompt,
    history,
):
    formatted_prompt = format_prompt(prompt, history)
    check_input_token_length(formatted_prompt)

    request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}


    resp = smr.invoke_endpoint_with_response_stream(
        EndpointName=sm_endpoint_name,
        Body=json.dumps(request),
        ContentType="application/json",
    )
    output = ""
    for c in LineIterator(resp["Body"]):
        c = c.decode("utf-8")
        if c.startswith("data:"):
            chunk = json.loads(c.lstrip("data:").rstrip("/n"))
            if chunk["token"]["special"]:
                continue
            if chunk["token"]["text"] in request["parameters"]["stop"]:
                break
            output += chunk["token"]["text"]
            for stop_str in request["parameters"]["stop"]:
                if output.endswith(stop_str):
                    output = output[: -len(stop_str)]
                    output = output.rstrip()
                    yield output

            yield output
    return output


def check_input_token_length(prompt: str) -> None:
    input_token_length = len(tokenizer(prompt)["input_ids"])
    if input_token_length > MAX_INPUT_TOKEN_LENGTH:
        raise gr.Error(
            f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
        )


theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)


demo = gr.ChatInterface(
    generate,
    chatbot=gr.Chatbot(layout="panel"),
    theme=theme,
)

demo.queue().launch(share=False)