|
import os |
|
import shutil |
|
import logging |
|
import pretty_errors |
|
from datasets import Dataset, load_dataset |
|
from huggingface_hub import WebhooksServer, WebhookPayload, webhook_endpoint |
|
|
|
|
|
logger = logging.getLogger("basic_logger") |
|
logger.setLevel(logging.INFO) |
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
DS_NAME = "amaye15/object-segmentation" |
|
DATA_DIR = "data" |
|
|
|
|
|
def get_data(): |
|
""" |
|
Generator function to stream data from the dataset. |
|
""" |
|
ds = load_dataset( |
|
DS_NAME, |
|
cache_dir=os.path.join(os.getcwd(), DATA_DIR), |
|
streaming=True, |
|
download_mode="force_redownload", |
|
) |
|
for row in ds["train"]: |
|
yield row |
|
|
|
|
|
def process_and_push_data(): |
|
""" |
|
Function to process and push new data. |
|
""" |
|
p = os.path.join(os.getcwd(), DATA_DIR) |
|
|
|
if os.path.exists(p): |
|
shutil.rmtree(p) |
|
|
|
os.mkdir(p) |
|
|
|
ds_processed = Dataset.from_generator(get_data) |
|
ds_processed.push_to_hub("amaye15/tmp") |
|
|
|
|
|
|
|
app = WebhooksServer(webhook_secret="my_secret_key") |
|
|
|
|
|
@webhook_endpoint |
|
async def trigger_processing(payload: WebhookPayload): |
|
""" |
|
Webhook endpoint that triggers data processing when a dataset is updated. |
|
""" |
|
if payload.repo.type == "dataset" and payload.event.action == "update": |
|
logger.info(f"Dataset {payload.repo.name} updated. Triggering processing.") |
|
process_and_push_data() |
|
return {"message": "Data processing triggered successfully."} |
|
else: |
|
logger.info(f"Ignored event: {payload.event.action} on {payload.repo.name}") |
|
return {"message": "Event ignored."} |
|
|
|
|
|
|
|
app.launch() |
|
|