import spaces import gradio as gr import torch import torchvision from PIL import Image import numpy as np import os import shutil from huggingface_hub import HfApi, HfFolder from transformers import AutoModelForImageClassification, Trainer, TrainingArguments from datasets import load_dataset, Dataset from sklearn.metrics import accuracy_score @spaces.GPU def dummy_gpu(): pass HF_MODEL = "google/vit-base-patch16-224" HF_DATASET = "verytuffcat/recaptcha-dataset" HF_REPO = "" HF_TOKEN = os.getenv("HF_TOKEN", "") if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO") if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET") if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL") OUT_DIR = "./new_model" def pil_to_torch(image: Image.Image): return torchvision.transforms.functional.to_tensor(image.convert("RGB").resize((224, 224), Image.BICUBIC)) def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) metrics = dict(accuracy=accuracy_score(predictions, labels)) return metrics def collate_fn(batch): pixel_values = torch.stack([pil_to_torch(x["image"]) for x in batch]) labels = torch.tensor([x["label"] for x in batch]) return {"pixel_values": pixel_values, "labels": labels} def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, slice: int, log_md: str, progress=gr.Progress(track_tqdm=True)): try: if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.") if not hf_token: hf_token = HF_TOKEN if not hf_token: raise gr.Error("Input HF token.") HfFolder.save_token(hf_token) api = HfApi(token=hf_token) if slice >= 1: dataset = load_dataset(dataset_id, split=f"train[1:{int(slice)+1}]", num_proc=8) else: dataset = load_dataset(dataset_id, split="train", num_proc=8) labels = dataset.features["label"].names label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = i id2label[i] = label model = AutoModelForImageClassification.from_pretrained(model_id, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True) training_args = TrainingArguments( output_dir=OUT_DIR, use_cpu=True, no_cuda=True, # fp16=True, optim="adamw_torch", lr_scheduler_type="linear", learning_rate=0.00005, per_device_train_batch_size=8, num_train_epochs=3, gradient_accumulation_steps=1, use_ipex=True, #eval_strategy="epoch", eval_strategy="no", logging_strategy="epoch", remove_unused_columns=False, push_to_hub=False, save_total_limit=2, report_to="none" ) trainer = Trainer( model=model, args=training_args, data_collator=collate_fn, compute_metrics=compute_metrics, train_dataset=dataset, #eval_dataset=dataset, eval_dataset=None, ) train_results = trainer.train() trainer.save_model(OUT_DIR) trainer.log_metrics("train", train_results.metrics) trainer.save_metrics("train", train_results.metrics) trainer.save_state() api.create_repo(repo_id=repo_id, private=True, exist_ok=True, token=hf_token) api.upload_folder(repo_id=repo_id, folder_path=OUT_DIR, path_in_repo="", token=HF_TOKEN) shutil.rmtree(OUT_DIR) return log_md except Exception as e: raise gr.Error(f"Error occured: {e}") with gr.Blocks() as demo: with gr.Row(): model_id = gr.Textbox(label="Source model", value=HF_MODEL, lines=1) dataset_id = gr.Textbox(label="Source dataset", value=HF_DATASET, lines=1) with gr.Row(): repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1) hf_token = gr.Textbox(label="HF write token", value="", lines=1) with gr.Accordion("Advanced", open=False): slice = gr.Number(label="Slice dataset", info="If 0, use whole dataset", minimum=0, maximum=999999, step=1, value=0) train_btn = gr.Button("Train") log_md = gr.Markdown(label="Log", value="

") train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, slice, log_md], [log_md]) demo.queue().launch()