asoria HF staff commited on
Commit
64583bd
·
1 Parent(s): 2aedb2c

Open in spaces

Browse files
Files changed (1) hide show
  1. app.py +84 -19
app.py CHANGED
@@ -1,5 +1,4 @@
1
- # These imports at the end because of torch/datamapplot issue in Zero GPU
2
- # import spaces
3
  import gradio as gr
4
 
5
  import logging
@@ -16,13 +15,10 @@ from bertopic import BERTopic
16
  from bertopic.representation import KeyBERTInspired
17
  from bertopic.representation import TextGeneration
18
 
19
- # Temporary disabling because of ZeroGPU does not support cuml
20
  from cuml.manifold import UMAP
21
  from cuml.cluster import HDBSCAN
22
 
23
- # from umap import UMAP
24
- # from hdbscan import HDBSCAN
25
- from huggingface_hub import HfApi
26
  from sklearn.feature_extraction.text import CountVectorizer
27
  from sentence_transformers import SentenceTransformer
28
  from prompts import REPRESENTATION_PROMPT
@@ -59,6 +55,7 @@ logging.basicConfig(
59
  MAX_ROWS = 50_000
60
  CHUNK_SIZE = 10_000
61
 
 
62
 
63
  session = requests.Session()
64
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
@@ -186,7 +183,6 @@ def _push_to_hub(
186
  logging.info(f"Pushing file to hub: {dataset_id} on file {file_path}")
187
 
188
  file_name = file_path.split("/")[-1]
189
- api = HfApi(token=HF_TOKEN)
190
  try:
191
  logging.info(f"About to push {file_path} - {dataset_id}")
192
  api.upload_file(
@@ -200,6 +196,44 @@ def _push_to_hub(
200
  raise
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
204
  logging.info(
205
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
@@ -239,6 +273,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
239
  gr.Plot(value=None, visible=True),
240
  gr.Label({message: rows_processed / limit}, visible=True),
241
  "",
 
242
  )
243
  while offset < limit:
244
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
@@ -278,6 +313,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
278
  docs=all_docs,
279
  reduced_embeddings=reduced_embeddings_array,
280
  title=dataset,
 
281
  width=800,
282
  height=700,
283
  arrowprops={
@@ -286,6 +322,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
286
  "linewidth": 0,
287
  "fc": "#33333377",
288
  },
 
289
  dynamic_label_size=False,
290
  # label_wrap_width=12,
291
  # label_over_points=True,
@@ -299,6 +336,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
299
  reduced_embeddings=reduced_embeddings_array,
300
  custom_labels=True,
301
  title=dataset,
 
302
  )
303
  )
304
 
@@ -317,6 +355,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
317
  topic_plot,
318
  gr.Label({message: progress}, visible=True),
319
  "",
 
320
  )
321
 
322
  offset += CHUNK_SIZE
@@ -330,20 +369,42 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
330
  topic_plot.write_image(plot_png)
331
 
332
  _push_to_hub(dataset, plot_png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  plot_png_link = (
334
  f"https://huggingface.co/datasets/{EXPORTS_REPOSITORY}/blob/main/{plot_png}"
335
  )
336
- # interactive_plot = datamapplot.create_interactive_plot(
337
- # reduced_embeddings_array,
338
- # *cord19_label_layers,
339
- # font_family="Cinzel",
340
- # enable_search=True,
341
- # inline_data=False,
342
- # offline_data_prefix="cord-large-1",
343
- # initial_zoom_fraction=0.4,
344
- # )
345
- # all_topics, _ = base_model.transform(all_topics)
346
- # logging.info(f"TAll opics: {all_topics[:5]}")
347
  yield (
348
  gr.Accordion(open=False),
349
  topics_info,
@@ -352,6 +413,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
352
  {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
353
  ),
354
  f"[![Download as PNG](https://img.shields.io/badge/Download_as-PNG-red)]({plot_png_link})",
 
355
  )
356
  cuda.empty_cache()
357
 
@@ -400,7 +462,9 @@ with gr.Blocks() as demo:
400
 
401
  gr.Markdown("## Data map")
402
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
403
- open_png_label = gr.Markdown()
 
 
404
  topics_plot = gr.Plot()
405
  with gr.Accordion("Topics Info", open=False):
406
  topics_df = gr.DataFrame(interactive=False, visible=True)
@@ -420,6 +484,7 @@ with gr.Blocks() as demo:
420
  topics_plot,
421
  full_topics_generation_label,
422
  open_png_label,
 
423
  ],
424
  )
425
 
 
1
+ import spaces
 
2
  import gradio as gr
3
 
4
  import logging
 
15
  from bertopic.representation import KeyBERTInspired
16
  from bertopic.representation import TextGeneration
17
 
 
18
  from cuml.manifold import UMAP
19
  from cuml.cluster import HDBSCAN
20
 
21
+ from huggingface_hub import HfApi, SpaceCard
 
 
22
  from sklearn.feature_extraction.text import CountVectorizer
23
  from sentence_transformers import SentenceTransformer
24
  from prompts import REPRESENTATION_PROMPT
 
55
  MAX_ROWS = 50_000
56
  CHUNK_SIZE = 10_000
57
 
58
+ api = HfApi(token=HF_TOKEN)
59
 
60
  session = requests.Session()
