pr-test-analyzer / utils.py
sayakpaul's picture
sayakpaul HF staff
Upload 3 files
bc69905 verified
import requests
import zipfile
import tempfile
def group_tests_by_duration(file_path: str) -> dict:
# Define the buckets and their labels
buckets = [(0, 5), (5, 10), (10, 15), (15, 20), (20, float('inf'))]
bucket_names = ["0-5s", "5-10s", "10-15s", "15-20s", ">20s"]
test_groups = {name: [] for name in bucket_names}
# Process the file with error handling
with open(file_path, 'r') as file:
for line in file:
try:
parts = line.split()
# Extracting duration and test name, ignoring lines that don't match expected format
if len(parts) >= 3 and 's' in parts[0]:
duration = float(parts[0].rstrip('s')) # Remove 's' and convert to float
test_name = ' '.join(parts[2:]) # Join back the test name parts
# Assign test to the correct bucket based on duration
for (start, end), bucket_name in zip(buckets, bucket_names):
if start <= duration < end:
test_groups[bucket_name].append((duration, test_name))
break
except ValueError:
# Skip lines that cannot be parsed properly
continue
return test_groups
def extract_top_n_tests(file_path, n=10):
test_durations = []
# Reading and processing the file
with open(file_path, 'r') as file:
for line in file:
parts = line.split()
if len(parts) >= 3 and parts[1] == 'call':
duration_s = parts[0].rstrip('s') # Remove the 's' from the duration
try:
duration = float(duration_s)
test_name = ' '.join(parts[2:])
test_durations.append((duration, test_name))
except ValueError:
# Skip lines that cannot be converted to float
continue
# Sort the list in descending order of duration
test_durations.sort(reverse=True, key=lambda x: x[0])
# Extract the top N tests
top_n_tests = {test[1]: f"{test[0]}s"
for i, test in enumerate(test_durations[:n])}
return top_n_tests
def fetch_test_duration_artifact(repo_id, token, run_id, artifact_name):
# Construct the API URL
owner_repo = repo_id.split("/")
artifacts_url = f'https://api.github.com/repos/{owner_repo[0]}/{owner_repo[1]}/actions/runs/{run_id}/artifacts'
# Set up the headers with your authentication token
headers = {'Authorization': f'token {token}'}
# Send the request to get a list of artifacts from the specified run
response = requests.get(artifacts_url, headers=headers)
response.raise_for_status() # Raise an exception for HTTP error responses
# Search for the artifact with the specified name
download_url = None
for artifact in response.json().get('artifacts', []):
if artifact['name'] == artifact_name:
download_url = artifact['archive_download_url']
break
if download_url:
# Download the artifact
download_response = requests.get(download_url, headers=headers, stream=True)
download_response.raise_for_status()
# Save the downloaded artifact to a file
zip_file_path = f'{artifact_name}.zip'
with open(zip_file_path, 'wb') as file:
for chunk in download_response.iter_content(chunk_size=128):
file.write(chunk)
# Extract the duration text file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
# Check if the specified file exists in the zip
zip_files = zip_ref.namelist()
for file in zip_files:
if "duration" in file:
zip_ref.extract(file, ".")
break
return file
else:
raise ValueError("Error 🥲")
def format_to_markdown_str(test_bucket_map, top_n_slow_tests, repo_id, run_id, artifact_name):
run_url = f"https://github.com/{repo_id}/actions/runs/{run_id}/"
markdown_str = f"""
## Top {len(top_n_slow_tests)} slow test for {artifact_name}\n
"""
for test, duration in top_n_slow_tests.items():
markdown_str += f"* {test.split('/')[-1]}: {duration}\n"
markdown_str += """
## Bucketed durations of the tests\n
"""
for bucket, num_tests in test_bucket_map.items():
if ">" in bucket:
bucket = f"\{bucket}"
markdown_str += f"* {bucket}: {num_tests} tests\n"
markdown_str += f"\nRun URL: [{run_url}]({run_url})."
return markdown_str
def analyze_tests(repo_id, token, run_id, artifact_name, top_n):
test_duration_file = fetch_test_duration_artifact(repo_id=repo_id, token=token, run_id=run_id, artifact_name=artifact_name)
grouped_tests_map = group_tests_by_duration(test_duration_file)
test_bucket_map = {bucket: len(tests) for bucket, tests in grouped_tests_map.items()}
print(test_bucket_map)
top_n_slow_tests = extract_top_n_tests(test_duration_file, n=top_n)
print(top_n_slow_tests)
return format_to_markdown_str(test_bucket_map, top_n_slow_tests, repo_id, run_id, artifact_name)