Lazyhope commited on
Commit
c0f0c0a
1 Parent(s): e063f35

Move API header warning to init of pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -17
pipeline.py CHANGED
@@ -9,17 +9,6 @@ import requests
9
  import torch
10
  from transformers import Pipeline
11
 
12
- API_HEADERS = {"Accept": "application/vnd.github+json"}
13
- if os.environ.get("GITHUB_TOKEN") is None:
14
- print(
15
- "[!] Consider setting GITHUB_TOKEN environment variable to avoid hitting rate limits\n"
16
- "For more info, see:"
17
- "https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
18
- )
19
- else:
20
- API_HEADERS["Authorization"] = f"Bearer {os.environ['GITHUB_TOKEN']}"
21
- print("[+] Using GITHUB_TOKEN for authentication")
22
-
23
 
24
  def extract_code_and_docs(text: str):
25
  """Extract code and documentation from a Python file.
@@ -53,11 +42,11 @@ def extract_code_and_docs(text: str):
53
  return code_set, docs_set
54
 
55
 
56
- def get_topics(repo_name):
57
  api_url = f"https://api.github.com/repos/{repo_name}"
58
  print(f"[+] Getting topics for {repo_name}")
59
  try:
60
- response = requests.get(api_url, headers=API_HEADERS)
61
  response.raise_for_status()
62
  except requests.exceptions.HTTPError as e:
63
  print(f"[-] Failed to get topics for {repo_name}: {e}")
@@ -71,19 +60,19 @@ def get_topics(repo_name):
71
  return topics
72
 
73
 
74
- def download_and_extract(repos):
75
  extracted_info = {}
76
  for repo_name in repos:
77
  extracted_info[repo_name] = {
78
  "funcs": set(),
79
  "docs": set(),
80
- "topics": get_topics(repo_name),
81
  }
82
 
83
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
84
  print(f"[+] Extracting functions and docstrings from {repo_name}")
85
  try:
86
- response = requests.get(download_url, headers=API_HEADERS, stream=True)
87
  response.raise_for_status()
88
  except requests.exceptions.HTTPError as e:
89
  print(f"[-] Failed to download {repo_name}: {e}")
@@ -107,6 +96,20 @@ def download_and_extract(repos):
107
 
108
 
109
  class RepoEmbeddingPipeline(Pipeline):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def _sanitize_parameters(self, **kwargs):
111
  _forward_kwargs = {}
112
  if "max_length" in kwargs:
@@ -118,7 +121,7 @@ class RepoEmbeddingPipeline(Pipeline):
118
  if isinstance(inputs, str):
119
  inputs = (inputs,)
120
 
121
- extracted_infos = download_and_extract(inputs)
122
 
123
  return extracted_infos
124
 
 
9
  import torch
10
  from transformers import Pipeline
11
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def extract_code_and_docs(text: str):
14
  """Extract code and documentation from a Python file.
 
42
  return code_set, docs_set
43
 
44
 
45
+ def get_topics(repo_name, headers=None):
46
  api_url = f"https://api.github.com/repos/{repo_name}"
47
  print(f"[+] Getting topics for {repo_name}")
48
  try:
49
+ response = requests.get(api_url, headers=headers)
50
  response.raise_for_status()
51
  except requests.exceptions.HTTPError as e:
52
  print(f"[-] Failed to get topics for {repo_name}: {e}")
 
60
  return topics
61
 
62
 
63
+ def download_and_extract(repos, headers=None):
64
  extracted_info = {}
65
  for repo_name in repos:
66
  extracted_info[repo_name] = {
67
  "funcs": set(),
68
  "docs": set(),
69
+ "topics": get_topics(repo_name, headers=headers),
70
  }
71
 
72
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
73
  print(f"[+] Extracting functions and docstrings from {repo_name}")
74
  try:
75
+ response = requests.get(download_url, headers=headers, stream=True)
76
  response.raise_for_status()
77
  except requests.exceptions.HTTPError as e:
78
  print(f"[-] Failed to download {repo_name}: {e}")
 
96
 
97
 
98
  class RepoEmbeddingPipeline(Pipeline):
99
+ def __init__(self, *args, **kwargs):
100
+ super().__init__(*args, **kwargs)
101
+ self.API_HEADERS = {"Accept": "application/vnd.github+json"}
102
+ if os.environ.get("GITHUB_TOKEN") is None:
103
+ print(
104
+ "[!] Consider setting GITHUB_TOKEN environment variable to avoid hitting rate limits\n"
105
+ "For more info, see:"
106
+ "https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
107
+ )
108
+ else:
109
+ self.API_HEADERS["Authorization"] = f"Bearer {os.environ['GITHUB_TOKEN']}"
110
+ print("[+] Using GITHUB_TOKEN for authentication")
111
+
112
+
113
  def _sanitize_parameters(self, **kwargs):
114
  _forward_kwargs = {}
115
  if "max_length" in kwargs:
 
121
  if isinstance(inputs, str):
122
  inputs = (inputs,)
123
 
124
+ extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
125
 
126
  return extracted_infos
127