coeuslearning's picture
Update app.py
ce33d9d
raw
history blame
4.69 kB
import os
from threading import Thread
from typing import Iterator
import requests
import json
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
HF_TOKEN = "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" ##Llama Write Token
MAX_MAX_NEW_TOKENS = 8192
DEFAULT_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Llama. Protected. With Protecto.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU. Please enable GPU</p>"
if torch.cuda.is_available():
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 8192,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# outputs = []
# for text in streamer:
# outputs.append(text)
# concatenated_outputs = " ".join(outputs)
# yield concatenated_outputs
concatenated_outputs = "".join([r"{}".format(text) for text in streamer])
# yield concatenated_outputs
# Mask the output here
masked_output = mask_with_protecto(concatenated_outputs)
masked_output = format_for_html(masked_output)
yield masked_output
def format_for_html(text):
text = text.replace("<", "&lt;")
text = text.replace(">", "&gt;")
return text
# Using the function
api_output = "N7JI9bzqbf M1L4LPGkyj is the CEO of Apple Inc."
formatted_output = format_for_html(api_output)
def mask_with_protecto(text_for_prompt):
mask_request_url = "https://trial.protecto.ai/api/vault/mask"
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3N1ZXIiOiJQcm90ZWN0byIsImV4cGlyYXRpb25fZGF0ZSI6IjIwMjMtMTEtMDQiLCJwZXJtaXNzaW9ucyI6WyJyZWFkIiwid3JpdGUiXSwidXNlcl9uYW1lIjoiZGlwYXlhbkBjb2V1c2xlYXJuaW5nLmNvbSIsImRiX25hbWUiOiJwcm90ZWN0b19jb2V1c2xlYXJuaW5nX25ydG1mYmFrIiwiaGFzaGVkX3Bhc3N3b3JkIjoiMjIyMTI2ZWNiZTlkZTRmNWJlODdiY2QyYWFlZWRlM2FmNDc5MzMxZmNhOTUxMWU0MDRiNzkxNDM1MGI4MWUyYiJ9.DeIK00NuhM51lRwWdnUXuQSBA1aBn5AQ8qM3pIeM01U"
}
mask_input = {
"mask": [
{
"value": text_for_prompt
}
]
}
response = requests.put(mask_request_url, headers=headers, json=mask_input)
if response.status_code == 200:
# Parse the masked result from the API response and format it for display
masked_result = response.json()
masked_result_token_value = str(masked_result["data"][0]["token_value"])
return_value = masked_result_token_value
return(return_value)
else:
# Return an error message if the API request was not successful.
return(str(response.status_code))
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
],
retry_btn=None,
stop_btn=None,
undo_btn=None,
clear_btn=None,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()