coeuslearning's picture
Update app.py
37f0e68
raw
history blame
4.42 kB
import os
from threading import Thread
from typing import Iterator
import requests
import json
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
HF_TOKEN = "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" ##Llama Write Token
MAX_MAX_NEW_TOKENS = 8192
DEFAULT_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Llama. Protected. With Protecto.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU. Please enable GPU</p>"
if torch.cuda.is_available():
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 8192,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
concatenated_outputs = "".join([r"{}".format(text) for text in streamer])
# Mask the output here
masked_output = mask_with_protecto(concatenated_outputs)
masked_output = format_for_html(masked_output)
yield masked_output
# Ensuring entity tags are properly rendered
def format_for_html(text):
text = text.replace("<", "&lt;")
text = text.replace(">", "&gt;")
return text
def mask_with_protecto(text_for_prompt):
mask_request_url = "https://trial.protecto.ai/api/vault/mask"
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3N1ZXIiOiJQcm90ZWN0byIsImV4cGlyYXRpb25fZGF0ZSI6IjIwMjMtMTEtMDQiLCJwZXJtaXNzaW9ucyI6WyJyZWFkIiwid3JpdGUiXSwidXNlcl9uYW1lIjoiZGlwYXlhbkBjb2V1c2xlYXJuaW5nLmNvbSIsImRiX25hbWUiOiJwcm90ZWN0b19jb2V1c2xlYXJuaW5nX25ydG1mYmFrIiwiaGFzaGVkX3Bhc3N3b3JkIjoiMjIyMTI2ZWNiZTlkZTRmNWJlODdiY2QyYWFlZWRlM2FmNDc5MzMxZmNhOTUxMWU0MDRiNzkxNDM1MGI4MWUyYiJ9.DeIK00NuhM51lRwWdnUXuQSBA1aBn5AQ8qM3pIeM01U"
}
mask_input = {
"mask": [
{
"value": text_for_prompt
}
]
}
response = requests.put(mask_request_url, headers=headers, json=mask_input)
if response.status_code == 200:
# Parse the masked result from the API response and format it for display
masked_result = response.json()
masked_result_token_value = str(masked_result["data"][0]["token_value"])
return_value = masked_result_token_value
return(return_value)
else:
# Return an error message if the API request was not successful.
return(str(response.status_code))
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
],
retry_btn=None,
stop_btn=None,
undo_btn=None,
clear_btn=None,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()