Add execution logic
Browse files
app.py
CHANGED
@@ -4,7 +4,9 @@ import torch
|
|
4 |
import logging
|
5 |
import json
|
6 |
import pickle
|
|
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
|
10 |
logger = logging.getLogger("basebody")
|
@@ -14,7 +16,6 @@ TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
|
|
14 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
|
15 |
|
16 |
|
17 |
-
|
18 |
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
|
19 |
|
20 |
with open(
|
@@ -44,9 +45,29 @@ with torch.no_grad():
|
|
44 |
|
45 |
|
46 |
def predict_fn(input_img):
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
iface = gr.Interface(
|
|
|
4 |
import logging
|
5 |
import json
|
6 |
import pickle
|
7 |
+
from PIL import Image
|
8 |
import gradio as gr
|
9 |
+
from scipy.special import softmax
|
10 |
|
11 |
|
12 |
logger = logging.getLogger("basebody")
|
|
|
16 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
|
17 |
|
18 |
|
|
|
19 |
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
|
20 |
|
21 |
with open(
|
|
|
45 |
|
46 |
|
47 |
def predict_fn(input_img):
|
48 |
+
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
|
49 |
+
image = preprocess(
|
50 |
+
input_img
|
51 |
+
).unsqueeze(0)
|
52 |
+
with torch.no_grad():
|
53 |
+
image_features = clip_model.encode_image(image)
|
54 |
+
cosine_simlarities = softmax(
|
55 |
+
(all_text_features @ image_features.cpu().T)
|
56 |
+
.squeeze()
|
57 |
+
.reshape(len(text_prompts), 2, -1),
|
58 |
+
axis=1,
|
59 |
+
)[:, 0, :]
|
60 |
+
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
61 |
+
logger.info(f"cosine_simlarities: {cosine_simlarities}")
|
62 |
+
probabilities = lr_model.predict_proba(
|
63 |
+
cosine_simlarities.reshape(1, -1)
|
64 |
+
)
|
65 |
+
logger.info(f"probabilities: {probabilities}")
|
66 |
+
decision_json = json.dumps(
|
67 |
+
{"is_base_body": float(probabilities[0][1])}
|
68 |
+
).encode("utf-8")
|
69 |
+
logger.info(f"decision_json: {decision_json}")
|
70 |
+
return decision_json
|
71 |
|
72 |
|
73 |
iface = gr.Interface(
|