GRAB-DOC / app.py
prithivMLmods's picture
Update app.py
416168a verified
raw
history blame
3.73 kB
from huggingface_hub import InferenceClient
import gradio as gr
from fpdf import FPDF
from docx import Document
import os
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
def format_prompt(message, history, system_prompt=None):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
if system_prompt:
prompt += f"[SYS] {system_prompt} [/SYS]"
prompt += f"[INST] {message} [/INST]"
return prompt
def save_to_file(content, filename, format):
try:
if format == "pdf":
pdf = FPDF()
pdf.add_page()
pdf.set_auto_page_break(auto=True, margin=15)
pdf.set_font("Arial", size=12)
pdf.multi_cell(0, 10, content)
pdf.output(f"{filename}.pdf")
elif format == "docx":
doc = Document()
doc.add_paragraph(content)
doc.save(f"{filename}.docx")
elif format == "txt":
with open(f"{filename}.txt", 'w') as file:
file.write(content)
return f"File saved successfully as {filename}.{format}"
except Exception as e:
return f"Error saving file: {str(e)}"
def generate(
prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
save_format="txt", save_file=None
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history, system_prompt)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
if save_file:
# Ensure filename is valid and doesn't contain invalid characters
save_file = save_file.strip().replace(' ', '_')
save_message = save_to_file(output, save_file, save_format)
yield f"{output}\n\n{save_message}"
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown("# GRAB DOC")
prompt = gr.Textbox(label="Prompt", placeholder="Enter your text here...", lines=4)
history = gr.State([])
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter system instructions (optional)...", lines=2, visible=False)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.2)
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=1024)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.05, value=0.95)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.0, maximum=2.0, step=0.1, value=1.0)
save_format = gr.Dropdown(label="Save Format", choices=["txt", "pdf", "docx"], value="txt")
save_file = gr.Textbox(label="Save File Name", placeholder="Enter the filename (without extension)...", lines=1)
submit = gr.Button("Generate")
output = gr.Textbox(label="Generated Output", lines=20)
submit.click(
generate,
inputs=[prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty, save_format, save_file],
outputs=output
)
demo.queue().launch(show_api=False)