Spaces:
Running
Running
import math | |
from dataclasses import dataclass | |
from functools import lru_cache | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import pandas as pd | |
import torch | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import HfHubHTTPError | |
from PIL import Image | |
from torch import Tensor, nn | |
class Heatmap: | |
label: str | |
score: float | |
image: Image.Image | |
class LabelData: | |
names: list[str] | |
rating: list[np.int64] | |
general: list[np.int64] | |
character: list[np.int64] | |
class ImageLabels: | |
caption: str | |
booru: str | |
rating: dict[str, float] | |
general: dict[str, float] | |
character: dict[str, float] | |
def load_labels_hf( | |
repo_id: str, | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> LabelData: | |
try: | |
csv_path = hf_hub_download( | |
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token | |
) | |
csv_path = Path(csv_path).resolve() | |
except HfHubHTTPError as e: | |
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e | |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) | |
tag_data = LabelData( | |
names=df["name"].tolist(), | |
rating=list(np.where(df["category"] == 9)[0]), | |
general=list(np.where(df["category"] == 0)[0]), | |
character=list(np.where(df["category"] == 4)[0]), | |
) | |
return tag_data | |
def mcut_threshold(probs: np.ndarray) -> float: | |
""" | |
Maximum Cut Thresholding (MCut) | |
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
(pp. 172-183). | |
""" | |
probs = probs[probs.argsort()[::-1]] | |
diffs = probs[:-1] - probs[1:] | |
idx = diffs.argmax() | |
thresh = (probs[idx] + probs[idx + 1]) / 2 | |
return float(thresh) | |
def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
# convert to RGB/RGBA if not already (deals with palette images etc.) | |
if image.mode not in ["RGB", "RGBA"]: | |
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
# convert RGBA to RGB with white background | |
if image.mode == "RGBA": | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
return image | |
def pil_pad_square( | |
image: Image.Image, | |
fill: tuple[int, int, int] = (255, 255, 255), | |
) -> Image.Image: | |
w, h = image.size | |
# get the largest dimension so we can pad to a square | |
px = max(image.size) | |
# pad to square with white background | |
canvas = Image.new("RGB", (px, px), fill) | |
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
return canvas | |
def preprocess_image( | |
image: Image.Image, | |
size_px: int | tuple[int, int], | |
upscale: bool = True, | |
) -> Image.Image: | |
""" | |
Preprocess an image to be square and centered on a white background. | |
""" | |
if isinstance(size_px, int): | |
size_px = (size_px, size_px) | |
# ensure RGB and pad to square | |
image = pil_ensure_rgb(image) | |
image = pil_pad_square(image) | |
# resize to target size | |
if image.size[0] < size_px[0] or image.size[1] < size_px[1]: | |
if upscale is False: | |
raise ValueError("Image is smaller than target size, and upscaling is disabled") | |
image = image.resize(size_px, Image.LANCZOS) | |
if image.size[0] > size_px[0] or image.size[1] > size_px[1]: | |
image.thumbnail(size_px, Image.BICUBIC) | |
return image | |
def pil_make_grid( | |
images: list[Image.Image], | |
max_cols: int = 8, | |
padding: int = 4, | |
bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color | |
partial_rows: bool = True, | |
) -> Image.Image: | |
n_cols = min(math.floor(math.sqrt(len(images))), max_cols) | |
n_rows = math.ceil(len(images) / n_cols) | |
# if the final row is not full and partial_rows is False, remove a row | |
if n_cols * n_rows > len(images) and not partial_rows: | |
n_rows -= 1 | |
# assumes all images are same size | |
image_width, image_height = images[0].size | |
canvas_width = ((image_width + padding) * n_cols) + padding | |
canvas_height = ((image_height + padding) * n_rows) + padding | |
canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color) | |
for i, img in enumerate(images): | |
x = (i % n_cols) * (image_width + padding) + padding | |
y = (i // n_cols) * (image_height + padding) + padding | |
canvas.paste(img, (x, y)) | |
return canvas | |
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 | |
kaomojis = [ | |
"0_0", | |
"(o)_(o)", | |
"+_+", | |
"+_-", | |
"._.", | |
"<o>_<o>", | |
"<|>_<|>", | |
"=_=", | |
">_<", | |
"3_3", | |
"6_9", | |
">_o", | |
"@_@", | |
"^_^", | |
"o_o", | |
"u_u", | |
"x_x", | |
"|_|", | |
"||_||", | |
] | |