Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,23 +22,23 @@ os.system('python -m spacy download en_core_web_sm')
|
|
22 |
nlp = spacy.load("en_core_web_sm")
|
23 |
|
24 |
# Function for generating text and tokenizing
|
25 |
-
def historical_generation(prompt, max_new_tokens=600):
|
26 |
prompt = f"### Text ###\n{prompt}"
|
27 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
28 |
input_ids = inputs["input_ids"].to(device)
|
29 |
attention_mask = inputs["attention_mask"].to(device)
|
30 |
|
31 |
-
# Generate text
|
32 |
output = model.generate(
|
33 |
input_ids,
|
34 |
attention_mask=attention_mask,
|
35 |
max_new_tokens=max_new_tokens,
|
36 |
pad_token_id=tokenizer.eos_token_id,
|
37 |
-
top_k=
|
38 |
-
temperature=
|
39 |
-
top_p=
|
40 |
do_sample=True,
|
41 |
-
repetition_penalty=
|
42 |
bos_token_id=tokenizer.bos_token_id,
|
43 |
eos_token_id=tokenizer.eos_token_id
|
44 |
)
|
@@ -53,11 +53,11 @@ def historical_generation(prompt, max_new_tokens=600):
|
|
53 |
# Tokenize the generated text
|
54 |
tokens = tokenizer.tokenize(generated_text)
|
55 |
|
56 |
-
# Create highlighted text output
|
57 |
highlighted_text = []
|
58 |
for token in tokens:
|
59 |
clean_token = token.replace("Ġ", "") # Remove "Ġ"
|
60 |
-
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
|
61 |
highlighted_text.append((clean_token, token_type))
|
62 |
|
63 |
return highlighted_text, generated_text # Return both tokenized and raw generated text
|
@@ -85,8 +85,10 @@ def generate_dependency_parse(generated_text):
|
|
85 |
return html_generated
|
86 |
|
87 |
# Full interface combining text generation and analysis, split across steps
|
88 |
-
def full_interface(prompt, max_new_tokens):
|
89 |
-
generated_highlight, generated_text = historical_generation(
|
|
|
|
|
90 |
|
91 |
# Dependency parse of input text
|
92 |
tokens_input, pos_count_input, html_input = text_analysis(prompt)
|
@@ -101,7 +103,13 @@ def reset_interface():
|
|
101 |
# Gradio interface components
|
102 |
with gr.Blocks() as iface:
|
103 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
|
|
|
|
|
104 |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Output components
|
107 |
highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
|
@@ -126,7 +134,7 @@ with gr.Blocks() as iface:
|
|
126 |
generate_button = gr.Button(value="Generate Text and Initial Outputs")
|
127 |
generate_button.click(
|
128 |
full_interface,
|
129 |
-
inputs=[prompt, max_new_tokens],
|
130 |
outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
|
131 |
)
|
132 |
|
|
|
22 |
nlp = spacy.load("en_core_web_sm")
|
23 |
|
24 |
# Function for generating text and tokenizing
|
25 |
+
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
|
26 |
prompt = f"### Text ###\n{prompt}"
|
27 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
28 |
input_ids = inputs["input_ids"].to(device)
|
29 |
attention_mask = inputs["attention_mask"].to(device)
|
30 |
|
31 |
+
# Generate text with customizable parameters
|
32 |
output = model.generate(
|
33 |
input_ids,
|
34 |
attention_mask=attention_mask,
|
35 |
max_new_tokens=max_new_tokens,
|
36 |
pad_token_id=tokenizer.eos_token_id,
|
37 |
+
top_k=top_k,
|
38 |
+
temperature=temperature,
|
39 |
+
top_p=top_p,
|
40 |
do_sample=True,
|
41 |
+
repetition_penalty=repetition_penalty,
|
42 |
bos_token_id=tokenizer.bos_token_id,
|
43 |
eos_token_id=tokenizer.eos_token_id
|
44 |
)
|
|
|
53 |
# Tokenize the generated text
|
54 |
tokens = tokenizer.tokenize(generated_text)
|
55 |
|
56 |
+
# Create highlighted text output, remove "Ġ" from both the token and token_type
|
57 |
highlighted_text = []
|
58 |
for token in tokens:
|
59 |
clean_token = token.replace("Ġ", "") # Remove "Ġ"
|
60 |
+
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
|
61 |
highlighted_text.append((clean_token, token_type))
|
62 |
|
63 |
return highlighted_text, generated_text # Return both tokenized and raw generated text
|
|
|
85 |
return html_generated
|
86 |
|
87 |
# Full interface combining text generation and analysis, split across steps
|
88 |
+
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
|
89 |
+
generated_highlight, generated_text = historical_generation(
|
90 |
+
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
|
91 |
+
)
|
92 |
|
93 |
# Dependency parse of input text
|
94 |
tokens_input, pos_count_input, html_input = text_analysis(prompt)
|
|
|
103 |
# Gradio interface components
|
104 |
with gr.Blocks() as iface:
|
105 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
|
106 |
+
|
107 |
+
# Slider for model parameters
|
108 |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
|
109 |
+
top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=1, value=50)
|
110 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.7)
|
111 |
+
top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.05, value=0.95)
|
112 |
+
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
|
113 |
|
114 |
# Output components
|
115 |
highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
|
|
|
134 |
generate_button = gr.Button(value="Generate Text and Initial Outputs")
|
135 |
generate_button.click(
|
136 |
full_interface,
|
137 |
+
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
|
138 |
outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
|
139 |
)
|
140 |
|