gamingflexer
Add arXiv scrapper module
ec6a480
raw
history blame
2.16 kB
import requests
from requests.adapters import HTTPAdapter, Retry
import logging
from typing import Union, Any, Optional
import re
"""
Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680
"""
paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)')
def retry_request_session(retries: Optional[int] = 5):
# we setup retry strategy to retry on common errors
retries = Retry(
total=retries,
backoff_factor=0.1,
status_forcelist=[
408, # request timeout
500, # internal server error
502, # bad gateway
503, # service unavailable
504 # gateway timeout
]
)
# we setup a session with the retry strategy
session = requests.Session()
session.mount('https://', HTTPAdapter(max_retries=retries))
return session
def get_paper_id(query: str, handle_not_found: bool = True):
"""Get the paper ID from a query.
:param query: The query to search with
:type query: str
:param handle_not_found: Whether to return None if no paper is found,
defaults to True
:type handle_not_found: bool, optional
:return: The paper ID
:rtype: str
"""
special_chars = {
":": "%3A",
"|": "%7C",
",": "%2C",
" ": "+"
}
# create a translation table from the special_chars dictionary
translation_table = query.maketrans(special_chars)
# use the translate method to replace the special characters
search_term = query.translate(translation_table)
# init requests search session
session = retry_request_session()
# get the search results
res = session.get(f"https://www.google.com/search?q={search_term}&sclient=gws-wiz-serp")
try:
# extract the paper id
paper_id = paper_id_re.findall(res.text)[0]
except IndexError:
if handle_not_found:
# if no paper is found, return None
return None
else:
# if no paper is found, raise an error
raise Exception(f'No paper found for query: {query}')
return paper_id