61
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
 
183
  logging.info(f"Pushing file to hub: {dataset_id} on file {file_path}")
184
 
185
  file_name = file_path.split("/")[-1]
 
186
  try:
187
  logging.info(f"About to push {file_path} - {dataset_id}")
188
  api.upload_file(
 
196
  raise
197
 
198
 
199
+ def create_space_with_content(dataset_id, html_file_path):
200
+ # TODO: Parameterize organization name
201
+ repo_id = f"datasets-topics/{dataset_id.replace('/', '-')}"
202
+ logging.info(f"Creating space with content: {repo_id} on file {html_file_path}")
203
+ api.create_repo(
204
+ repo_id=repo_id,
205
+ repo_type="space",
206
+ private=False,
207
+ exist_ok=True,
208
+ token=HF_TOKEN,
209
+ space_sdk="static",
210
+ )
211
+ SPACE_REPO_CARD_CONTENT = """
212
+ ---
213
+ title: {dataset_id} topic modeling
214
+ sdk: static
215
+ pinned: false
216
+ datasets:
217
+ - {dataset_id}
218
+ ---
219
+
220
+ """
221
+
222
+ SpaceCard(
223
+ content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id)
224
+ ).push_to_hub(repo_id=repo_id, repo_type="space", token=HF_TOKEN)
225
+
226
+ api.upload_file(
227
+ path_or_fileobj=html_file_path,
228
+ path_in_repo="index.html",
229
+ repo_type="space",
230
+ repo_id=repo_id,
231
+ token=HF_TOKEN,
232
+ )
233
+ logging.info(f"Space created done")
234
+ return repo_id
235
+
236
+
237
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
238
  logging.info(
239
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
 
273
  gr.Plot(value=None, visible=True),
274
  gr.Label({message: rows_processed / limit}, visible=True),
275
  "",
276
+ "",
277
  )
278
  while offset < limit:
279
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
 
313
  docs=all_docs,
314
  reduced_embeddings=reduced_embeddings_array,
315
  title=dataset,
316
+ font_family="Montserrat Thin",
317
  width=800,
318
  height=700,
319
  arrowprops={
 
322
  "linewidth": 0,
323
  "fc": "#33333377",
324
  },
325
+ # TODO: Make it configurable in UI
326
  dynamic_label_size=False,
327
  # label_wrap_width=12,
328
  # label_over_points=True,
 
336
  reduced_embeddings=reduced_embeddings_array,
337
  custom_labels=True,
338
  title=dataset,
339
+ font_family="Montserrat Thin",
340
  )
341
  )
342
 
 
355
  topic_plot,
356
  gr.Label({message: progress}, visible=True),
357
  "",
358
+ "",
359
  )
360
 
361
  offset += CHUNK_SIZE
 
369
  topic_plot.write_image(plot_png)
370
 
371
  _push_to_hub(dataset, plot_png)
372
+
373
+ all_topics, _ = base_model.transform(all_docs)
374
+ topic_info = base_model.get_topic_info()
375
+
376
+ topic_names = {row["Topic"]: row["Name"] for index, row in topic_info.iterrows()}
377
+ topic_names_array = np.array(
378
+ [
379
+ topic_names.get(topic, "No Topic").split("_")[1].strip("-")
380
+ for topic in all_topics
381
+ ]
382
+ )
383
+ dataset_clear_name = dataset.replace("/", "-")
384
+ interactive_plot = datamapplot.create_interactive_plot(
385
+ reduced_embeddings_array,
386
+ topic_names_array,
387
+ hover_text=all_docs,
388
+ title=dataset,
389
+ enable_search=True,
390
+ font_family="Montserrat Thin",
391
+ # TODO: Export data to .arrow and also serve it
392
+ inline_data=True,
393
+ # offline_data_prefix=dataset_clear_name,
394
+ initial_zoom_fraction=0.9,
395
+ )
396
+ html_content = str(interactive_plot)
397
+ html_file_path = f"{dataset_clear_name}.html"
398
+ with open(html_file_path, "w", encoding="utf-8") as html_file:
399
+ html_file.write(html_content)
400
+
401
+ space_id = create_space_with_content(dataset, html_file_path)
402
+
403
  plot_png_link = (
404
  f"https://huggingface.co/datasets/{EXPORTS_REPOSITORY}/blob/main/{plot_png}"
405
  )
406
+
407
+ space_link = f"https://huggingface.co/spaces/{space_id}"
 
 
 
 
 
 
 
 
 
408
  yield (
409
  gr.Accordion(open=False),
410
  topics_info,
 
413
  {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
414
  ),
415
  f"[![Download as PNG](https://img.shields.io/badge/Download_as-PNG-red)]({plot_png_link})",
416
+ f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
417
  )
418
  cuda.empty_cache()
419
 
 
462
 
463
  gr.Markdown("## Data map")
464
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
465
+ with gr.Row():
466
+ open_png_label = gr.Markdown()
467
+ open_space_label = gr.Markdown()
468
  topics_plot = gr.Plot()
469
  with gr.Accordion("Topics Info", open=False):
470
  topics_df = gr.DataFrame(interactive=False, visible=True)
 
484
  topics_plot,
485
  full_topics_generation_label,
486
  open_png_label,
487
+ open_space_label,
488
  ],
489
  )
490