microhum commited on
Commit
eaf66ef
·
1 Parent(s): 3970cec

add toggle button

Browse files
Files changed (2) hide show
  1. app.py +19 -10
  2. models/transformers.py +1 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Optional
2
  import streamlit as st
3
  from generate import ttf_to_image
 
4
  from PIL import Image
5
  import os
6
 
@@ -8,6 +9,23 @@ LOADED_TTF_KEY = "loaded_ttf"
8
  SET_IMG_KEY = "set_img"
9
  OUTPUT_IMG_KEY = "output_img"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def get_ttf(key: str) -> Optional[any]:
12
  if key in st.session_state:
13
  return st.session_state[key]
@@ -45,19 +63,10 @@ def generate_button(prefix, file_input, version, **kwargs):
45
  value="1,2,3,4,5,6,7,8",
46
  key=f"{prefix}-ref_char_ids",
47
  )
48
- enable_attention_slicing = st.checkbox(
49
- "Enable attention slicing (enables higher resolutions but is slower)",
50
- key=f"{prefix}-attention-slicing",
51
- )
52
- enable_cpu_offload = st.checkbox(
53
- "Enable CPU offload (if you run out of memory, e.g. for XL model)",
54
- key=f"{prefix}-cpu-offload",
55
- value=False,
56
- )
57
 
58
  if st.button("Generate image", key=f"{prefix}-btn"):
59
  with st.spinner("⏳ Generating image..."):
60
- image = ttf_to_image(file_input, n_samples, ref_char_ids, version)
61
  set_img(OUTPUT_IMG_KEY, image.copy())
62
  st.image(image)
63
 
 
1
  from typing import Optional
2
  import streamlit as st
3
  from generate import ttf_to_image
4
+ from threading import Thread
5
  from PIL import Image
6
  import os
7
 
 
9
  SET_IMG_KEY = "set_img"
10
  OUTPUT_IMG_KEY = "output_img"
11
 
12
+ # For multithreading toggle (prevent function from running too many time)
13
+ process_runnning = False
14
+ process_thread = None
15
+
16
+ def toggle_process(process_running, process_thread, run_process):
17
+ if process_runnning:
18
+ # Toggle off
19
+ process_running = False
20
+ st.write("Cancled")
21
+ if process_thread: # Kill Thread
22
+ process_thread.join()
23
+ else:
24
+ # Toggle on
25
+ process_running = True
26
+ process_thread = Thread(target=run_process)
27
+ process_thread.start()
28
+
29
  def get_ttf(key: str) -> Optional[any]:
30
  if key in st.session_state:
31
  return st.session_state[key]
 
63
  value="1,2,3,4,5,6,7,8",
64
  key=f"{prefix}-ref_char_ids",
65
  )
 
 
 
 
 
 
 
 
 
66
 
67
  if st.button("Generate image", key=f"{prefix}-btn"):
68
  with st.spinner("⏳ Generating image..."):
69
+ image = toggle_process( ttf_to_image(file_input, n_samples, ref_char_ids, version) )
70
  set_img(OUTPUT_IMG_KEY, image.copy())
71
  st.image(image)
72
 
models/transformers.py CHANGED
@@ -196,7 +196,7 @@ class Transformer_decoder(nn.Module):
196
  self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
  self.decoder_norm_parallel = nn.LayerNorm(512)
198
  if opts.ref_nshot == 52:
199
- self.cls_embedding = nn.Embedding(92,512)
200
  else:
201
  self.cls_embedding = nn.Embedding(52,512)
202
  self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
 
196
  self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
  self.decoder_norm_parallel = nn.LayerNorm(512)
198
  if opts.ref_nshot == 52:
199
+ self.cls_embedding = nn.Embedding(96,512)
200
  else:
201
  self.cls_embedding = nn.Embedding(52,512)
202
  self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))