| |
| """Build the AnimeTimm organization card assets from checked-in model data.""" |
|
|
| import csv |
| import json |
| from collections import OrderedDict |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import Rectangle |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| ASSETS = ROOT / "assets" |
| DATA = ROOT / "data" |
| ASSETS.mkdir(exist_ok=True) |
|
|
| FAMILY_ORDER = [ |
| "ConvNeXt", |
| "EVA", |
| "CAFormer", |
| "SwinV2", |
| "ViT", |
| "MobileNetV4", |
| "MobileNetV3", |
| "MobileViT", |
| "ResNet", |
| ] |
|
|
|
|
| def load_rows(): |
| with (DATA / "dbv4_full_models.csv").open(newline="", encoding="utf-8") as f: |
| rows = list(csv.DictReader(f)) |
| for row in rows: |
| row["rank_by_macro_best"] = int(row["rank_by_macro_best"]) |
| for key in ["params_m", "macro_best", "macro_040", "micro_040"]: |
| row[key] = float(row[key]) |
| return sorted(rows, key=lambda item: item["macro_best"], reverse=True) |
|
|
|
|
| def write_data_views(rows): |
| with (DATA / "dbv4_full_models.json").open("w", encoding="utf-8") as f: |
| json.dump(rows, f, indent=2) |
| best = OrderedDict() |
| for row in rows: |
| best.setdefault(row["family"], row) |
| family_rows = [best[family] for family in FAMILY_ORDER if family in best] |
| with (DATA / "featured_models.json").open("w", encoding="utf-8") as f: |
| json.dump({"top5_macro_f1": rows[:5], "best_by_family": family_rows}, f, indent=2) |
| return rows[:5], family_rows |
|
|
|
|
| def draw_banner(): |
| fig, ax = plt.subplots(figsize=(12, 3.6), dpi=160) |
| fig.patch.set_facecolor("#12131c") |
| ax.set_facecolor("#12131c") |
| ax.set_xlim(0, 1200) |
| ax.set_ylim(0, 360) |
| ax.axis("off") |
| colors = ["#ff7ab6", "#8f7cff", "#55d6be", "#ffd166", "#7bdff2"] |
| for i, color in enumerate(colors): |
| ax.add_patch(Rectangle((i * 240, 0), 240, 360, color=color, alpha=0.10 + 0.02 * i, lw=0)) |
| for x in range(0, 1201, 60): |
| ax.plot([x, x], [0, 360], color="white", alpha=0.045, lw=0.7) |
| for y in range(0, 361, 60): |
| ax.plot([0, 1200], [y, y], color="white", alpha=0.045, lw=0.7) |
| for x, y, size, color in [(930, 240, 52, "#ff7ab6"), (1010, 165, 38, "#55d6be"), (1080, 260, 28, "#ffd166"), (845, 120, 32, "#8f7cff")]: |
| ax.scatter([x], [y], s=size * size / 4, marker="*", color=color, alpha=0.85, edgecolors="white", linewidths=0.8) |
| labels = ["tags", "timm", "anime", "F1", "ONNX", "safetensors"] |
| positions = [(720, 78), (812, 285), (935, 65), (1045, 104), (705, 210), (1085, 205)] |
| for (x, y), label, color in zip(positions, labels, colors + ["#ffffff"]): |
| ax.text(x, y, label, color="#f7f7fb", fontsize=13, fontweight="bold", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.35,rounding_size=0.18", fc=color, ec="white", alpha=0.35, lw=0.8)) |
| ax.text(70, 218, "AnimeTimm", color="white", fontsize=54, fontweight="bold", ha="left", va="center") |
| ax.text(75, 157, "timm-based vision models for anime-style image tagging", color="#e6e7f3", fontsize=20, ha="left", va="center") |
| ax.text(76, 108, "A DeepGHS research-and-hobbyist project", color="#ffcfdf", fontsize=15, ha="left", va="center") |
| xs = [720, 800, 890, 980, 1080] |
| ys = [135, 175, 205, 225, 252] |
| ax.plot(xs, ys, color="#ffffff", lw=2.2, alpha=0.78) |
| ax.scatter(xs, ys, s=75, color="#55d6be", edgecolor="white", linewidth=1.2, zorder=3) |
| fig.savefig(ASSETS / "animetimm-banner.png", bbox_inches="tight", pad_inches=0) |
| plt.close(fig) |
|
|
|
|
| def draw_snapshot(top5, family_rows): |
| plot_rows = top5 + [row for row in family_rows if row not in top5] |
| plot_rows = sorted(plot_rows, key=lambda row: row["macro_best"]) |
| fig, ax1 = plt.subplots(figsize=(11.5, 7.2), dpi=160) |
| fig.patch.set_facecolor("#ffffff") |
| ax1.set_facecolor("#fbfbff") |
| y = list(range(len(plot_rows))) |
| labels = [row["model"].replace(".dbv4-full", "") for row in plot_rows] |
| bar_colors = ["#8f7cff" if row in top5 else "#55b9d6" for row in plot_rows] |
| ax1.barh(y, [row["macro_best"] for row in plot_rows], color=bar_colors, alpha=0.88) |
| ax1.set_yticks(y) |
| ax1.set_yticklabels(labels, fontsize=9) |
| ax1.set_xlim(0.28, 0.63) |
| ax1.set_xlabel("Macro@Best F1", fontsize=11) |
| ax1.grid(axis="x", alpha=0.22) |
| for idx, row in enumerate(plot_rows): |
| ax1.text(row["macro_best"] + 0.004, idx, f"{row['macro_best']:.3f}", va="center", fontsize=8, color="#222222") |
| ax2 = ax1.twiny() |
| ax2.scatter([row["params_m"] for row in plot_rows], y, color="#ff7ab6", s=46, edgecolor="white", linewidth=0.8, zorder=5) |
| ax2.set_xscale("log") |
| ax2.set_xlabel("Parameters (M, log scale)", fontsize=11) |
| ax2.set_xlim(12, 900) |
| for idx, row in enumerate(plot_rows): |
| ax2.text(row["params_m"] * 1.04, idx + 0.14, f"{row['params_m']:.1f}M", va="center", fontsize=7, color="#8a1f55") |
| ax1.set_title("AnimeTimm dbv4-full Model Snapshot", fontsize=16, fontweight="bold", pad=14) |
| fig.text(0.125, 0.028, "Purple bars mark the top-5 Macro@Best F1 models. Blue bars are best-per-family representatives not already in top-5.", fontsize=8.5, color="#555555") |
| fig.tight_layout(rect=[0, 0.045, 1, 1]) |
| fig.savefig(ASSETS / "dbv4-full-model-snapshot.png", bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| rows = load_rows() |
| top5, family_rows = write_data_views(rows) |
| draw_banner() |
| draw_snapshot(top5, family_rows) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|