basebody / app.py
hwajjala's picture
Update app.py
a59bf63 verified
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
# set logging level
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
clip_model, preprocess = clip.load(
CLIP_MODEL_NAME, device="cpu"
)
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
) as f:
hair_text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
),
"r",
) as f:
text_prompts_old = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
),
"rb",
) as f:
lr_old_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
),
"rb",
) as f:
hair_rf_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
def get_text_features(text_prompts):
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu()
return all_text_features
all_text_features = get_text_features(text_prompts)
hair_text_features = get_text_features(hair_text_prompts)
old_text_features = get_text_features(text_prompts_old)
def get_cosine_similarities(image_features, text_features, text_prompts):
cosine_simlarities = softmax(
(text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
return cosine_simlarities
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
base_body_cosine_simlarities = get_cosine_similarities(
image_features, all_text_features, text_prompts
)
hair_cosine_simlarities = get_cosine_similarities(
image_features, hair_text_features, hair_text_prompts
)
old_cosine_simlarities = get_cosine_similarities(
image_features, old_text_features, text_prompts_old
)
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
probabilities = lr_model.predict_proba(
base_body_cosine_simlarities.reshape(1, -1)
)
hair_probabilities = hair_rf_model.predict_proba(
hair_cosine_simlarities.reshape(1, -1)
)
old_lr_probabilities = lr_old_model.predict_proba(
old_cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
result_probabilty = float(probabilities[0][1].round(3))
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
# get decision string
if result_probabilty > 0.9:
# if hair_result_probabilty < 0.24:
# logger.info("hair_result_probabilty < 0.5")
# result_probabilty = hair_result_probabilty
# decision = "AUTO REJECT"
# else:
decision = "AUTO ACCEPT"
elif result_probabilty < 0.2:
logger.info("result_probabilty < 0.2")
decision = "AUTO REJECT"
# elif old_result_probabilty < 0.06:
# logger.info("old_result_probabilty < 0.06")
# result_probabilty = old_result_probabilty
# decision = "AUTO REJECT"
else:
decision = "MODERATION"
logger.info(f"decision: {decision}")
decision_json = json.dumps(
{"is_base_body": result_probabilty, "decision": decision}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
description="""
The model returns the probability of the image being a base body. If
probability > 0.9, the image can be automatically tagged as a base body. If
probability < 0.2, the image can be automatically REJECTED as NOT as base
body. All other cases will be submitted for moderation.
Please flag if you think the decision is wrong.
""",
allow_flagging="manual",
flagging_options=[
": decision should be accept",
": decision should be reject",
": decision should be moderation"
]
)
iface.launch()