Lazyhope commited on
Commit
c5dc40f
1 Parent(s): 2ab1f7c

Update pipeline

Browse files

Fix a bug when getting metadata license and topics, use list of dictionaries to store the results instead of a single dictionary by pushing repo name inwards, add repo stars to the result

Files changed (1) hide show
  1. pipeline.py +32 -22
pipeline.py CHANGED
@@ -19,16 +19,12 @@ def extract_code_and_docs(text: str):
19
  tuple: A tuple of two sets, the first is the code set, and the second is the docs set,
20
  each set contains unique code string or docstring, respectively.
21
  """
22
- root = ast.parse(text)
23
- def_nodes = [
24
- node
25
- for node in ast.walk(root)
26
- if isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module))
27
- ]
28
-
29
  code_set = set()
30
  docs_set = set()
31
- for node in def_nodes:
 
 
 
32
  docs = ast.get_docstring(node)
33
  node_without_docs = node
34
  if docs is not None:
@@ -55,16 +51,22 @@ def get_metadata(repo_name, headers=None):
55
 
56
 
57
  def download_and_extract(repos, headers=None):
58
- extracted_info = {}
59
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
60
  # Get metadata
61
  metadata = get_metadata(repo_name, headers=headers)
62
- extracted_info[repo_name] = {
 
63
  "funcs": set(),
64
  "docs": set(),
65
- "topics": metadata.get("topics", []),
66
- "license": metadata.get("license", {}).get("spdx_id", None),
 
67
  }
 
 
 
 
68
 
69
  # Download repo tarball bytes
70
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
@@ -86,8 +88,8 @@ def download_and_extract(repos, headers=None):
86
  file_content = tar.extractfile(member).read().decode("utf-8")
87
  code_set, docs_set = extract_code_and_docs(file_content)
88
 
89
- extracted_info[repo_name]["funcs"].update(code_set)
90
- extracted_info[repo_name]["docs"].update(docs_set)
91
  except UnicodeDecodeError as e:
92
  tqdm.write(
93
  f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
@@ -95,7 +97,9 @@ def download_and_extract(repos, headers=None):
95
  except SyntaxError as e:
96
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
97
 
98
- return extracted_info
 
 
99
 
100
 
101
  class RepoEmbeddingPipeline(Pipeline):
@@ -140,6 +144,7 @@ class RepoEmbeddingPipeline(Pipeline):
140
 
141
  if self.st_messager:
142
  self.st_messager.info("[*] Downloading and extracting repos...")
 
143
  extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
144
 
145
  return extracted_infos
@@ -171,14 +176,19 @@ class RepoEmbeddingPipeline(Pipeline):
171
  return sentence_embeddings
172
 
173
  def _forward(self, extracted_infos, max_length=512, st_progress=None):
174
- repo_dataset = {}
175
- num_texts = sum(
176
- len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
177
- )
178
  with tqdm(total=num_texts) as pbar:
179
- for repo_name, repo_info in extracted_infos.items():
 
 
 
 
 
 
 
 
180
  pbar.set_description(f"Processing {repo_name}")
181
- entry = {"topics": repo_info["topics"], "license": repo_info["license"]}
182
 
183
  message = f"[*] Generating embeddings for {repo_name}"
184
  tqdm.write(message)
@@ -219,7 +229,7 @@ class RepoEmbeddingPipeline(Pipeline):
219
  else None
220
  )
221
 
222
- repo_dataset[repo_name] = entry
223
 
224
  return repo_dataset
225
 
 
19
  tuple: A tuple of two sets, the first is the code set, and the second is the docs set,
20
  each set contains unique code string or docstring, respectively.
21
  """
 
 
 
 
 
 
 
22
  code_set = set()
23
  docs_set = set()
24
+ root = ast.parse(text)
25
+ for node in ast.walk(root):
26
+ if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
27
+ continue
28
  docs = ast.get_docstring(node)
29
  node_without_docs = node
30
  if docs is not None:
 
51
 
52
 
53
  def download_and_extract(repos, headers=None):
54
+ extracted_infos = []
55
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
56
  # Get metadata
57
  metadata = get_metadata(repo_name, headers=headers)
58
+ repo_info = {
59
+ "name": repo_name,
60
  "funcs": set(),
61
  "docs": set(),
62
+ "topics": [],
63
+ "license": None,
64
+ "stars": metadata.get("stargazers_count"),
65
  }
66
+ if metadata.get("topics"):
67
+ repo_info["topics"] = metadata["topics"]
68
+ if metadata.get("license"):
69
+ repo_info["license"] = metadata["license"]["spdx_id"]
70
 
71
  # Download repo tarball bytes
72
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
 
88
  file_content = tar.extractfile(member).read().decode("utf-8")
89
  code_set, docs_set = extract_code_and_docs(file_content)
90
 
91
+ repo_info["funcs"].update(code_set)
92
+ repo_info["docs"].update(docs_set)
93
  except UnicodeDecodeError as e:
94
  tqdm.write(
95
  f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
 
97
  except SyntaxError as e:
98
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
99
 
100
+ extracted_infos.append(repo_info)
101
+
102
+ return extracted_infos
103
 
104
 
105
  class RepoEmbeddingPipeline(Pipeline):
 
144
 
145
  if self.st_messager:
146
  self.st_messager.info("[*] Downloading and extracting repos...")
147
+
148
  extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
149
 
150
  return extracted_infos
 
176
  return sentence_embeddings
177
 
178
  def _forward(self, extracted_infos, max_length=512, st_progress=None):
179
+ repo_dataset = []
180
+ num_texts = sum(len(x["funcs"]) + len(x["docs"]) for x in extracted_infos)
 
 
181
  with tqdm(total=num_texts) as pbar:
182
+ for repo_info in extracted_infos:
183
+ repo_name = repo_info["name"]
184
+ entry = {
185
+ "name": repo_name,
186
+ "topics": repo_info["topics"],
187
+ "license": repo_info["license"],
188
+ "stars": repo_info["stars"],
189
+ }
190
+
191
  pbar.set_description(f"Processing {repo_name}")
 
192
 
193
  message = f"[*] Generating embeddings for {repo_name}"
194
  tqdm.write(message)
 
229
  else None
230
  )
231
 
232
+ repo_dataset.append(entry)
233
 
234
  return repo_dataset
235