Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import PIL.Image | |
import huggingface_hub | |
import numpy as np | |
import onnxruntime as rt | |
import pandas as pd | |
import torch | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoProcessor | |
from . import dbimutils | |
from .singleton import Singleton | |
import torch | |
from clip_interrogator import Config, Interrogator | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
class Models(object): | |
# WD14 models | |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
# CLIP models | |
VIT_H_14_MODEL_REPO = "ViT-H-14/laion2b_s32b_b79k" # Stable Diffusion 2.X | |
VIT_L_14_MODEL_REPO = "ViT-L-14/openai" # Stable Diffusion 1.X | |
def __init__(self): | |
pass | |
def load_clip_model(cls, model_repo): | |
config = Config() | |
config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
config.blip_offload = False if torch.cuda.is_available() else True | |
config.chunk_size = 2048 | |
config.flavor_intermediate_count = 512 | |
config.blip_num_beams = 64 | |
config.clip_model_name = model_repo | |
ci = Interrogator(config) | |
return ci | |
def __getattr__(self, item): | |
if item in self.__dict__: | |
return getattr(self, item) | |
print(f"Loading {item}...") | |
if item in ('clip_vit_h_14_model',): | |
self.clip_vit_h_14_model = self.load_clip_model(self.VIT_H_14_MODEL_REPO) | |
if item in ('clip_vit_l_14_model',): | |
self.clip_vit_l_14_model = self.load_clip_model(self.VIT_L_14_MODEL_REPO) | |
if item in ('swinv2_model',): | |
self.swinv2_model = self.load_model(self.SWIN_MODEL_REPO, self.MODEL_FILENAME) | |
if item in ('convnext_model',): | |
self.convnext_model = self.load_model(self.CONV_MODEL_REPO, self.MODEL_FILENAME) | |
if item in ('vit_model',): | |
self.vit_model = self.load_model(self.VIT_MODEL_REPO, self.MODEL_FILENAME) | |
if item in ('convnextv2_model',): | |
self.convnextv2_model = self.load_model(self.CONV2_MODEL_REPO, self.MODEL_FILENAME) | |
if item in ('git_model', 'git_processor'): | |
self.git_model, self.git_processor = self.load_git_model() | |
if item in ('tag_names', 'rating_indexes', 'general_indexes', 'character_indexes'): | |
self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = self.load_w14_labels() | |
return getattr(self, item) | |
def load_git_model(cls): | |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
return model, processor | |
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession: | |
path = huggingface_hub.hf_hub_download( | |
model_repo, model_filename, | |
) | |
model = rt.InferenceSession(path) | |
return model | |
def load_w14_labels(cls) -> list[str]: | |
path = huggingface_hub.hf_hub_download( | |
cls.CONV2_MODEL_REPO, cls.LABEL_FILENAME | |
) | |
df = pd.read_csv(path) | |
tag_names = df["name"].tolist() | |
rating_indexes = list(np.where(df["category"] == 9)[0]) | |
general_indexes = list(np.where(df["category"] == 0)[0]) | |
character_indexes = list(np.where(df["category"] == 4)[0]) | |
return [tag_names, rating_indexes, general_indexes, character_indexes] | |
models = Models.instance() | |
def clip_image2text(image, mode_type='best', model_name='vit_h_14'): | |
image = image.convert('RGB') | |
model = getattr(models, f'clip_{model_name}_model') | |
if mode_type == 'classic': | |
prompt = model.interrogate_classic(image) | |
elif mode_type == 'fast': | |
prompt = model.interrogate_fast(image) | |
elif mode_type == 'negative': | |
prompt = model.interrogate_negative(image) | |
else: | |
prompt = model.interrogate(image) # default to best | |
return prompt | |
def git_image2text(input_image, max_length=50): | |
image = input_image.convert('RGB') | |
pixel_values = models.git_processor(images=image, return_tensors="pt").to(device).pixel_values | |
generated_ids = models.git_model.to(device).generate(pixel_values=pixel_values, max_length=max_length) | |
generated_caption = models.git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_caption | |
def w14_image2text( | |
image: PIL.Image.Image, | |
model_name: str, | |
general_threshold: float, | |
character_threshold: float, | |
): | |
tag_names: list[str] = models.tag_names | |
rating_indexes: list[np.int64] = models.rating_indexes | |
general_indexes: list[np.int64] = models.general_indexes | |
character_indexes: list[np.int64] = models.character_indexes | |
model_name = "{}_model".format(model_name.lower()) | |
model = getattr(models, model_name) | |
_, height, width, _ = model.get_inputs()[0].shape | |
# Alpha to white | |
image = image.convert("RGBA") | |
new_image = PIL.Image.new("RGBA", image.size, "WHITE") | |
new_image.paste(image, mask=image) | |
image = new_image.convert("RGB") | |
image = np.asarray(image) | |
# PIL RGB to OpenCV BGR | |
image = image[:, :, ::-1] | |
image = dbimutils.make_square(image, height) | |
image = dbimutils.smart_resize(image, height) | |
image = image.astype(np.float32) | |
image = np.expand_dims(image, 0) | |
input_name = model.get_inputs()[0].name | |
label_name = model.get_outputs()[0].name | |
probs = model.run([label_name], {input_name: image})[0] | |
labels = list(zip(tag_names, probs[0].astype(float))) | |
# First 4 labels are actually ratings: pick one with argmax | |
ratings_names = [labels[i] for i in rating_indexes] | |
rating = dict(ratings_names) | |
# Then we have general tags: pick any where prediction confidence > threshold | |
general_names = [labels[i] for i in general_indexes] | |
general_res = [x for x in general_names if x[1] > general_threshold] | |
general_res = dict(general_res) | |
# Everything else is characters: pick any where prediction confidence > threshold | |
character_names = [labels[i] for i in character_indexes] | |
character_res = [x for x in character_names if x[1] > character_threshold] | |
character_res = dict(character_res) | |
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) | |
a = ( | |
", ".join(list(b.keys())) | |
.replace("_", " ") | |
.replace("(", "\(") | |
.replace(")", "\)") | |
) | |
c = ", ".join(list(b.keys())) | |
d = " ".join(list(b.keys())) | |
return a, c, d, rating, character_res, general_res | |