Bugfix length of text prompts
Browse files
app.py
CHANGED
@@ -78,7 +78,7 @@ with torch.no_grad():
|
|
78 |
hair_text_features = hair_text_features.cpu()
|
79 |
|
80 |
|
81 |
-
def get_cosine_similarities(image_features, text_features):
|
82 |
cosine_simlarities = softmax(
|
83 |
(text_features @ image_features.cpu().T)
|
84 |
.squeeze()
|
@@ -96,10 +96,10 @@ def predict_fn(input_img):
|
|
96 |
with torch.no_grad():
|
97 |
image_features = clip_model.encode_image(image)
|
98 |
base_body_cosine_simlarities = get_cosine_similarities(
|
99 |
-
image_features, all_text_features
|
100 |
)
|
101 |
hair_cosine_simlarities = get_cosine_similarities(
|
102 |
-
image_features, hair_text_features
|
103 |
)
|
104 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
105 |
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
|
|
|
78 |
hair_text_features = hair_text_features.cpu()
|
79 |
|
80 |
|
81 |
+
def get_cosine_similarities(image_features, text_features, text_prompts):
|
82 |
cosine_simlarities = softmax(
|
83 |
(text_features @ image_features.cpu().T)
|
84 |
.squeeze()
|
|
|
96 |
with torch.no_grad():
|
97 |
image_features = clip_model.encode_image(image)
|
98 |
base_body_cosine_simlarities = get_cosine_similarities(
|
99 |
+
image_features, all_text_features, text_prompts
|
100 |
)
|
101 |
hair_cosine_simlarities = get_cosine_similarities(
|
102 |
+
image_features, hair_text_features, hair_text_prompts
|
103 |
)
|
104 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
105 |
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
|