gsarti commited on
Commit
b4cb26c
1 Parent(s): df52f66

Add tqdm for attribute_context

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import re
 
3
 
4
  import bm25s
5
  import gradio as gr
@@ -12,6 +13,7 @@ from lxt.functional import add2, mul2, softmax
12
  from lxt.models.llama import LlamaForCausalLM, attnlrp
13
  from rerankers import Reranker
14
  from style import custom_css
 
15
  from transformers import AutoTokenizer
16
 
17
  from inseq import load_model, register_step_function
@@ -120,7 +122,7 @@ def generate(
120
  rag_setting,
121
  custom_context,
122
  model_size,
123
- progress=gr.Progress(),
124
  ):
125
  global model, model_id
126
  if rag_setting == "Use Custom Context":
@@ -144,7 +146,8 @@ def generate(
144
  if model is None or model.model_name != curr_model_id:
145
  progress(0.2, desc="Loading model...")
146
  model = get_model(model_size)
147
- progress(0.3, desc="Attributing with LXT...")
 
148
  lm_rag_prompting_example = AttributeContextArgs(
149
  model_name_or_path=model_id,
150
  input_context_text="\n\n".join(docs),
@@ -169,7 +172,20 @@ def generate(
169
  save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
170
  viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
171
  )
172
- out = attribute_context_with_model(lm_rag_prompting_example, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  html = visualize_attribute_context(out, show_viz=False, return_html=True)
174
  return [
175
  gradio_iframe.iFrame(html, height=500, visible=True),
 
1
  import os
2
  import re
3
+ import threading
4
 
5
  import bm25s
6
  import gradio as gr
 
13
  from lxt.models.llama import LlamaForCausalLM, attnlrp
14
  from rerankers import Reranker
15
  from style import custom_css
16
+ from tqdm import tqdm
17
  from transformers import AutoTokenizer
18
 
19
  from inseq import load_model, register_step_function
 
122
  rag_setting,
123
  custom_context,
124
  model_size,
125
+ progress=gr.Progress(track_tqdm=True),
126
  ):
127
  global model, model_id
128
  if rag_setting == "Use Custom Context":
 
146
  if model is None or model.model_name != curr_model_id:
147
  progress(0.2, desc="Loading model...")
148
  model = get_model(model_size)
149
+ estimated_time = 20
150
+ tstep = 1
151
  lm_rag_prompting_example = AttributeContextArgs(
152
  model_name_or_path=model_id,
153
  input_context_text="\n\n".join(docs),
 
172
  save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
173
  viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
174
  )
175
+
176
+ ret = [None]
177
+
178
+ def run_attribute_context():
179
+ ret[0] = attribute_context_with_model(lm_rag_prompting_example, model)
180
+
181
+ thread = threading.Thread(target=run_attribute_context)
182
+ pbar = tqdm(total=estimated_time, desc="Attributing with LXT...")
183
+ thread.start()
184
+ while thread.is_alive():
185
+ thread.join(timeout=tstep)
186
+ pbar.update(tstep)
187
+ pbar.close()
188
+ out = ret[0]
189
  html = visualize_attribute_context(out, show_viz=False, return_html=True)
190
  return [
191
  gradio_iframe.iFrame(html, height=500, visible=True),