Spaces:
Runtime error
Runtime error
import requests | |
from requests.adapters import HTTPAdapter, Retry | |
import logging | |
from typing import Union, Any, Optional | |
import re | |
""" | |
Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680 | |
""" | |
paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)') | |
def retry_request_session(retries: Optional[int] = 5): | |
# we setup retry strategy to retry on common errors | |
retries = Retry( | |
total=retries, | |
backoff_factor=0.1, | |
status_forcelist=[ | |
408, # request timeout | |
500, # internal server error | |
502, # bad gateway | |
503, # service unavailable | |
504 # gateway timeout | |
] | |
) | |
# we setup a session with the retry strategy | |
session = requests.Session() | |
session.mount('https://', HTTPAdapter(max_retries=retries)) | |
return session | |
def get_paper_id(query: str, handle_not_found: bool = True): | |
"""Get the paper ID from a query. | |
:param query: The query to search with | |
:type query: str | |
:param handle_not_found: Whether to return None if no paper is found, | |
defaults to True | |
:type handle_not_found: bool, optional | |
:return: The paper ID | |
:rtype: str | |
""" | |
special_chars = { | |
":": "%3A", | |
"|": "%7C", | |
",": "%2C", | |
" ": "+" | |
} | |
# create a translation table from the special_chars dictionary | |
translation_table = query.maketrans(special_chars) | |
# use the translate method to replace the special characters | |
search_term = query.translate(translation_table) | |
# init requests search session | |
session = retry_request_session() | |
# get the search results | |
res = session.get(f"https://www.google.com/search?q={search_term}&sclient=gws-wiz-serp") | |
try: | |
# extract the paper id | |
paper_id = paper_id_re.findall(res.text)[0] | |
except IndexError: | |
if handle_not_found: | |
# if no paper is found, return None | |
return None | |
else: | |
# if no paper is found, raise an error | |
raise Exception(f'No paper found for query: {query}') | |
return paper_id | |