README / scripts /build_org_card.py
narugo1992's picture
Update AnimeTimm organization card
9dca5a9 verified
#!/usr/bin/env python3
"""Build the AnimeTimm organization card assets from checked-in model data."""
import csv
import json
from collections import OrderedDict
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
ROOT = Path(__file__).resolve().parents[1]
ASSETS = ROOT / "assets"
DATA = ROOT / "data"
ASSETS.mkdir(exist_ok=True)
FAMILY_ORDER = [
"ConvNeXt",
"EVA",
"CAFormer",
"SwinV2",
"ViT",
"MobileNetV4",
"MobileNetV3",
"MobileViT",
"ResNet",
]
def load_rows():
with (DATA / "dbv4_full_models.csv").open(newline="", encoding="utf-8") as f:
rows = list(csv.DictReader(f))
for row in rows:
row["rank_by_macro_best"] = int(row["rank_by_macro_best"])
for key in ["params_m", "macro_best", "macro_040", "micro_040"]:
row[key] = float(row[key])
return sorted(rows, key=lambda item: item["macro_best"], reverse=True)
def write_data_views(rows):
with (DATA / "dbv4_full_models.json").open("w", encoding="utf-8") as f:
json.dump(rows, f, indent=2)
best = OrderedDict()
for row in rows:
best.setdefault(row["family"], row)
family_rows = [best[family] for family in FAMILY_ORDER if family in best]
with (DATA / "featured_models.json").open("w", encoding="utf-8") as f:
json.dump({"top5_macro_f1": rows[:5], "best_by_family": family_rows}, f, indent=2)
return rows[:5], family_rows
def draw_banner():
fig, ax = plt.subplots(figsize=(12, 3.6), dpi=160)
fig.patch.set_facecolor("#12131c")
ax.set_facecolor("#12131c")
ax.set_xlim(0, 1200)
ax.set_ylim(0, 360)
ax.axis("off")
colors = ["#ff7ab6", "#8f7cff", "#55d6be", "#ffd166", "#7bdff2"]
for i, color in enumerate(colors):
ax.add_patch(Rectangle((i * 240, 0), 240, 360, color=color, alpha=0.10 + 0.02 * i, lw=0))
for x in range(0, 1201, 60):
ax.plot([x, x], [0, 360], color="white", alpha=0.045, lw=0.7)
for y in range(0, 361, 60):
ax.plot([0, 1200], [y, y], color="white", alpha=0.045, lw=0.7)
for x, y, size, color in [(930, 240, 52, "#ff7ab6"), (1010, 165, 38, "#55d6be"), (1080, 260, 28, "#ffd166"), (845, 120, 32, "#8f7cff")]:
ax.scatter([x], [y], s=size * size / 4, marker="*", color=color, alpha=0.85, edgecolors="white", linewidths=0.8)
labels = ["tags", "timm", "anime", "F1", "ONNX", "safetensors"]
positions = [(720, 78), (812, 285), (935, 65), (1045, 104), (705, 210), (1085, 205)]
for (x, y), label, color in zip(positions, labels, colors + ["#ffffff"]):
ax.text(x, y, label, color="#f7f7fb", fontsize=13, fontweight="bold", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.35,rounding_size=0.18", fc=color, ec="white", alpha=0.35, lw=0.8))
ax.text(70, 218, "AnimeTimm", color="white", fontsize=54, fontweight="bold", ha="left", va="center")
ax.text(75, 157, "timm-based vision models for anime-style image tagging", color="#e6e7f3", fontsize=20, ha="left", va="center")
ax.text(76, 108, "A DeepGHS research-and-hobbyist project", color="#ffcfdf", fontsize=15, ha="left", va="center")
xs = [720, 800, 890, 980, 1080]
ys = [135, 175, 205, 225, 252]
ax.plot(xs, ys, color="#ffffff", lw=2.2, alpha=0.78)
ax.scatter(xs, ys, s=75, color="#55d6be", edgecolor="white", linewidth=1.2, zorder=3)
fig.savefig(ASSETS / "animetimm-banner.png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
def draw_snapshot(top5, family_rows):
plot_rows = top5 + [row for row in family_rows if row not in top5]
plot_rows = sorted(plot_rows, key=lambda row: row["macro_best"])
fig, ax1 = plt.subplots(figsize=(11.5, 7.2), dpi=160)
fig.patch.set_facecolor("#ffffff")
ax1.set_facecolor("#fbfbff")
y = list(range(len(plot_rows)))
labels = [row["model"].replace(".dbv4-full", "") for row in plot_rows]
bar_colors = ["#8f7cff" if row in top5 else "#55b9d6" for row in plot_rows]
ax1.barh(y, [row["macro_best"] for row in plot_rows], color=bar_colors, alpha=0.88)
ax1.set_yticks(y)
ax1.set_yticklabels(labels, fontsize=9)
ax1.set_xlim(0.28, 0.63)
ax1.set_xlabel("Macro@Best F1", fontsize=11)
ax1.grid(axis="x", alpha=0.22)
for idx, row in enumerate(plot_rows):
ax1.text(row["macro_best"] + 0.004, idx, f"{row['macro_best']:.3f}", va="center", fontsize=8, color="#222222")
ax2 = ax1.twiny()
ax2.scatter([row["params_m"] for row in plot_rows], y, color="#ff7ab6", s=46, edgecolor="white", linewidth=0.8, zorder=5)
ax2.set_xscale("log")
ax2.set_xlabel("Parameters (M, log scale)", fontsize=11)
ax2.set_xlim(12, 900)
for idx, row in enumerate(plot_rows):
ax2.text(row["params_m"] * 1.04, idx + 0.14, f"{row['params_m']:.1f}M", va="center", fontsize=7, color="#8a1f55")
ax1.set_title("AnimeTimm dbv4-full Model Snapshot", fontsize=16, fontweight="bold", pad=14)
fig.text(0.125, 0.028, "Purple bars mark the top-5 Macro@Best F1 models. Blue bars are best-per-family representatives not already in top-5.", fontsize=8.5, color="#555555")
fig.tight_layout(rect=[0, 0.045, 1, 1])
fig.savefig(ASSETS / "dbv4-full-model-snapshot.png", bbox_inches="tight")
plt.close(fig)
def main():
rows = load_rows()
top5, family_rows = write_data_views(rows)
draw_banner()
draw_snapshot(top5, family_rows)
if __name__ == "__main__":
main()