asoria HF staff commited on
Commit
e739a24
·
1 Parent(s): edc66b4

Separate functions

Browse files
Files changed (1) hide show
  1. app.py +58 -28
app.py CHANGED
@@ -41,6 +41,34 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
41
 
42
 
43
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate_topics(dataset, config, split, column, nested_column):
45
  logging.info(
46
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
@@ -67,43 +95,45 @@ def generate_topics(dataset, config, split, column, nested_column):
67
  while True:
68
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
69
  logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
70
- embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
71
- logging.info(f"Embeddings shape: {embeddings.shape}")
72
  offset = offset + chunk_size
73
  if not docs or offset >= limit:
74
  break
75
 
76
- new_model = BERTopic(
77
- "english",
78
- embedding_model=sentence_model,
79
- representation_model=representation_model,
80
- min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
- logging.info("Fitting new model")
83
- new_model.fit(docs, embeddings)
84
- logging.info("End fitting new model")
85
- if base_model is not None:
86
- updated_model = BERTopic.merge_models([base_model, new_model])
87
- nr_new_topics = len(set(updated_model.topics_)) - len(
88
- set(base_model.topics_)
89
- )
90
- new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
91
- logging.info("The following topics are newly found:")
92
- logging.info(f"{new_topics}\n")
93
- base_model = updated_model
94
- else:
95
- base_model = new_model
96
- logging.info(base_model.get_topic_info())
97
- reduced_embeddings = UMAP(
98
- n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
99
- ).fit_transform(embeddings)
100
- logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
101
  yield (
102
  base_model.get_topic_info(),
103
  new_model.visualize_documents(
104
  docs, embeddings=embeddings
105
- ), # TODO: Visualize the merged models
106
- )
107
  logging.info("Finished processing all data")
108
  return base_model.get_topic_info(), base_model.visualize_topics()
109
 
 
41
 
42
 
43
  @spaces.GPU
44
+ def calculate_embeddings(sentence_model, docs):
45
+ embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
46
+ logging.info(f"Embeddings shape: {embeddings.shape}")
47
+ return embeddings
48
+
49
+
50
+ @spaces.GPU
51
+ def fit_model(base_model, sentence_model, representation_model, docs, embeddings):
52
+ new_model = BERTopic(
53
+ "english",
54
+ embedding_model=sentence_model,
55
+ representation_model=representation_model,
56
+ min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
57
+ )
58
+ logging.info("Fitting new model")
59
+ new_model.fit(docs, embeddings)
60
+ logging.info("End fitting new model")
61
+ if base_model is None:
62
+ return new_model, new_model
63
+
64
+ updated_model = BERTopic.merge_models([base_model, new_model])
65
+ nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_))
66
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
67
+ logging.info("The following topics are newly found:")
68
+ logging.info(f"{new_topics}\n")
69
+ return updated_model, new_model
70
+
71
+
72
  def generate_topics(dataset, config, split, column, nested_column):
73
  logging.info(
74
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
 
95
  while True:
96
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
97
  logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
98
+ embeddings = calculate_embeddings(sentence_model, docs)
 
99
  offset = offset + chunk_size
100
  if not docs or offset >= limit:
101
  break
102
 
103
+ # new_model = BERTopic(
104
+ # "english",
105
+ # embedding_model=sentence_model,
106
+ # representation_model=representation_model,
107
+ # min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
108
+ # )
109
+ # logging.info("Fitting new model")
110
+ # new_model.fit(docs, embeddings)
111
+ # logging.info("End fitting new model")
112
+ # if base_model is not None:
113
+ # updated_model = BERTopic.merge_models([base_model, new_model])
114
+ # nr_new_topics = len(set(updated_model.topics_)) - len(
115
+ # set(base_model.topics_)
116
+ # )
117
+ # new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
118
+ # logging.info("The following topics are newly found:")
119
+ # logging.info(f"{new_topics}\n")
120
+ # base_model = updated_model
121
+ # else:
122
+ # base_model = new_model
123
+ # logging.info(base_model.get_topic_info())
124
+ base_model, new_model = fit_model(
125
+ base_model, sentence_model, representation_model, docs, embeddings
126
  )
127
+ # reduced_embeddings = UMAP(
128
+ # n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
129
+ # ).fit_transform(embeddings)
130
+ # logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  yield (
132
  base_model.get_topic_info(),
133
  new_model.visualize_documents(
134
  docs, embeddings=embeddings
135
+ ), # TODO: Visualize the merged models
136
+ )
137
  logging.info("Finished processing all data")
138
  return base_model.get_topic_info(), base_model.visualize_topics()
139