hwajjala commited on
Commit
cb8ddf6
β€’
1 Parent(s): 75dd19d

Add additional check

Browse files
Files changed (1) hide show
  1. app.py +48 -21
app.py CHANGED
@@ -12,8 +12,10 @@ from scipy.special import softmax
12
  logger = logging.getLogger("basebody")
13
  CLIP_MODEL_NAME = "ViT-B/16"
14
 
 
15
  TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
16
  HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
 
17
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
18
  HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
19
 
@@ -36,6 +38,16 @@ with open(
36
  ) as f:
37
  hair_text_prompts = json.load(f)
38
 
 
 
 
 
 
 
 
 
 
 
39
  with open(
40
  os.path.join(
41
  os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
@@ -44,6 +56,16 @@ with open(
44
  ) as f:
45
  lr_model = pickle.load(f)
46
 
 
 
 
 
 
 
 
 
 
 
47
  with open(
48
  os.path.join(
49
  os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
@@ -55,27 +77,21 @@ with open(
55
 
56
  logger.info("Logistic regression model loaded, coefficients: ")
57
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- all_text_features = []
60
- with torch.no_grad():
61
- for k, prompts in text_prompts.items():
62
- assert len(prompts) == 2
63
- inputs = clip.tokenize(prompts)
64
- outputs = clip_model.encode_text(inputs)
65
- all_text_features.append(outputs)
66
- all_text_features = torch.cat(all_text_features, dim=0)
67
- all_text_features = all_text_features.cpu()
68
-
69
-
70
- hair_text_features = []
71
- with torch.no_grad():
72
- for k, prompts in hair_text_prompts.items():
73
- assert len(prompts) == 2
74
- inputs = clip.tokenize(prompts)
75
- outputs = clip_model.encode_text(inputs)
76
- hair_text_features.append(outputs)
77
- hair_text_features = torch.cat(hair_text_features, dim=0)
78
- hair_text_features = hair_text_features.cpu()
79
 
80
 
81
  def get_cosine_similarities(image_features, text_features, text_prompts):
@@ -101,6 +117,9 @@ def predict_fn(input_img):
101
  hair_cosine_simlarities = get_cosine_similarities(
102
  image_features, hair_text_features, hair_text_prompts
103
  )
 
 
 
104
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
105
  logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
106
  probabilities = lr_model.predict_proba(
@@ -109,9 +128,13 @@ def predict_fn(input_img):
109
  hair_probabilities = hair_rf_model.predict_proba(
110
  hair_cosine_simlarities.reshape(1, -1)
111
  )
 
 
 
112
  logger.info(f"probabilities: {probabilities}")
113
  result_probabilty = float(probabilities[0][1].round(3))
114
  hair_result_probabilty = float(hair_probabilities[0][1].round(3))
 
115
  # get decision string
116
  if result_probabilty > 0.77:
117
  if hair_result_probabilty < 0.5:
@@ -120,7 +143,11 @@ def predict_fn(input_img):
120
  decision = "AUTO REJECT"
121
  else:
122
  decision = "AUTO ACCEPT"
123
- elif result_probabilty < 0.4:
 
 
 
 
124
  decision = "AUTO REJECT"
125
  else:
126
  decision = "MODERATION"
 
12
  logger = logging.getLogger("basebody")
13
  CLIP_MODEL_NAME = "ViT-B/16"
14
 
15
+ TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
16
  TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
17
  HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
18
+ LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1.pkl"
19
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
20
  HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
21
 
 
38
  ) as f:
39
  hair_text_prompts = json.load(f)
40
 
41
+
42
+ with open(
43
+ os.path.join(
44
+ os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
45
+ ),
46
+ "r",
47
+ ) as f:
48
+ text_prompts_old = json.load(f)
49
+
50
+
51
  with open(
52
  os.path.join(
53
  os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
 
56
  ) as f:
57
  lr_model = pickle.load(f)
58
 
59
+
60
+ with open(
61
+ os.path.join(
62
+ os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
63
+ ),
64
+ "rb",
65
+ ) as f:
66
+ lr_old_model = pickle.load(f)
67
+
68
+
69
  with open(
70
  os.path.join(
71
  os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
 
77
 
78
  logger.info("Logistic regression model loaded, coefficients: ")
79
 
80
+ def get_text_features(text_prompts):
81
+ all_text_features = []
82
+ with torch.no_grad():
83
+ for k, prompts in text_prompts.items():
84
+ assert len(prompts) == 2
85
+ inputs = clip.tokenize(prompts)
86
+ outputs = clip_model.encode_text(inputs)
87
+ all_text_features.append(outputs)
88
+ all_text_features = torch.cat(all_text_features, dim=0)
89
+ all_text_features = all_text_features.cpu()
90
+ return all_text_features
91
 
92
+ all_text_features = get_text_features(text_prompts)
93
+ hair_text_features = get_text_features(hair_text_prompts)
94
+ old_text_features = get_text_features(text_prompts_old)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  def get_cosine_similarities(image_features, text_features, text_prompts):
 
117
  hair_cosine_simlarities = get_cosine_similarities(
118
  image_features, hair_text_features, hair_text_prompts
119
  )
120
+ old_cosine_simlarities = get_cosine_similarities(
121
+ image_features, old_text_features, text_prompts_old
122
+ )
123
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
124
  logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
125
  probabilities = lr_model.predict_proba(
 
128
  hair_probabilities = hair_rf_model.predict_proba(
129
  hair_cosine_simlarities.reshape(1, -1)
130
  )
131
+ old_lr_probabilities = lr_old_model.predict_proba(
132
+ old_cosine_simlarities.reshape(1, -1)
133
+ )
134
  logger.info(f"probabilities: {probabilities}")
135
  result_probabilty = float(probabilities[0][1].round(3))
136
  hair_result_probabilty = float(hair_probabilities[0][1].round(3))
137
+ old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
138
  # get decision string
139
  if result_probabilty > 0.77:
140
  if hair_result_probabilty < 0.5:
 
143
  decision = "AUTO REJECT"
144
  else:
145
  decision = "AUTO ACCEPT"
146
+ elif result_probabilty < 0.2:
147
+ decision = "AUTO REJECT"
148
+ elif old_result_probabilty < 0.06:
149
+ logger.info("old_result_probabilty < 0.06")
150
+ result_probabilty = old_result_probabilty
151
  decision = "AUTO REJECT"
152
  else:
153
  decision = "MODERATION"