Add additional check
Browse files
app.py
CHANGED
@@ -12,8 +12,10 @@ from scipy.special import softmax
|
|
12 |
logger = logging.getLogger("basebody")
|
13 |
CLIP_MODEL_NAME = "ViT-B/16"
|
14 |
|
|
|
15 |
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
|
16 |
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
|
|
|
17 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
|
18 |
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
|
19 |
|
@@ -36,6 +38,16 @@ with open(
|
|
36 |
) as f:
|
37 |
hair_text_prompts = json.load(f)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
with open(
|
40 |
os.path.join(
|
41 |
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
|
@@ -44,6 +56,16 @@ with open(
|
|
44 |
) as f:
|
45 |
lr_model = pickle.load(f)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
with open(
|
48 |
os.path.join(
|
49 |
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
|
@@ -55,27 +77,21 @@ with open(
|
|
55 |
|
56 |
logger.info("Logistic regression model loaded, coefficients: ")
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
all_text_features =
|
60 |
-
|
61 |
-
|
62 |
-
assert len(prompts) == 2
|
63 |
-
inputs = clip.tokenize(prompts)
|
64 |
-
outputs = clip_model.encode_text(inputs)
|
65 |
-
all_text_features.append(outputs)
|
66 |
-
all_text_features = torch.cat(all_text_features, dim=0)
|
67 |
-
all_text_features = all_text_features.cpu()
|
68 |
-
|
69 |
-
|
70 |
-
hair_text_features = []
|
71 |
-
with torch.no_grad():
|
72 |
-
for k, prompts in hair_text_prompts.items():
|
73 |
-
assert len(prompts) == 2
|
74 |
-
inputs = clip.tokenize(prompts)
|
75 |
-
outputs = clip_model.encode_text(inputs)
|
76 |
-
hair_text_features.append(outputs)
|
77 |
-
hair_text_features = torch.cat(hair_text_features, dim=0)
|
78 |
-
hair_text_features = hair_text_features.cpu()
|
79 |
|
80 |
|
81 |
def get_cosine_similarities(image_features, text_features, text_prompts):
|
@@ -101,6 +117,9 @@ def predict_fn(input_img):
|
|
101 |
hair_cosine_simlarities = get_cosine_similarities(
|
102 |
image_features, hair_text_features, hair_text_prompts
|
103 |
)
|
|
|
|
|
|
|
104 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
105 |
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
|
106 |
probabilities = lr_model.predict_proba(
|
@@ -109,9 +128,13 @@ def predict_fn(input_img):
|
|
109 |
hair_probabilities = hair_rf_model.predict_proba(
|
110 |
hair_cosine_simlarities.reshape(1, -1)
|
111 |
)
|
|
|
|
|
|
|
112 |
logger.info(f"probabilities: {probabilities}")
|
113 |
result_probabilty = float(probabilities[0][1].round(3))
|
114 |
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
|
|
|
115 |
# get decision string
|
116 |
if result_probabilty > 0.77:
|
117 |
if hair_result_probabilty < 0.5:
|
@@ -120,7 +143,11 @@ def predict_fn(input_img):
|
|
120 |
decision = "AUTO REJECT"
|
121 |
else:
|
122 |
decision = "AUTO ACCEPT"
|
123 |
-
elif result_probabilty < 0.
|
|
|
|
|
|
|
|
|
124 |
decision = "AUTO REJECT"
|
125 |
else:
|
126 |
decision = "MODERATION"
|
|
|
12 |
logger = logging.getLogger("basebody")
|
13 |
CLIP_MODEL_NAME = "ViT-B/16"
|
14 |
|
15 |
+
TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
|
16 |
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
|
17 |
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
|
18 |
+
LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1.pkl"
|
19 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
|
20 |
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
|
21 |
|
|
|
38 |
) as f:
|
39 |
hair_text_prompts = json.load(f)
|
40 |
|
41 |
+
|
42 |
+
with open(
|
43 |
+
os.path.join(
|
44 |
+
os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
|
45 |
+
),
|
46 |
+
"r",
|
47 |
+
) as f:
|
48 |
+
text_prompts_old = json.load(f)
|
49 |
+
|
50 |
+
|
51 |
with open(
|
52 |
os.path.join(
|
53 |
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
|
|
|
56 |
) as f:
|
57 |
lr_model = pickle.load(f)
|
58 |
|
59 |
+
|
60 |
+
with open(
|
61 |
+
os.path.join(
|
62 |
+
os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
|
63 |
+
),
|
64 |
+
"rb",
|
65 |
+
) as f:
|
66 |
+
lr_old_model = pickle.load(f)
|
67 |
+
|
68 |
+
|
69 |
with open(
|
70 |
os.path.join(
|
71 |
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
|
|
|
77 |
|
78 |
logger.info("Logistic regression model loaded, coefficients: ")
|
79 |
|
80 |
+
def get_text_features(text_prompts):
|
81 |
+
all_text_features = []
|
82 |
+
with torch.no_grad():
|
83 |
+
for k, prompts in text_prompts.items():
|
84 |
+
assert len(prompts) == 2
|
85 |
+
inputs = clip.tokenize(prompts)
|
86 |
+
outputs = clip_model.encode_text(inputs)
|
87 |
+
all_text_features.append(outputs)
|
88 |
+
all_text_features = torch.cat(all_text_features, dim=0)
|
89 |
+
all_text_features = all_text_features.cpu()
|
90 |
+
return all_text_features
|
91 |
|
92 |
+
all_text_features = get_text_features(text_prompts)
|
93 |
+
hair_text_features = get_text_features(hair_text_prompts)
|
94 |
+
old_text_features = get_text_features(text_prompts_old)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
def get_cosine_similarities(image_features, text_features, text_prompts):
|
|
|
117 |
hair_cosine_simlarities = get_cosine_similarities(
|
118 |
image_features, hair_text_features, hair_text_prompts
|
119 |
)
|
120 |
+
old_cosine_simlarities = get_cosine_similarities(
|
121 |
+
image_features, old_text_features, text_prompts_old
|
122 |
+
)
|
123 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
124 |
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
|
125 |
probabilities = lr_model.predict_proba(
|
|
|
128 |
hair_probabilities = hair_rf_model.predict_proba(
|
129 |
hair_cosine_simlarities.reshape(1, -1)
|
130 |
)
|
131 |
+
old_lr_probabilities = lr_old_model.predict_proba(
|
132 |
+
old_cosine_simlarities.reshape(1, -1)
|
133 |
+
)
|
134 |
logger.info(f"probabilities: {probabilities}")
|
135 |
result_probabilty = float(probabilities[0][1].round(3))
|
136 |
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
|
137 |
+
old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
|
138 |
# get decision string
|
139 |
if result_probabilty > 0.77:
|
140 |
if hair_result_probabilty < 0.5:
|
|
|
143 |
decision = "AUTO REJECT"
|
144 |
else:
|
145 |
decision = "AUTO ACCEPT"
|
146 |
+
elif result_probabilty < 0.2:
|
147 |
+
decision = "AUTO REJECT"
|
148 |
+
elif old_result_probabilty < 0.06:
|
149 |
+
logger.info("old_result_probabilty < 0.06")
|
150 |
+
result_probabilty = old_result_probabilty
|
151 |
decision = "AUTO REJECT"
|
152 |
else:
|
153 |
decision = "MODERATION"
|