Rzhishchev commited on
Commit
68edb67
1 Parent(s): 5a4923b

Rename app.py to gpt2.py

Browse files
Files changed (1) hide show
  1. app.py → gpt2.py +16 -7
app.py → gpt2.py RENAMED
@@ -8,24 +8,33 @@ model_path = "zhvanetsky_model"
8
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
9
  model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
10
 
11
- def generate_text(input_text):
12
  model.eval()
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
14
  with torch.no_grad():
15
  out = model.generate(input_ids,
16
  do_sample=True,
17
- num_beams=10,
18
- temperature=2.2,
19
- top_p=0.85,
20
  top_k=500,
21
- max_length=100,
22
  no_repeat_ngram_size=3,
23
  num_return_sequences=3,
24
  )
25
  return tokenizer.decode(out[0], skip_special_tokens=True)
26
 
 
27
  st.title("GPT-2 Text Generator")
 
28
  user_input = st.text_area("Input Text", "Введите ваш текст")
 
 
 
 
 
 
 
29
  if st.button("Generate"):
30
- generated_output = generate_text(user_input)
31
- st.text_area("Generated Text", generated_output)
 
8
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
9
  model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
10
 
11
+ def generate_text(input_text, num_beams, temperature, max_length, top_p):
12
  model.eval()
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
14
  with torch.no_grad():
15
  out = model.generate(input_ids,
16
  do_sample=True,
17
+ num_beams=num_beams,
18
+ temperature=temperature,
19
+ top_p=top_p,
20
  top_k=500,
21
+ max_length=max_length,
22
  no_repeat_ngram_size=3,
23
  num_return_sequences=3,
24
  )
25
  return tokenizer.decode(out[0], skip_special_tokens=True)
26
 
27
+ # Streamlit interface
28
  st.title("GPT-2 Text Generator")
29
+
30
  user_input = st.text_area("Input Text", "Введите ваш текст")
31
+
32
+ # Add sliders or input boxes for model parameters
33
+ num_beams = st.slider("Number of Beams", min_value=1, max_value=20, value=10)
34
+ temperature = st.slider("Temperature", min_value=0.1, max_value=3.0, value=1.0, step=0.1)
35
+ max_length = st.number_input("Max Length", min_value=10, max_value=300, value=100)
36
+ top_p = st.slider("Top P", min_value=0.1, max_value=1.0, value=0.85, step=0.05)
37
+
38
  if st.button("Generate"):
39
+ generated_output = generate_text(user_input, num_beams, temperature, max_length, top_p)
40
+ st.text_area("Generated Text", generated_output)