Remove st_messager from the pipeline for streamlit data cachability
Browse files- pipeline.py +4 -19
pipeline.py
CHANGED
@@ -103,31 +103,22 @@ def download_and_extract(repos, headers=None):
|
|
103 |
|
104 |
|
105 |
class RepoEmbeddingPipeline(Pipeline):
|
106 |
-
def __init__(self, github_token=None,
|
107 |
super().__init__(*args, **kwargs)
|
108 |
|
109 |
-
# Streamlit single element container created by st.empty()
|
110 |
-
self.st_messager = st_messager
|
111 |
-
|
112 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
113 |
if not github_token:
|
114 |
-
|
115 |
"[*] Consider setting GitHub token to avoid hitting rate limits. \n"
|
116 |
"For more info, see: "
|
117 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
118 |
)
|
119 |
-
print(message)
|
120 |
-
if self.st_messager:
|
121 |
-
self.st_messager.info(message)
|
122 |
else:
|
123 |
self.set_github_token(github_token)
|
124 |
|
125 |
def set_github_token(self, github_token):
|
126 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
127 |
-
|
128 |
-
print(message)
|
129 |
-
if self.st_messager:
|
130 |
-
self.st_messager.success(message)
|
131 |
|
132 |
def _sanitize_parameters(self, **kwargs):
|
133 |
_forward_kwargs = {}
|
@@ -142,9 +133,6 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
142 |
if isinstance(inputs, str):
|
143 |
inputs = [inputs]
|
144 |
|
145 |
-
if self.st_messager:
|
146 |
-
self.st_messager.info("[*] Downloading and extracting repos...")
|
147 |
-
|
148 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
149 |
|
150 |
return extracted_infos
|
@@ -190,10 +178,7 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
190 |
|
191 |
pbar.set_description(f"Processing {repo_name}")
|
192 |
|
193 |
-
|
194 |
-
tqdm.write(message)
|
195 |
-
if self.st_messager:
|
196 |
-
self.st_messager.info(message)
|
197 |
|
198 |
code_embeddings = []
|
199 |
for func in repo_info["funcs"]:
|
|
|
103 |
|
104 |
|
105 |
class RepoEmbeddingPipeline(Pipeline):
|
106 |
+
def __init__(self, github_token=None, *args, **kwargs):
|
107 |
super().__init__(*args, **kwargs)
|
108 |
|
|
|
|
|
|
|
109 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
110 |
if not github_token:
|
111 |
+
print(
|
112 |
"[*] Consider setting GitHub token to avoid hitting rate limits. \n"
|
113 |
"For more info, see: "
|
114 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
115 |
)
|
|
|
|
|
|
|
116 |
else:
|
117 |
self.set_github_token(github_token)
|
118 |
|
119 |
def set_github_token(self, github_token):
|
120 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
121 |
+
print("[+] GitHub token set")
|
|
|
|
|
|
|
122 |
|
123 |
def _sanitize_parameters(self, **kwargs):
|
124 |
_forward_kwargs = {}
|
|
|
133 |
if isinstance(inputs, str):
|
134 |
inputs = [inputs]
|
135 |
|
|
|
|
|
|
|
136 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
137 |
|
138 |
return extracted_infos
|
|
|
178 |
|
179 |
pbar.set_description(f"Processing {repo_name}")
|
180 |
|
181 |
+
tqdm.write(f"[*] Generating embeddings for {repo_name}")
|
|
|
|
|
|
|
182 |
|
183 |
code_embeddings = []
|
184 |
for func in repo_info["funcs"]:
|