Show progress bar when running the model
Browse files- pipeline.py +27 -15
pipeline.py
CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
|
|
6 |
import numpy as np
|
7 |
import requests
|
8 |
import torch
|
|
|
9 |
from transformers import Pipeline
|
10 |
|
11 |
|
@@ -154,26 +155,37 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
154 |
|
155 |
def _forward(self, extracted_infos, max_length=512):
|
156 |
repo_dataset = {}
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
entry["code_embeddings"] = code_embeddings
|
167 |
entry["mean_code_embedding"] = (
|
168 |
np.mean([x[1] for x in code_embeddings], axis=0).tolist()
|
169 |
if code_embeddings
|
170 |
else None
|
171 |
)
|
172 |
-
|
173 |
-
doc_embeddings = [
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
177 |
entry["doc_embeddings"] = doc_embeddings
|
178 |
entry["mean_doc_embedding"] = (
|
179 |
np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
|
@@ -181,7 +193,7 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
181 |
else None
|
182 |
)
|
183 |
|
184 |
-
|
185 |
|
186 |
return repo_dataset
|
187 |
|
|
|
6 |
import numpy as np
|
7 |
import requests
|
8 |
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
from transformers import Pipeline
|
11 |
|
12 |
|
|
|
155 |
|
156 |
def _forward(self, extracted_infos, max_length=512):
|
157 |
repo_dataset = {}
|
158 |
+
num_texts = sum(
|
159 |
+
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
160 |
+
)
|
161 |
+
with tqdm(total=num_texts) as pbar:
|
162 |
+
for repo_name, repo_info in extracted_infos.items():
|
163 |
+
pbar.set_description(f"Processing {repo_name}")
|
164 |
+
entry = {"topics": repo_info.get("topics")}
|
165 |
+
|
166 |
+
print(f"[+] Generating embeddings for {repo_name}")
|
167 |
+
|
168 |
+
code_embeddings = []
|
169 |
+
for func in repo_info["funcs"]:
|
170 |
+
code_embeddings.append(
|
171 |
+
[func, self.encode(func, max_length).squeeze().tolist()]
|
172 |
+
)
|
173 |
+
pbar.update(1)
|
174 |
+
|
175 |
entry["code_embeddings"] = code_embeddings
|
176 |
entry["mean_code_embedding"] = (
|
177 |
np.mean([x[1] for x in code_embeddings], axis=0).tolist()
|
178 |
if code_embeddings
|
179 |
else None
|
180 |
)
|
181 |
+
|
182 |
+
doc_embeddings = []
|
183 |
+
for doc in repo_info["docs"]:
|
184 |
+
doc_embeddings.append(
|
185 |
+
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
186 |
+
)
|
187 |
+
pbar.update(1)
|
188 |
+
|
189 |
entry["doc_embeddings"] = doc_embeddings
|
190 |
entry["mean_doc_embedding"] = (
|
191 |
np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
|
|
|
193 |
else None
|
194 |
)
|
195 |
|
196 |
+
repo_dataset[repo_name] = entry
|
197 |
|
198 |
return repo_dataset
|
199 |
|