File size: 4,420 Bytes
26bb9d0
 
 
63ad984
 
26bb9d0
 
 
 
 
fb8b8c3
c1ded61
e3382f4
 
26bb9d0
 
 
bae1d90
26bb9d0
 
 
6370d36
26bb9d0
 
 
 
0005f3b
 
26bb9d0
 
 
 
 
 
 
e3382f4
26bb9d0
 
 
 
 
 
992e802
26bb9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be8efe
b470587
 
 
f97d79d
71a7f7f
992e802
f97d79d
37f0e68
f97d79d
6d4301b
 
f97d79d
 
992e802
63ad984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfb1781
 
82c6071
71a7f7f
e3382f4
63ad984
cfb1781
 
26bb9d0
73565c8
 
26bb9d0
 
566b50e
 
 
1f04ccd
 
 
 
26bb9d0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from threading import Thread
from typing import Iterator
import requests
import json
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

HF_TOKEN = "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"  ##Llama Write Token

MAX_MAX_NEW_TOKENS = 8192
DEFAULT_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# Llama. Protected. With Protecto.
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU. Please enable GPU</p>"


if torch.cuda.is_available():
    model_id = "meta-llama/Llama-2-7b-chat-hf"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
    tokenizer.use_default_system_prompt = False

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 8192,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []

    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    concatenated_outputs = "".join([r"{}".format(text) for text in streamer])

    # Mask the output here
    masked_output = mask_with_protecto(concatenated_outputs)
    masked_output = format_for_html(masked_output)
    yield masked_output


# Ensuring entity tags are properly rendered
def format_for_html(text):
    text = text.replace("<", "&lt;")
    text = text.replace(">", "&gt;")
    return text

    
def mask_with_protecto(text_for_prompt):
    mask_request_url = "https://trial.protecto.ai/api/vault/mask"
    headers = {
        "Content-Type": "application/json; charset=utf-8",
        "Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3N1ZXIiOiJQcm90ZWN0byIsImV4cGlyYXRpb25fZGF0ZSI6IjIwMjMtMTEtMDQiLCJwZXJtaXNzaW9ucyI6WyJyZWFkIiwid3JpdGUiXSwidXNlcl9uYW1lIjoiZGlwYXlhbkBjb2V1c2xlYXJuaW5nLmNvbSIsImRiX25hbWUiOiJwcm90ZWN0b19jb2V1c2xlYXJuaW5nX25ydG1mYmFrIiwiaGFzaGVkX3Bhc3N3b3JkIjoiMjIyMTI2ZWNiZTlkZTRmNWJlODdiY2QyYWFlZWRlM2FmNDc5MzMxZmNhOTUxMWU0MDRiNzkxNDM1MGI4MWUyYiJ9.DeIK00NuhM51lRwWdnUXuQSBA1aBn5AQ8qM3pIeM01U"
    }
    mask_input = {
    "mask": [
        {
            "value": text_for_prompt
        }
    ]
    }
    response = requests.put(mask_request_url, headers=headers, json=mask_input)
    if response.status_code == 200:
        # Parse the masked result from the API response and format it for display
        masked_result = response.json()
        masked_result_token_value = str(masked_result["data"][0]["token_value"])
        return_value = masked_result_token_value
        return(return_value)
    else:
        # Return an error message if the API request was not successful.
        return(str(response.status_code))



chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
    ],
    retry_btn=None, 
    stop_btn=None, 
    undo_btn=None, 
    clear_btn=None,
)
with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()