|
import ast |
|
import os |
|
import tarfile |
|
from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module |
|
from io import BytesIO |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
from transformers import Pipeline |
|
|
|
API_HEADERS = {"Accept": "application/vnd.github+json"} |
|
if os.environ.get("GITHUB_TOKEN") is None: |
|
print( |
|
"[!] Consider setting GITHUB_TOKEN environment variable to avoid hitting rate limits\n" |
|
"For more info, see:" |
|
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token" |
|
) |
|
else: |
|
API_HEADERS["Authorization"] = f"Bearer {os.environ['GITHUB_TOKEN']}" |
|
print("[+] Using GITHUB_TOKEN for authentication") |
|
|
|
|
|
def extract_code_and_docs(text: str): |
|
"""Extract code and documentation from a Python file. |
|
|
|
Args: |
|
text (str): Source code of a Python file |
|
|
|
Returns: |
|
tuple: A tuple of two sets, the first is the code set, and the second is the docs set, |
|
each set contains unique code string or docstring, respectively. |
|
""" |
|
root = ast.parse(text) |
|
def_nodes = [ |
|
node |
|
for node in ast.walk(root) |
|
if isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)) |
|
] |
|
|
|
code_set = set() |
|
docs_set = set() |
|
for node in def_nodes: |
|
docs = ast.get_docstring(node) |
|
node_without_docs = node |
|
if docs is not None: |
|
docs_set.add(docs) |
|
|
|
node_without_docs.body = node_without_docs.body[1:] |
|
if isinstance(node, (AsyncFunctionDef, FunctionDef)): |
|
code_set.add(ast.unparse(node_without_docs)) |
|
|
|
return code_set, docs_set |
|
|
|
|
|
def get_topics(repo_name): |
|
api_url = f"https://api.github.com/repos/{repo_name}" |
|
print(f"[+] Getting topics for {repo_name}") |
|
try: |
|
response = requests.get(api_url, headers=API_HEADERS) |
|
response.raise_for_status() |
|
except requests.exceptions.HTTPError as e: |
|
print(f"[-] Failed to get topics for {repo_name}: {e}") |
|
return [] |
|
|
|
metadata = response.json() |
|
topics = metadata.get("topics", []) |
|
if topics: |
|
print(f"[+] Topics found for {repo_name}: {topics}") |
|
|
|
return topics |
|
|
|
|
|
def download_and_extract(repos): |
|
extracted_info = {} |
|
for repo_name in repos: |
|
extracted_info[repo_name] = { |
|
"funcs": set(), |
|
"docs": set(), |
|
"topics": get_topics(repo_name), |
|
} |
|
|
|
download_url = f"https://api.github.com/repos/{repo_name}/tarball" |
|
print(f"[+] Extracting functions and docstrings from {repo_name}") |
|
try: |
|
response = requests.get(download_url, headers=API_HEADERS, stream=True) |
|
response.raise_for_status() |
|
except requests.exceptions.HTTPError as e: |
|
print(f"[-] Failed to download {repo_name}: {e}") |
|
continue |
|
|
|
repo_bytes = BytesIO(response.raw.read()) |
|
print(f"[+] Extracting {repo_name} info") |
|
with tarfile.open(fileobj=repo_bytes) as tar: |
|
for member in tar.getmembers(): |
|
if member.isfile() and member.name.endswith(".py"): |
|
file_content = tar.extractfile(member).read().decode("utf-8") |
|
try: |
|
code_set, docs_set = extract_code_and_docs(file_content) |
|
except SyntaxError as e: |
|
print(f"[-] SyntaxError in {member.name}: {e}, skipping") |
|
continue |
|
extracted_info[repo_name]["funcs"].update(code_set) |
|
extracted_info[repo_name]["docs"].update(docs_set) |
|
|
|
return extracted_info |
|
|
|
|
|
class RepoEmbeddingPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
_forward_kwargs = {} |
|
if "max_length" in kwargs: |
|
_forward_kwargs["max_length"] = kwargs["max_length"] |
|
|
|
return {}, _forward_kwargs, {} |
|
|
|
def preprocess(self, inputs): |
|
if isinstance(inputs, str): |
|
inputs = (inputs,) |
|
|
|
extracted_infos = download_and_extract(inputs) |
|
|
|
return extracted_infos |
|
|
|
def encode(self, text, max_length): |
|
""" |
|
Generates an embedding for a input string. |
|
|
|
Parameters: |
|
|
|
* `text`- The input string to be embedded. |
|
* `max_length`- The maximum total source sequence length after tokenization. |
|
""" |
|
assert max_length < 1024 |
|
|
|
tokenizer = self.tokenizer |
|
|
|
tokens = ( |
|
[tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token] |
|
+ tokenizer.tokenize(text)[: max_length - 4] |
|
+ [tokenizer.sep_token] |
|
) |
|
tokens_id = tokenizer.convert_tokens_to_ids(tokens) |
|
source_ids = torch.tensor([tokens_id]) |
|
|
|
token_embeddings = self.model(source_ids)[0] |
|
sentence_embeddings = token_embeddings.mean(dim=1) |
|
|
|
return sentence_embeddings |
|
|
|
def _forward(self, extracted_infos, max_length=512): |
|
repo_dataset = {} |
|
for repo_name, repo_info in extracted_infos.items(): |
|
entry = {"topics": repo_info.get("topics")} |
|
|
|
print(f"[+] Generating embeddings for {repo_name}") |
|
if entry.get("code_embeddings") is None: |
|
code_embeddings = [ |
|
[func, self.encode(func, max_length).squeeze().tolist()] |
|
for func in repo_info["funcs"] |
|
] |
|
entry["code_embeddings"] = code_embeddings |
|
entry["mean_code_embeddings"] = ( |
|
np.mean([x[1] for x in code_embeddings], axis=0).tolist() |
|
if code_embeddings |
|
else None |
|
) |
|
if entry.get("doc_embeddings") is None: |
|
doc_embeddings = [ |
|
[doc, self.encode(doc, max_length).squeeze().tolist()] |
|
for doc in repo_info["docs"] |
|
] |
|
entry["doc_embeddings"] = doc_embeddings |
|
entry["mean_doc_embeddings"] = ( |
|
np.mean([x[1] for x in doc_embeddings], axis=0).tolist() |
|
if doc_embeddings |
|
else None |
|
) |
|
|
|
repo_dataset[repo_name] = entry |
|
|
|
return repo_dataset |
|
|
|
def postprocess(self, repo_dataset): |
|
return repo_dataset |
|
|