File size: 4,693 Bytes
f4b1311 194b093 f4b1311 2a62a79 91d8e2e 2a62a79 91d8e2e f4b1311 2e89562 2de9666 2e89562 2de9666 f4b1311 e1f60ba f4b1311 e1f60ba f4b1311 2de9666 f4b1311 2de9666 267519a 4e55465 267519a f4b1311 2de9666 bb70cd9 2de9666 77c92b5 2a62a79 2de9666 bb70cd9 2de9666 bb70cd9 2de9666 2a62a79 2de9666 2a62a79 2de9666 2a62a79 29c69b2 2de9666 29c69b2 2de9666 75dd19d 2de9666 509301a 29c69b2 2a62a79 29c69b2 2a62a79 f4b1311 91d8e2e f4b1311 267519a f4b1311 edd1aeb 509301a edd1aeb 29c69b2 edd1aeb e1f60ba 7a2bd30 9c9b891 7a2bd30 e1f60ba f4b1311 267519a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "Roblox/basebody_feedback"
)
clip_model, preprocess = clip.load(
CLIP_MODEL_NAME, device="cpu"
)
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
) as f:
hair_text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
),
"rb",
) as f:
hair_rf_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu()
hair_text_features = []
with torch.no_grad():
for k, prompts in hair_text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
hair_text_features.append(outputs)
hair_text_features = torch.cat(hair_text_features, dim=0)
hair_text_features = hair_text_features.cpu()
def get_cosine_similarities(image_features, text_features, text_prompts):
cosine_simlarities = softmax(
(text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
return cosine_simlarities
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
base_body_cosine_simlarities = get_cosine_similarities(
image_features, all_text_features, text_prompts
)
hair_cosine_simlarities = get_cosine_similarities(
image_features, hair_text_features, hair_text_prompts
)
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
probabilities = lr_model.predict_proba(
base_body_cosine_simlarities.reshape(1, -1)
)
hair_probabilities = hair_rf_model.predict_proba(
hair_cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
result_probabilty = float(probabilities[0][1].round(3))
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
# get decision string
if result_probabilty > 0.77:
if hair_result_probabilty < 0.5:
logger.info("hair_result_probabilty < 0.5")
result_probabilty = hair_result_probabilty
decision = "AUTO REJECT"
else:
decision = "AUTO ACCEPT"
elif result_probabilty < 0.4:
decision = "AUTO REJECT"
else:
decision = "MODERATION"
logger.info(f"decision: {decision}")
decision_json = json.dumps(
{"is_base_body": result_probabilty, "decision": decision}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
description="""
The model returns the probability of the image being a base body. If
probability > 0.77, the image can be automatically tagged as a base body. If
probability < 0.4, the image can be automatically REJECTED as NOT as base
body. All other cases will be submitted for moderation.
Please flag if you think the decision is wrong.
""",
allow_flagging="manual",
flagging_options=[
": decision should be accept",
": decision should be reject",
": decision should be moderation"
],
flagging_callback=hf_writer
)
iface.launch()
|