Tonic commited on
Commit
be6c757
·
verified ·
1 Parent(s): 3582cae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -22,23 +22,23 @@ os.system('python -m spacy download en_core_web_sm')
22
  nlp = spacy.load("en_core_web_sm")
23
 
24
  # Function for generating text and tokenizing
25
- def historical_generation(prompt, max_new_tokens=600):
26
  prompt = f"### Text ###\n{prompt}"
27
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
28
  input_ids = inputs["input_ids"].to(device)
29
  attention_mask = inputs["attention_mask"].to(device)
30
 
31
- # Generate text
32
  output = model.generate(
33
  input_ids,
34
  attention_mask=attention_mask,
35
  max_new_tokens=max_new_tokens,
36
  pad_token_id=tokenizer.eos_token_id,
37
- top_k=50,
38
- temperature=0.3,
39
- top_p=0.95,
40
  do_sample=True,
41
- repetition_penalty=1.5,
42
  bos_token_id=tokenizer.bos_token_id,
43
  eos_token_id=tokenizer.eos_token_id
44
  )
@@ -53,11 +53,11 @@ def historical_generation(prompt, max_new_tokens=600):
53
  # Tokenize the generated text
54
  tokens = tokenizer.tokenize(generated_text)
55
 
56
- # Create highlighted text output
57
  highlighted_text = []
58
  for token in tokens:
59
  clean_token = token.replace("Ġ", "") # Remove "Ġ"
60
- token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
61
  highlighted_text.append((clean_token, token_type))
62
 
63
  return highlighted_text, generated_text # Return both tokenized and raw generated text
@@ -85,8 +85,10 @@ def generate_dependency_parse(generated_text):
85
  return html_generated
86
 
87
  # Full interface combining text generation and analysis, split across steps
88
- def full_interface(prompt, max_new_tokens):
89
- generated_highlight, generated_text = historical_generation(prompt, max_new_tokens)
 
 
90
 
91
  # Dependency parse of input text
92
  tokens_input, pos_count_input, html_input = text_analysis(prompt)
@@ -101,7 +103,13 @@ def reset_interface():
101
  # Gradio interface components
102
  with gr.Blocks() as iface:
103
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
 
 
104
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
 
 
 
 
105
 
106
  # Output components
107
  highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
@@ -126,7 +134,7 @@ with gr.Blocks() as iface:
126
  generate_button = gr.Button(value="Generate Text and Initial Outputs")
127
  generate_button.click(
128
  full_interface,
129
- inputs=[prompt, max_new_tokens],
130
  outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
131
  )
132
 
 
22
  nlp = spacy.load("en_core_web_sm")
23
 
24
  # Function for generating text and tokenizing
25
+ def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
26
  prompt = f"### Text ###\n{prompt}"
27
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
28
  input_ids = inputs["input_ids"].to(device)
29
  attention_mask = inputs["attention_mask"].to(device)
30
 
31
+ # Generate text with customizable parameters
32
  output = model.generate(
33
  input_ids,
34
  attention_mask=attention_mask,
35
  max_new_tokens=max_new_tokens,
36
  pad_token_id=tokenizer.eos_token_id,
37
+ top_k=top_k,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
  do_sample=True,
41
+ repetition_penalty=repetition_penalty,
42
  bos_token_id=tokenizer.bos_token_id,
43
  eos_token_id=tokenizer.eos_token_id
44
  )
 
53
  # Tokenize the generated text
54
  tokens = tokenizer.tokenize(generated_text)
55
 
56
+ # Create highlighted text output, remove "Ġ" from both the token and token_type
57
  highlighted_text = []
58
  for token in tokens:
59
  clean_token = token.replace("Ġ", "") # Remove "Ġ"
60
+ token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
61
  highlighted_text.append((clean_token, token_type))
62
 
63
  return highlighted_text, generated_text # Return both tokenized and raw generated text
 
85
  return html_generated
86
 
87
  # Full interface combining text generation and analysis, split across steps
88
+ def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
89
+ generated_highlight, generated_text = historical_generation(
90
+ prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
91
+ )
92
 
93
  # Dependency parse of input text
94
  tokens_input, pos_count_input, html_input = text_analysis(prompt)
 
103
  # Gradio interface components
104
  with gr.Blocks() as iface:
105
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
106
+
107
+ # Slider for model parameters
108
  max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
109
+ top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=1, value=50)
110
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.7)
111
+ top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.05, value=0.95)
112
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
113
 
114
  # Output components
115
  highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
 
134
  generate_button = gr.Button(value="Generate Text and Initial Outputs")
135
  generate_button.click(
136
  full_interface,
137
+ inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
138
  outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
139
  )
140