Lazyhope commited on
Commit
561a94f
1 Parent(s): d9e7a5e

Show progress bar when running the model

Browse files
Files changed (1) hide show
  1. pipeline.py +27 -15
pipeline.py CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
6
  import numpy as np
7
  import requests
8
  import torch
 
9
  from transformers import Pipeline
10
 
11
 
@@ -154,26 +155,37 @@ class RepoEmbeddingPipeline(Pipeline):
154
 
155
  def _forward(self, extracted_infos, max_length=512):
156
  repo_dataset = {}
157
- for repo_name, repo_info in extracted_infos.items():
158
- entry = {"topics": repo_info.get("topics")}
159
-
160
- print(f"[+] Generating embeddings for {repo_name}")
161
- if entry.get("code_embeddings") is None:
162
- code_embeddings = [
163
- [func, self.encode(func, max_length).squeeze().tolist()]
164
- for func in repo_info["funcs"]
165
- ]
 
 
 
 
 
 
 
 
166
  entry["code_embeddings"] = code_embeddings
167
  entry["mean_code_embedding"] = (
168
  np.mean([x[1] for x in code_embeddings], axis=0).tolist()
169
  if code_embeddings
170
  else None
171
  )
172
- if entry.get("doc_embeddings") is None:
173
- doc_embeddings = [
174
- [doc, self.encode(doc, max_length).squeeze().tolist()]
175
- for doc in repo_info["docs"]
176
- ]
 
 
 
177
  entry["doc_embeddings"] = doc_embeddings
178
  entry["mean_doc_embedding"] = (
179
  np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
@@ -181,7 +193,7 @@ class RepoEmbeddingPipeline(Pipeline):
181
  else None
182
  )
183
 
184
- repo_dataset[repo_name] = entry
185
 
186
  return repo_dataset
187
 
 
6
  import numpy as np
7
  import requests
8
  import torch
9
+ from tqdm import tqdm
10
  from transformers import Pipeline
11
 
12
 
 
155
 
156
  def _forward(self, extracted_infos, max_length=512):
157
  repo_dataset = {}
158
+ num_texts = sum(
159
+ len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
160
+ )
161
+ with tqdm(total=num_texts) as pbar:
162
+ for repo_name, repo_info in extracted_infos.items():
163
+ pbar.set_description(f"Processing {repo_name}")
164
+ entry = {"topics": repo_info.get("topics")}
165
+
166
+ print(f"[+] Generating embeddings for {repo_name}")
167
+
168
+ code_embeddings = []
169
+ for func in repo_info["funcs"]:
170
+ code_embeddings.append(
171
+ [func, self.encode(func, max_length).squeeze().tolist()]
172
+ )
173
+ pbar.update(1)
174
+
175
  entry["code_embeddings"] = code_embeddings
176
  entry["mean_code_embedding"] = (
177
  np.mean([x[1] for x in code_embeddings], axis=0).tolist()
178
  if code_embeddings
179
  else None
180
  )
181
+
182
+ doc_embeddings = []
183
+ for doc in repo_info["docs"]:
184
+ doc_embeddings.append(
185
+ [doc, self.encode(doc, max_length).squeeze().tolist()]
186
+ )
187
+ pbar.update(1)
188
+
189
  entry["doc_embeddings"] = doc_embeddings
190
  entry["mean_doc_embedding"] = (
191
  np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
 
193
  else None
194
  )
195
 
196
+ repo_dataset[repo_name] = entry
197
 
198
  return repo_dataset
199