space_to_dataset_saver / app_1M_image.py
Wauplin's picture
Wauplin HF staff
Upload 4 files
3531f81
raw
history blame
3.6 kB
import json
import tempfile
import zipfile
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import CommitScheduler, InferenceClient
IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}"
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
class ZipScheduler(CommitScheduler):
"""
Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub.
Workflow:
1. Read metadata + list PNG files.
2. Zip png files in a single archive.
3. Create commit (metadata + archive).
4. Delete local png files to avoid re-uploading them later.
Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the
process can be done without blocking the Gradio app.
"""
def push_to_hub(self):
# 1. Read metadata + list PNG files
with self.lock:
png_files = list(self.folder_path.glob("*.png"))
if len(png_files) == 0:
return None # return early if nothing to commit
# Read and delete metadata file
metadata = IMAGE_JSONL_PATH.read_text()
try:
IMAGE_JSONL_PATH.unlink()
except Exception:
pass
with tempfile.TemporaryDirectory() as tmpdir:
# 2. Zip png files + metadata in a single archive
archive_path = Path(tmpdir) / "train.zip"
with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip:
# PNG files
for png_file in png_files:
zip.write(filename=png_file, arcname=png_file.name)
# Metadata
tmp_metadata = Path(tmpdir) / "metadata.jsonl"
tmp_metadata.write_text(metadata)
zip.write(filename=tmp_metadata, arcname="metadata.jsonl")
# 3. Create commit
self.api.upload_file(
repo_id=self.repo_id,
repo_type=self.repo_type,
revision=self.revision,
path_in_repo=f"train-{uuid4()}.zip",
path_or_fileobj=archive_path,
)
# 4. Delete local png files to avoid re-uploading them later
for png_file in png_files:
try:
png_file.unlink()
except Exception:
pass
scheduler = ZipScheduler(
repo_id="example-commit-scheduler-image-zip",
repo_type="dataset",
folder_path=IMAGE_DATASET_DIR,
)
client = InferenceClient()
def generate_image(prompt: str) -> Image:
return client.text_to_image(prompt)
def save_image(prompt: str, image_array: np.ndarray) -> None:
print("Saving: " + prompt)
image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
with scheduler.lock:
Image.fromarray(image_array).save(image_path)
with IMAGE_JSONL_PATH.open("a") as f:
json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f)
f.write("\n")
def get_demo():
with gr.Row():
prompt_value = gr.Textbox(label="Prompt")
image_value = gr.Image(label="Generated image")
text_to_image_btn = gr.Button("Generate")
text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success(
fn=save_image,
inputs=[prompt_value, image_value],
outputs=None,
)