File size: 2,462 Bytes
f4b1311 194b093 f4b1311 2a62a79 91d8e2e 2a62a79 91d8e2e f4b1311 e1f60ba f4b1311 e1f60ba f4b1311 267519a 4e55465 267519a f4b1311 77c92b5 2a62a79 7a2bd30 2a62a79 f4b1311 91d8e2e f4b1311 267519a f4b1311 e1f60ba 7a2bd30 e1f60ba f4b1311 267519a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "Roblox/basebody_feedback"
)
clip_model, preprocess = clip.load(
CLIP_MODEL_NAME, device="cpu"
)
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu()
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
cosine_simlarities = softmax(
(all_text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {cosine_simlarities}")
probabilities = lr_model.predict_proba(
cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
decision_json = json.dumps(
{"is_base_body": float(probabilities[0][1].round(3))}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
allow_flagging="manual",
flagging_options=[
"probability error > 0.5",
"0.2 < probability error < 0.5",
"probability error < 0.2"
],
flagging_callback=hf_writer
)
iface.launch()
|