|
import os |
|
import clip |
|
import torch |
|
import logging |
|
import json |
|
import pickle |
|
from PIL import Image |
|
import gradio as gr |
|
from scipy.special import softmax |
|
|
|
|
|
logger = logging.getLogger("basebody") |
|
CLIP_MODEL_NAME = "ViT-B/16" |
|
|
|
TEXT_PROMPTS_FILE_NAME = "text_prompts.json" |
|
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json" |
|
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl" |
|
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl" |
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
hf_writer = gr.HuggingFaceDatasetSaver( |
|
HF_TOKEN, "Roblox/basebody_feedback" |
|
) |
|
|
|
clip_model, preprocess = clip.load( |
|
CLIP_MODEL_NAME, device="cpu" |
|
) |
|
|
|
with open( |
|
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r" |
|
) as f: |
|
text_prompts = json.load(f) |
|
|
|
with open( |
|
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r" |
|
) as f: |
|
hair_text_prompts = json.load(f) |
|
|
|
with open( |
|
os.path.join( |
|
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME |
|
), |
|
"rb", |
|
) as f: |
|
lr_model = pickle.load(f) |
|
|
|
with open( |
|
os.path.join( |
|
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME |
|
), |
|
"rb", |
|
) as f: |
|
hair_rf_model = pickle.load(f) |
|
|
|
|
|
logger.info("Logistic regression model loaded, coefficients: ") |
|
|
|
|
|
all_text_features = [] |
|
with torch.no_grad(): |
|
for k, prompts in text_prompts.items(): |
|
assert len(prompts) == 2 |
|
inputs = clip.tokenize(prompts) |
|
outputs = clip_model.encode_text(inputs) |
|
all_text_features.append(outputs) |
|
all_text_features = torch.cat(all_text_features, dim=0) |
|
all_text_features = all_text_features.cpu() |
|
|
|
|
|
hair_text_features = [] |
|
with torch.no_grad(): |
|
for k, prompts in hair_text_prompts.items(): |
|
assert len(prompts) == 2 |
|
inputs = clip.tokenize(prompts) |
|
outputs = clip_model.encode_text(inputs) |
|
hair_text_features.append(outputs) |
|
hair_text_features = torch.cat(hair_text_features, dim=0) |
|
hair_text_features = hair_text_features.cpu() |
|
|
|
|
|
def get_cosine_similarities(image_features, text_features, text_prompts): |
|
cosine_simlarities = softmax( |
|
(text_features @ image_features.cpu().T) |
|
.squeeze() |
|
.reshape(len(text_prompts), 2, -1), |
|
axis=1, |
|
)[:, 0, :] |
|
return cosine_simlarities |
|
|
|
|
|
def predict_fn(input_img): |
|
input_img = Image.fromarray(input_img.astype("uint8"), "RGB") |
|
image = preprocess( |
|
input_img |
|
).unsqueeze(0) |
|
with torch.no_grad(): |
|
image_features = clip_model.encode_image(image) |
|
base_body_cosine_simlarities = get_cosine_similarities( |
|
image_features, all_text_features, text_prompts |
|
) |
|
hair_cosine_simlarities = get_cosine_similarities( |
|
image_features, hair_text_features, hair_text_prompts |
|
) |
|
|
|
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}") |
|
probabilities = lr_model.predict_proba( |
|
base_body_cosine_simlarities.reshape(1, -1) |
|
) |
|
hair_probabilities = hair_rf_model.predict_proba( |
|
hair_cosine_simlarities.reshape(1, -1) |
|
) |
|
logger.info(f"probabilities: {probabilities}") |
|
result_probabilty = float(probabilities[0][1].round(3)) |
|
hair_result_probabilty = float(hair_probabilities[0][1].round(3)) |
|
|
|
if result_probabilty > 0.77: |
|
if hair_result_probabilty < 0.5: |
|
result_probabilty = hair_result_probabilty |
|
decision = "AUTO REJECT" |
|
else: |
|
decision = "AUTO ACCEPT" |
|
elif result_probabilty < 0.4: |
|
decision = "AUTO REJECT" |
|
else: |
|
decision = "MODERATION" |
|
logger.info(f"decision: {decision}") |
|
decision_json = json.dumps( |
|
{"is_base_body": result_probabilty, "decision": decision} |
|
).encode("utf-8") |
|
logger.info(f"decision_json: {decision_json}") |
|
return decision_json |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_fn, |
|
inputs="image", |
|
outputs="text", |
|
description=""" |
|
The model returns the probability of the image being a base body. If |
|
probability > 0.77, the image can be automatically tagged as a base body. If |
|
probability < 0.4, the image can be automatically REJECTED as NOT as base |
|
body. All other cases will be submitted for moderation. |
|
|
|
Please flag if you think the decision is wrong. |
|
""", |
|
|
|
allow_flagging="manual", |
|
flagging_options=[ |
|
": decision should be accept", |
|
": decision should be reject", |
|
": decision should be moderation" |
|
], |
|
flagging_callback=hf_writer |
|
) |
|
iface.launch() |
|
|