Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# TODO | |
# Remove duplication in code used to generate markdown | |
# periodically update models to check all still valid and public | |
import os | |
import re | |
import sys | |
from functools import lru_cache | |
from pathlib import Path | |
from typing import Dict, List, Set, Union | |
import gradio as gr | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from apscheduler.triggers.cron import CronTrigger | |
from cachetools import TTLCache, cached | |
from dotenv import load_dotenv | |
from huggingface_hub import ( | |
HfApi, | |
comment_discussion, | |
create_discussion, | |
dataset_info, | |
get_repo_discussions, | |
) | |
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError | |
from sqlitedict import SqliteDict | |
from toolz import concat, count, unique | |
from tqdm.auto import tqdm | |
from tqdm.contrib.concurrent import thread_map | |
local = bool(sys.platform.startswith("darwin")) | |
cache_location = "cache/" if local else "/data/cache" | |
save_dir = "test_data" if local else "/data/" | |
Path(save_dir).mkdir(parents=True, exist_ok=True) | |
load_dotenv() | |
user_agent = os.getenv("USER_AGENT") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
REPO = "librarian-bots/dataset-to-model-monitor" # where issues land | |
AUTHOR = "librarian-bot" # who makes the issues | |
hf_api = HfApi(user_agent=user_agent) | |
ten_min_cache = TTLCache(maxsize=5_000, ttl=600) | |
def get_datasets_for_user(username: str) -> List[str]: | |
datasets = hf_api.list_datasets(author=username) | |
datasets = (dataset.id for dataset in datasets) | |
return datasets | |
def get_models_for_dataset(dataset_id): | |
results = list(iter(hf_api.list_models(filter=f"dataset:{dataset_id}"))) | |
if results: | |
results = list({result.id for result in results}) | |
return {dataset_id: results} | |
def generate_dataset_model_map( | |
dataset_ids: List[str], | |
) -> dict[str, dict[str, List[str]]]: | |
results = thread_map(get_models_for_dataset, dataset_ids) | |
results = {key: value for d in results for key, value in d.items()} | |
return results | |
def maybe_update_datasets_to_model_map(dataset_id): | |
with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as dataset_to_model_map_db: | |
if dataset_id not in dataset_to_model_map_db: | |
dataset_to_model_map_db[dataset_id] = list( | |
get_models_for_dataset(dataset_id)[dataset_id] | |
) | |
dataset_to_model_map_db.commit() | |
return len(dataset_to_model_map_db) | |
return False | |
def datasets_tracked_by_user(username): | |
with SqliteDict( | |
f"{save_dir}/tracked_dataset_to_users.sqlite" | |
) as tracked_dataset_to_users_db: | |
return [ | |
dataset | |
for dataset, users in tracked_dataset_to_users_db.items() | |
if username in users | |
] | |
def update_tracked_dataset_to_users(dataset_id: str, username: str): | |
with SqliteDict( | |
f"{save_dir}/tracked_dataset_to_users.sqlite", | |
) as tracked_dataset_to_users_db: | |
if dataset_id in tracked_dataset_to_users_db: | |
# check if user already tracking dataset | |
if username not in tracked_dataset_to_users_db[dataset_id]: | |
users_for_dataset = tracked_dataset_to_users_db[dataset_id] | |
users_for_dataset.append(username) | |
tracked_dataset_to_users_db[dataset_id] = list(set(users_for_dataset)) | |
tracked_dataset_to_users_db.commit() | |
else: | |
tracked_dataset_to_users_db[dataset_id] = [username] | |
tracked_dataset_to_users_db.commit() | |
return datasets_tracked_by_user(username) | |
HUB_ORG_OR_USERNAME_GLOB_PATTERN = re.compile(r"^([a-zA-Z0-9_-]+)\/\*$") | |
def match_org_user_glob_pattern(hub_id): | |
if match := re.match(HUB_ORG_OR_USERNAME_GLOB_PATTERN, hub_id): | |
return match[1] | |
else: | |
return None | |
def grab_dataset_ids_for_user_or_org(hub_id: str) -> List[str]: | |
datasets_for_org = hf_api.list_datasets(author=hub_id) | |
datasets_for_org = ( | |
dataset for dataset in datasets_for_org if dataset.private is False | |
) | |
return [dataset.id for dataset in datasets_for_org] | |
def parse_hub_id_entry(hub_id: str) -> Union[str, List[str]]: | |
if match := match_org_user_glob_pattern(hub_id): | |
return grab_dataset_ids_for_user_or_org(match), match | |
try: | |
dataset_info(hub_id) | |
return hub_id, match | |
except HFValidationError as e: | |
raise gr.Error(f"Invalid format for Hugging Face Hub dataset ID. {e}") from e | |
except RepositoryNotFoundError as e: | |
raise gr.Error(f"{hub_id}: Invalid Hugging Face Hub dataset ID") from e | |
def remove_user_from_tracking_datasets(dataset_id, profile: gr.OAuthProfile | None): | |
if not profile and not local: | |
return "You must be logged in to remove a dataset" | |
username = profile.preferred_username | |
dataset_id, match = parse_hub_id_entry(dataset_id) | |
if isinstance(dataset_id, str): | |
return _remove_user_from_tracking_datasets(dataset_id, username) | |
if isinstance(dataset_id, list): | |
[ | |
_remove_user_from_tracking_datasets(dataset, username) | |
for dataset in dataset_id | |
] | |
return f"Stopped tracking datasets for username or org: {match}" | |
def _remove_user_from_tracking_datasets(dataset_id: str, username): | |
with SqliteDict( | |
f"{save_dir}/tracked_dataset_to_users.sqlite" | |
) as tracked_dataset_to_users_db: | |
users = tracked_dataset_to_users_db.get(dataset_id) | |
if users is None: | |
return "Dataset not being tracked" | |
try: | |
users.remove(username) | |
except ValueError: | |
return "No longer tracking dataset" | |
tracked_dataset_to_users_db[dataset_id] = users | |
if len(users) < 1: | |
del tracked_dataset_to_users_db[dataset_id] | |
with SqliteDict( | |
f"{save_dir}/models_to_dataset.sqlite" | |
) as dataset_to_models_db: | |
del dataset_to_models_db[dataset_id] | |
dataset_to_models_db.commit() | |
tracked_dataset_to_users_db.commit() | |
return "Dataset no longer being tracked" | |
def user_unsubscribe_all(username): | |
datasets_tracked = datasets_tracked_by_user(username) | |
for dataset_id in datasets_tracked: | |
remove_user_from_tracking_datasets(username, dataset_id) | |
assert len(datasets_tracked_by_user(username)) == 0 | |
return f"Unsubscribed from {len(datasets_tracked)} datasets" | |
def user_update(hub_id, profile: gr.OAuthProfile | None): | |
if not profile and not local: | |
return "Please login to track a dataset" | |
username = profile.preferred_username | |
hub_id, match = parse_hub_id_entry(hub_id) | |
if isinstance(hub_id, str): | |
return _user_update(hub_id, username) | |
else: | |
return glob_update_tracked_datasets(hub_id, username, match) | |
def glob_update_tracked_datasets(hub_ids, username, match): | |
for id_ in tqdm(hub_ids): | |
_user_update(id_, username) | |
response = "## Dataset tracking summary \n\n" | |
response += ( | |
f"All datasets under the user or organization: {match} are being tracked \n\n" | |
) | |
tracked_datasets = datasets_tracked_by_user(username) | |
response += ( | |
"You are currently tracking whether new models have been trained on" | |
f" {len(tracked_datasets)} datasets.\n\n" | |
) | |
if tracked_datasets: | |
response += "### Datasets being tracked \n\n" | |
response += ( | |
"You are currently monitoring whether new models have been trained on the" | |
" following datasets:\n" | |
) | |
for dataset in tracked_datasets: | |
response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
return response | |
def _user_update(hub_id: str, username: str) -> str: | |
"""Update the user's tracked datasets and return a response string.""" | |
response = "" | |
if number_datasets_being_tracked := maybe_update_datasets_to_model_map(hub_id): | |
response += ( | |
"New dataset being tracked! Now tracking" | |
f" {number_datasets_being_tracked} datasets \n\n" | |
) | |
if not number_datasets_being_tracked: | |
response += f"Dataset {hub_id} is already being tracked. \n\n" | |
datasets_tracked_by_user = update_tracked_dataset_to_users(hub_id, username) | |
response += ( | |
"You are currently tracking whether new models have been trained on" | |
f" {len(datasets_tracked_by_user)} datasets." | |
) | |
if datasets_tracked_by_user: | |
response += ( | |
"\nYou are currently monitoring whether new models have been trained on the" | |
" following datasets:\n" | |
) | |
for dataset in datasets_tracked_by_user: | |
response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
else: | |
response += "You are not currently tracking any datasets." | |
return response | |
def check_for_new_models_for_dataset_and_update() -> Dict[str, Set[str]]: | |
# if not Path(f"{save_dir}/models_to_dataset.json").is_file(): | |
with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as old_results_db: | |
dataset_ids = list(old_results_db.keys()) | |
new_results = generate_dataset_model_map(dataset_ids) | |
models_to_notify_about = { | |
dataset_id: set(models).difference(set(old_results_db[dataset_id])) | |
for dataset_id, models in new_results.items() | |
if len(models) > len(old_results_db[dataset_id]) | |
} | |
for dataset_id, models in new_results.items(): | |
old_results_db[dataset_id] = models | |
old_results_db.commit() | |
return models_to_notify_about | |
def get_repo_discussion_by_author_and_type( | |
repo, author, token, repo_type="space", include_prs=False | |
): | |
discussions = get_repo_discussions(repo, repo_type=repo_type, token=token) | |
for discussion in discussions: | |
if discussion.author == author: | |
if not include_prs and discussion.is_pull_request: | |
continue | |
yield discussion | |
def create_discussion_text_body(dataset_id, new_models, users_to_notify): | |
usernames = [f"@{username}" for username in users_to_notify] | |
usernames_string = ", ".join(usernames) | |
dataset_id_markdown_url = ( | |
f"[{dataset_id}](https://huggingface.co/datasets/{dataset_id})" | |
) | |
description = ( | |
f"Hey {usernames_string}! Librarian Bot found new models trained on the" | |
f" {dataset_id_markdown_url} dataset!\n\n" | |
) | |
description += f"New model trained on {dataset_id}:\n" | |
markdown_items = [ | |
f"- {hub_id_to_huggingface_hub_url_markdown(model)}" for model in new_models | |
] | |
markdown_list = "\n".join(markdown_items) | |
description += markdown_list | |
description += """\n\n This discussion was created by the [Dataset to Model Monitor](https://huggingface.co/spaces/librarian-bots/dataset-to-model-monitor) Space. You can modify your alerts using this Space.""" | |
return description | |
def maybe_create_discussion( | |
repo: str, | |
dataset_id: str, | |
new_models: Union[List, str], | |
users_to_notify: List[str], | |
author: str, | |
token: str, | |
): | |
title = f"Discussion tracking new models trained on {dataset_id}" | |
discussions = get_repo_discussion_by_author_and_type(repo, author, HF_TOKEN) | |
if discussions_for_dataset := next( | |
(discussion for discussion in discussions if title == discussion.title), | |
None, | |
): | |
discussion_id = discussions_for_dataset.num | |
description = create_discussion_text_body( | |
dataset_id, new_models, users_to_notify | |
) | |
comment_discussion( | |
repo, discussion_id, description, token=token, repo_type="space" | |
) | |
else: | |
description = create_discussion_text_body( | |
dataset_id, new_models, users_to_notify | |
) | |
create_discussion( | |
repo, | |
title, | |
token=token, | |
description=description, | |
repo_type="space", | |
) | |
def hub_id_to_huggingface_hub_url_markdown(hub_id: str) -> str: | |
return f"[{hub_id}](https://huggingface.co/{hub_id})" | |
def notify_about_new_models(): | |
print("running notifications") | |
if models_to_notify_about := check_for_new_models_for_dataset_and_update(): | |
for dataset_id, new_models in models_to_notify_about.items(): | |
with SqliteDict( | |
f"{save_dir}/tracked_dataset_to_users.sqlite" | |
) as tracked_dataset_to_users_db: | |
users_to_notify = tracked_dataset_to_users_db.get(dataset_id) | |
maybe_create_discussion( | |
REPO, dataset_id, new_models, users_to_notify, AUTHOR, HF_TOKEN | |
) | |
print("notified about new models") | |
def number_of_users_tracking_datasets(): | |
with SqliteDict( | |
f"{save_dir}/tracked_dataset_to_users.sqlite" | |
) as tracked_dataset_to_users_db: | |
return count(unique(concat(iter(tracked_dataset_to_users_db.values())))) | |
def number_of_datasets_tracked(): | |
with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as datasets_to_models_db: | |
return len(datasets_to_models_db) | |
def generate_summary_stats(): | |
return ( | |
f"Currently there are {number_of_users_tracking_datasets()} users tracking" | |
f" datasets with a total of {number_of_datasets_tracked()} datasets being" | |
" tracked" | |
) | |
def _user_stats(username: str): | |
if not (tracked_datasets := datasets_tracked_by_user(username)): | |
return "You are not currently tracking any datasets" | |
response = ( | |
"You are currently tracking whether new models have been trained on" | |
f" {len(tracked_datasets)} datasets.\n\n" | |
) | |
response += "### Datasets being tracked \n\n" | |
response += ( | |
"You are currently monitoring whether new models have been trained on the" | |
" following datasets:\n" | |
) | |
for dataset in tracked_datasets: | |
response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
return response | |
def user_stats(profile: gr.OAuthProfile | None): | |
if not profile and not local: | |
return "You must be logged in to view datasets you are tracking" | |
username = profile.preferred_username | |
return _user_stats(username) | |
markdown_text = """ | |
The Hugging Face Hub allows users to specify the dataset used to train a model in the model metadata. | |
This metadata allows you to find models trained on a particular dataset. | |
These links can be very powerful for finding models that might be suitable for a particular task.\n\n | |
This Gradio app allows you to track datasets hosted on the Hugging Face Hub and get a notification when new models are trained on the dataset you are tracking. | |
1. Submit the Hugging Face Hub ID for the dataset you are interested in tracking. | |
2. If a new model is listed as being trained on this dataset Librarian Bot will ping you in a discussion on the Hugging Face Hub to let you know. | |
3. Librarian Bot will check for new models for a particular dataset once a day. | |
**NOTE** This app is a proof of concept and is intended to validate how much interest there is for a feature like this. | |
If you have feedback please add it to this [discussion](https://huggingface.co/spaces/librarian-bots/dataset-to-model-monitor/discussions/2). | |
### Tips | |
- You might find the [Hugging Face Datasets Semantic Search](https://huggingface.co/spaces/librarian-bots/huggingface-datasets-semantic-search) Space useful for finding datasets to track. | |
- You can use a wildcard `*` to track all datasets for a user or organization on the hub. For example `biglam/*` will create alerts for all the datasets under the biglam Hugging Face Organization | |
- You need to be logged in to your Hugging Face account to use this app. If you don't have a Hugging Face Hub account you can get one <a href="https://huggingface.co/join">here</a>. | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
'<div style="text-align: center;"><h1> 🤖 Librarian Bot Dataset-to-Model' | |
' Monitor 🤖 </h1><i><p style="font-size: 20px;">✨ Get alerts when a new' | |
" model is created from a dataset you are interested in! ✨</p></i></div>" | |
) | |
with gr.Row(): | |
gr.Markdown(markdown_text) | |
with gr.Row(): | |
hub_id = gr.Textbox( | |
"i.e. biglam/brill_iconclass", | |
label="Hugging Face Hub ID for dataset to track", | |
max_lines=1, | |
) | |
with gr.Column(): | |
track_button = gr.Button("Track new models for dataset") | |
with gr.Row(): | |
remove_specific_datasets = gr.Button("Stop tracking dataset") | |
remove_all = gr.Button("⛔️ Unsubscribe from all datasets ⛔️") | |
with gr.Row(variant="compact"): | |
gr.LoginButton(size="sm") | |
gr.LogoutButton(size="sm") | |
summary_stats_btn = gr.Button( | |
"Summary stats for datasets being tracked by this app", size="sm" | |
) | |
user_stats_btn = gr.Button("List my tracked datasets", size="sm") | |
with gr.Row(): | |
output = gr.Markdown() | |
track_button.click(user_update, [hub_id], output) | |
remove_specific_datasets.click( | |
remove_user_from_tracking_datasets, [hub_id], output | |
) | |
summary_stats_btn.click(generate_summary_stats, [], output) | |
user_stats_btn.click(user_stats, [], output) | |
scheduler = BackgroundScheduler() | |
if local: | |
scheduler.add_job(notify_about_new_models, "interval", minutes=5) | |
else: | |
scheduler.add_job( | |
notify_about_new_models, | |
CronTrigger.from_crontab("0 */12 * * *"), | |
) | |
scheduler.start() | |
demo.queue(max_size=5) | |
demo.launch() | |