File size: 4,420 Bytes
26bb9d0 63ad984 26bb9d0 fb8b8c3 c1ded61 e3382f4 26bb9d0 bae1d90 26bb9d0 6370d36 26bb9d0 0005f3b 26bb9d0 e3382f4 26bb9d0 992e802 26bb9d0 8be8efe b470587 f97d79d 71a7f7f 992e802 f97d79d 37f0e68 f97d79d 6d4301b f97d79d 992e802 63ad984 cfb1781 82c6071 71a7f7f e3382f4 63ad984 cfb1781 26bb9d0 73565c8 26bb9d0 566b50e 1f04ccd 26bb9d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
from threading import Thread
from typing import Iterator
import requests
import json
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
HF_TOKEN = "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" ##Llama Write Token
MAX_MAX_NEW_TOKENS = 8192
DEFAULT_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Llama. Protected. With Protecto.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU. Please enable GPU</p>"
if torch.cuda.is_available():
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 8192,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
concatenated_outputs = "".join([r"{}".format(text) for text in streamer])
# Mask the output here
masked_output = mask_with_protecto(concatenated_outputs)
masked_output = format_for_html(masked_output)
yield masked_output
# Ensuring entity tags are properly rendered
def format_for_html(text):
text = text.replace("<", "<")
text = text.replace(">", ">")
return text
def mask_with_protecto(text_for_prompt):
mask_request_url = "https://trial.protecto.ai/api/vault/mask"
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3N1ZXIiOiJQcm90ZWN0byIsImV4cGlyYXRpb25fZGF0ZSI6IjIwMjMtMTEtMDQiLCJwZXJtaXNzaW9ucyI6WyJyZWFkIiwid3JpdGUiXSwidXNlcl9uYW1lIjoiZGlwYXlhbkBjb2V1c2xlYXJuaW5nLmNvbSIsImRiX25hbWUiOiJwcm90ZWN0b19jb2V1c2xlYXJuaW5nX25ydG1mYmFrIiwiaGFzaGVkX3Bhc3N3b3JkIjoiMjIyMTI2ZWNiZTlkZTRmNWJlODdiY2QyYWFlZWRlM2FmNDc5MzMxZmNhOTUxMWU0MDRiNzkxNDM1MGI4MWUyYiJ9.DeIK00NuhM51lRwWdnUXuQSBA1aBn5AQ8qM3pIeM01U"
}
mask_input = {
"mask": [
{
"value": text_for_prompt
}
]
}
response = requests.put(mask_request_url, headers=headers, json=mask_input)
if response.status_code == 200:
# Parse the masked result from the API response and format it for display
masked_result = response.json()
masked_result_token_value = str(masked_result["data"][0]["token_value"])
return_value = masked_result_token_value
return(return_value)
else:
# Return an error message if the API request was not successful.
return(str(response.status_code))
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
],
retry_btn=None,
stop_btn=None,
undo_btn=None,
clear_btn=None,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch() |