Lazyhope commited on
Commit
ea248b5
1 Parent(s): 561a94f

Add streamlit widget support to the pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +31 -8
pipeline.py CHANGED
@@ -6,7 +6,7 @@ from io import BytesIO
6
  import numpy as np
7
  import requests
8
  import torch
9
- from tqdm import tqdm
10
  from transformers import Pipeline
11
 
12
 
@@ -96,26 +96,38 @@ def download_and_extract(repos, headers=None):
96
 
97
 
98
  class RepoEmbeddingPipeline(Pipeline):
99
- def __init__(self, github_token=None, *args, **kwargs):
100
  super().__init__(*args, **kwargs)
 
 
 
 
101
  self.API_HEADERS = {"Accept": "application/vnd.github+json"}
102
  if not github_token:
103
- print(
104
- "[!] Consider setting GitHub token to avoid hitting rate limits\n"
105
- "For more info, see:"
106
  "https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
107
  )
 
 
 
108
  else:
109
  self.set_github_token(github_token)
110
 
111
  def set_github_token(self, github_token):
112
  self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
113
- print("[+] GitHub token set")
 
 
 
114
 
115
  def _sanitize_parameters(self, **kwargs):
116
  _forward_kwargs = {}
117
  if "max_length" in kwargs:
118
  _forward_kwargs["max_length"] = kwargs["max_length"]
 
 
119
 
120
  return {}, _forward_kwargs, {}
121
 
@@ -123,6 +135,8 @@ class RepoEmbeddingPipeline(Pipeline):
123
  if isinstance(inputs, str):
124
  inputs = (inputs,)
125
 
 
 
126
  extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
127
 
128
  return extracted_infos
@@ -153,7 +167,7 @@ class RepoEmbeddingPipeline(Pipeline):
153
 
154
  return sentence_embeddings
155
 
156
- def _forward(self, extracted_infos, max_length=512):
157
  repo_dataset = {}
158
  num_texts = sum(
159
  len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
@@ -163,14 +177,20 @@ class RepoEmbeddingPipeline(Pipeline):
163
  pbar.set_description(f"Processing {repo_name}")
164
  entry = {"topics": repo_info.get("topics")}
165
 
166
- print(f"[+] Generating embeddings for {repo_name}")
 
 
 
167
 
168
  code_embeddings = []
169
  for func in repo_info["funcs"]:
170
  code_embeddings.append(
171
  [func, self.encode(func, max_length).squeeze().tolist()]
172
  )
 
173
  pbar.update(1)
 
 
174
 
175
  entry["code_embeddings"] = code_embeddings
176
  entry["mean_code_embedding"] = (
@@ -184,7 +204,10 @@ class RepoEmbeddingPipeline(Pipeline):
184
  doc_embeddings.append(
185
  [doc, self.encode(doc, max_length).squeeze().tolist()]
186
  )
 
187
  pbar.update(1)
 
 
188
 
189
  entry["doc_embeddings"] = doc_embeddings
190
  entry["mean_doc_embedding"] = (
 
6
  import numpy as np
7
  import requests
8
  import torch
9
+ from tqdm.auto import tqdm
10
  from transformers import Pipeline
11
 
12
 
 
96
 
97
 
98
  class RepoEmbeddingPipeline(Pipeline):
99
+ def __init__(self, github_token=None, st_messager=None, *args, **kwargs):
100
  super().__init__(*args, **kwargs)
101
+
102
+ # Streamlit single element container created by st.empty()
103
+ self.st_messager = st_messager
104
+
105
  self.API_HEADERS = {"Accept": "application/vnd.github+json"}
106
  if not github_token:
107
+ message = (
108
+ "[*] Consider setting GitHub token to avoid hitting rate limits. \n"
109
+ "For more info, see: "
110
  "https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
111
  )
112
+ print(message)
113
+ if self.st_messager:
114
+ self.st_messager.info(message)
115
  else:
116
  self.set_github_token(github_token)
117
 
118
  def set_github_token(self, github_token):
119
  self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
120
+ message = "[+] GitHub token set"
121
+ print(message)
122
+ if self.st_messager:
123
+ self.st_messager.success(message)
124
 
125
  def _sanitize_parameters(self, **kwargs):
126
  _forward_kwargs = {}
127
  if "max_length" in kwargs:
128
  _forward_kwargs["max_length"] = kwargs["max_length"]
129
+ if "st_progress" in kwargs:
130
+ _forward_kwargs["st_progress"] = kwargs["st_progress"]
131
 
132
  return {}, _forward_kwargs, {}
133
 
 
135
  if isinstance(inputs, str):
136
  inputs = (inputs,)
137
 
138
+ if self.st_messager:
139
+ self.st_messager.info("[*] Downloading and extracting repos...")
140
  extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
141
 
142
  return extracted_infos
 
167
 
168
  return sentence_embeddings
169
 
170
+ def _forward(self, extracted_infos, max_length=512, st_progress=None):
171
  repo_dataset = {}
172
  num_texts = sum(
173
  len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
 
177
  pbar.set_description(f"Processing {repo_name}")
178
  entry = {"topics": repo_info.get("topics")}
179
 
180
+ message = f"[*] Generating embeddings for {repo_name}"
181
+ tqdm.write(message)
182
+ if self.st_messager:
183
+ self.st_messager.info(message)
184
 
185
  code_embeddings = []
186
  for func in repo_info["funcs"]:
187
  code_embeddings.append(
188
  [func, self.encode(func, max_length).squeeze().tolist()]
189
  )
190
+
191
  pbar.update(1)
192
+ if st_progress:
193
+ st_progress.progress(pbar.n / pbar.total)
194
 
195
  entry["code_embeddings"] = code_embeddings
196
  entry["mean_code_embedding"] = (
 
204
  doc_embeddings.append(
205
  [doc, self.encode(doc, max_length).squeeze().tolist()]
206
  )
207
+
208
  pbar.update(1)
209
+ if st_progress:
210
+ st_progress.progress(pbar.n / pbar.total)
211
 
212
  entry["doc_embeddings"] = doc_embeddings
213
  entry["mean_doc_embedding"] = (