|
import json |
|
import tempfile |
|
import zipfile |
|
from datetime import datetime |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from huggingface_hub import CommitScheduler, InferenceClient |
|
|
|
|
|
IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}" |
|
|
|
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" |
|
|
|
|
|
class ZipScheduler(CommitScheduler): |
|
""" |
|
Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub. |
|
|
|
Workflow: |
|
1. Read metadata + list PNG files. |
|
2. Zip png files in a single archive. |
|
3. Create commit (metadata + archive). |
|
4. Delete local png files to avoid re-uploading them later. |
|
|
|
Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the |
|
process can be done without blocking the Gradio app. |
|
""" |
|
|
|
def push_to_hub(self): |
|
|
|
with self.lock: |
|
png_files = list(self.folder_path.glob("*.png")) |
|
if len(png_files) == 0: |
|
return None |
|
|
|
|
|
metadata = IMAGE_JSONL_PATH.read_text() |
|
try: |
|
IMAGE_JSONL_PATH.unlink() |
|
except Exception: |
|
pass |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
archive_path = Path(tmpdir) / "train.zip" |
|
with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: |
|
|
|
for png_file in png_files: |
|
zip.write(filename=png_file, arcname=png_file.name) |
|
|
|
|
|
tmp_metadata = Path(tmpdir) / "metadata.jsonl" |
|
tmp_metadata.write_text(metadata) |
|
zip.write(filename=tmp_metadata, arcname="metadata.jsonl") |
|
|
|
|
|
self.api.upload_file( |
|
repo_id=self.repo_id, |
|
repo_type=self.repo_type, |
|
revision=self.revision, |
|
path_in_repo=f"train-{uuid4()}.zip", |
|
path_or_fileobj=archive_path, |
|
) |
|
|
|
|
|
for png_file in png_files: |
|
try: |
|
png_file.unlink() |
|
except Exception: |
|
pass |
|
|
|
|
|
scheduler = ZipScheduler( |
|
repo_id="example-commit-scheduler-image-zip", |
|
repo_type="dataset", |
|
folder_path=IMAGE_DATASET_DIR, |
|
) |
|
|
|
client = InferenceClient() |
|
|
|
|
|
def generate_image(prompt: str) -> Image: |
|
return client.text_to_image(prompt) |
|
|
|
|
|
def save_image(prompt: str, image_array: np.ndarray) -> None: |
|
print("Saving: " + prompt) |
|
image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png" |
|
|
|
with scheduler.lock: |
|
Image.fromarray(image_array).save(image_path) |
|
with IMAGE_JSONL_PATH.open("a") as f: |
|
json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f) |
|
f.write("\n") |
|
|
|
|
|
def get_demo(): |
|
with gr.Row(): |
|
prompt_value = gr.Textbox(label="Prompt") |
|
image_value = gr.Image(label="Generated image") |
|
text_to_image_btn = gr.Button("Generate") |
|
text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success( |
|
fn=save_image, |
|
inputs=[prompt_value, image_value], |
|
outputs=None, |
|
) |
|
|