gamingflexer commited on
Commit
7cbc824
·
1 Parent(s): 2c4c422

Refactor Arxiv class and add new methods

Browse files
Files changed (1) hide show
  1. src/scrapper/arxiv.py +221 -3
src/scrapper/arxiv.py CHANGED
@@ -1,8 +1,15 @@
1
- import requests
2
- from requests.adapters import HTTPAdapter, Retry
3
  import logging
4
- from typing import Union, Any, Optional
5
  import re
 
 
 
 
 
 
 
6
 
7
  """
8
  Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680
@@ -64,3 +71,214 @@ def get_paper_id(query: str, handle_not_found: bool = True):
64
  # if no paper is found, raise an error
65
  raise Exception(f'No paper found for query: {query}')
66
  return paper_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
  import logging
4
+ from typing import Optional
5
  import re
6
+ import requests
7
+ from requests.adapters import HTTPAdapter, Retry
8
+ import arxiv
9
+ import PyPDF2
10
+ import requests
11
+ from tqdm.auto import tqdm
12
+ from decouple import config
13
 
14
  """
15
  Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680
 
71
  # if no paper is found, raise an error
72
  raise Exception(f'No paper found for query: {query}')
73
  return paper_id
74
+
75
+
76
+ class Arxiv:
77
+ refs_re = re.compile(r'\n(References|REFERENCES)\n')
78
+ references = []
79
+
80
+ llm = None
81
+
82
+ def __init__(self, paper_id: str):
83
+ """Object to handle the extraction of an ArXiv paper and its
84
+ relevant information.
85
+
86
+ :param paper_id: The ID of the paper to extract
87
+ :type paper_id: str
88
+ """
89
+ self.id = paper_id
90
+ self.url = f"https://export.arxiv.org/pdf/{paper_id}.pdf"
91
+ # initialize the requests session
92
+ self.session = requests.Session()
93
+
94
+ def load(self, save: bool = False):
95
+ """Load the paper from the ArXiv API or from a local file
96
+ if it already exists. Stores the paper's text content and
97
+ meta data in self.content and other attributes.
98
+
99
+ :param save: Whether to save the paper to a local file,
100
+ defaults to False
101
+ :type save: bool, optional
102
+ """
103
+ # check if pdf already exists
104
+ if os.path.exists(f'papers/{self.id}.json'):
105
+ print(f'Loading papers/{self.id}.json from file')
106
+ with open(f'papers/{self.id}.json', 'r') as fp:
107
+ attributes = json.loads(fp.read())
108
+ for key, value in attributes.items():
109
+ setattr(self, key, value)
110
+ else:
111
+ res = self.session.get(self.url)
112
+ with open(f'temp.pdf', 'wb') as fp:
113
+ fp.write(res.content)
114
+ # extract text content
115
+ self._convert_pdf_to_text()
116
+ # get meta for PDF
117
+ self._download_meta()
118
+ if save:
119
+ self.save()
120
+
121
+ def get_refs(self, extractor, text_splitter):
122
+ """Get the references for the paper.
123
+
124
+ :param extractor: The LLMChain extractor model
125
+ :type extractor: LLMChain
126
+ :param text_splitter: The text splitter to use
127
+ :type text_splitter: TokenTextSplitter
128
+ :return: The references for the paper
129
+ :rtype: list
130
+ """
131
+ if len(self.references) == 0:
132
+ self._download_refs(extractor, text_splitter)
133
+ return self.references
134
+
135
+ def _download_refs(self, extractor, text_splitter):
136
+ """Download the references for the paper. Stores them in
137
+ the self.references attribute.
138
+
139
+ :param extractor: The LLMChain extractor model
140
+ :type extractor: LLMChain
141
+ :param text_splitter: The text splitter to use
142
+ :type text_splitter: TokenTextSplitter
143
+ """
144
+ # get references section of paper
145
+ refs = self.refs_re.split(self.content)[-1]
146
+ # we don't need the full thing, just the first page
147
+ refs_page = text_splitter.split_text(refs)[0]
148
+ # use LLM extractor to extract references
149
+ out = extractor.run(refs=refs_page)
150
+ out = out.split('\n')
151
+ out = [o for o in out if o != '']
152
+ # with list of references, find the paper IDs
153
+ ids = [get_paper_id(o) for o in out]
154
+ # clean up into JSONL type format
155
+ out = [o.split(' | ') for o in out]
156
+ # in case we're missing some fields
157
+ out = [o for o in out if len(o) == 3]
158
+ meta = [{
159
+ 'id': _id,
160
+ 'title': o[0],
161
+ 'authors': o[1],
162
+ 'year': o[2]
163
+ } for o, _id in zip(out, ids) if _id is not None]
164
+ logging.debug(f"Extracted {len(meta)} references")
165
+ self.references = meta
166
+
167
+ def _convert_pdf_to_text(self):
168
+ """Convert the PDF to text and store it in the self.content
169
+ attribute.
170
+ """
171
+ text = []
172
+ with open("temp.pdf", 'rb') as f:
173
+ # create a PDF object
174
+ pdf = PyPDF2.PdfReader(f)
175
+ # iterate over every page in the PDF
176
+ for page in range(len(pdf.pages)):
177
+ # get the page object
178
+ page_obj = pdf.pages[page]
179
+ # extract text from the page
180
+ text.append(page_obj.extract_text())
181
+ text = "\n".join(text)
182
+ self.content = text
183
+
184
+ def _download_meta(self):
185
+ """Download the meta information for the paper from the
186
+ ArXiv API and store it in the self attributes.
187
+ """
188
+ search = arxiv.Search(
189
+ query=f'id:{self.id}',
190
+ max_results=1,
191
+ sort_by=arxiv.SortCriterion.SubmittedDate
192
+ )
193
+ result = list(search.results())
194
+ if len(result) == 0:
195
+ raise ValueError(f"No paper found for paper '{self.id}'")
196
+ result = result[0]
197
+ # remove 'v1', 'v2', etc. from the end of the pdf_url
198
+ result.pdf_url = re.sub(r'v\d+$', '', result.pdf_url)
199
+ self.authors = [author.name for author in result.authors]
200
+ self.categories = result.categories
201
+ self.comment = result.comment
202
+ self.journal_ref = result.journal_ref
203
+ self.source = result.pdf_url
204
+ self.primary_category = result.primary_category
205
+ self.published = result.published.strftime('%Y%m%d')
206
+ self.summary = result.summary
207
+ self.title = result.title
208
+ self.updated = result.updated.strftime('%Y%m%d')
209
+ logging.debug(f"Downloaded metadata for paper '{self.id}'")
210
+
211
+ def save(self):
212
+ """Save the paper to a local JSON file.
213
+ """
214
+ with open(f'papers/{self.id}.json', 'w') as fp:
215
+ json.dump(self.__dict__(), fp, indent=4)
216
+
217
+ def save_chunks(
218
+ self,
219
+ include_metadata: bool = True,
220
+ path: str = "chunks"
221
+ ):
222
+ """Save the paper's chunks to a local JSONL file.
223
+
224
+ :param include_metadata: Whether to include the paper's
225
+ metadata in the chunks, defaults
226
+ to True
227
+ :type include_metadata: bool, optional
228
+ :param path: The path to save the file to, defaults to "papers"
229
+ :type path: str, optional
230
+ """
231
+ if not os.path.exists(path):
232
+ os.makedirs(path)
233
+ with open(f'{path}/{self.id}.jsonl', 'w') as fp:
234
+ for chunk in self.dataset:
235
+ if include_metadata:
236
+ chunk.update(self.get_meta())
237
+ fp.write(json.dumps(chunk) + '\n')
238
+ logging.debug(f"Saved paper to '{path}/{self.id}.jsonl'")
239
+
240
+ def get_meta(self):
241
+ """Returns the meta information for the paper.
242
+
243
+ :return: The meta information for the paper
244
+ :rtype: dict
245
+ """
246
+ fields = self.__dict__()
247
+ # drop content field because it's big
248
+ fields.pop('content')
249
+ return fields
250
+
251
+ def chunker(self, chunk_size=300):
252
+ # Single Chunk is made for now
253
+ clean_paper = self._clean_text(self.content)
254
+ langchain_dataset = []
255
+ langchain_dataset.append({
256
+ 'doi': self.id,
257
+ 'chunk-id': 1,
258
+ 'chunk': clean_paper
259
+ })
260
+ self.dataset = langchain_dataset
261
+
262
+ def _clean_text(self, text):
263
+ text = re.sub(r'-\n', '', text)
264
+ return text
265
+
266
+ def __dict__(self):
267
+ return {
268
+ 'id': self.id,
269
+ 'title': self.title,
270
+ 'summary': self.summary,
271
+ 'source': self.source,
272
+ 'authors': self.authors,
273
+ 'categories': self.categories,
274
+ 'comment': self.comment,
275
+ 'journal_ref': self.journal_ref,
276
+ 'primary_category': self.primary_category,
277
+ 'published': self.published,
278
+ 'updated': self.updated,
279
+ 'content': self.content,
280
+ 'references': self.references
281
+ }
282
+
283
+ def __repr__(self):
284
+ return f"Arxiv(paper_id='{self.id}')"