Eitan177 commited on
Commit
a80dee5
1 Parent(s): ffec387

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ impor os
4
+ import conch
5
+ from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize
6
+
7
+ import torch
8
+ from huggingface_hub import login
9
+ hf=os.getenv('hf')
10
+ login(hf)
11
+ # Load the pretrained model and transforms
12
+ #model = create_model("hf-hub:MahmoodLab/UNI", pretrained=True, init_values=1e-5, dynamic_img_size=True)
13
+ # Load the model
14
+ @st.cache_resource
15
+ def load_model():
16
+ model, preprocess = create_model_from_pretrained('conch_ViT-B-16', "hf_hub:MahmoodLab/conch",hf_auth_token=hf)
17
+ return model, preprocess
18
+
19
+ model, preprocess = load_model()
20
+
21
+ st.title("CONCH - Image Captioning and Retrieval")
22
+
23
+ # Upload an image
24
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
25
+
26
+ if uploaded_file:
27
+ # Display uploaded image
28
+ image = Image.open(uploaded_file)
29
+ st.image(image, caption='Uploaded Image', use_column_width=True)
30
+
31
+ # Preprocess and get image embeddings
32
+ image = preprocess(image).unsqueeze(0)
33
+
34
+ with torch.no_grad():
35
+ image_embs = model.encode_image(image, proj_contrast=True, normalize=True)
36
+
37
+ st.write("Image embeddings generated successfully.")
38
+
39
+ # Text retrieval example
40
+ # Number of inputs to display (can be dynamic)
41
+ num_inputs = st.number_input("How many text inputs?", min_value=1, max_value=10, value=3)
42
+
43
+ # List to store user inputs
44
+ input_list = []
45
+
46
+ # Dynamically generate the text input fields
47
+ for i in range(num_inputs):
48
+ user_input = st.text_input(f"Input Text {i+1}")
49
+ input_list.append(user_input)
50
+
51
+ populated_status = ["Populated" if text.strip() else "Empty" for text in input_list]
52
+ if "Populated" in populated_status:
53
+ # Tokenize the text
54
+ tokenizer = get_tokenizer() # load tokenizer
55
+ text_tokens = tokenize(texts=input_list, tokenizer=tokenizer) # tokenize the text
56
+ text_embs = model.encode_text(text_tokens)
57
+
58
+
59
+ #with torch.no_grad():
60
+ # text_embs = model.encode_text(tokens, proj_contrast=True, normalize=True)
61
+
62
+ st.write("Text embeddings generated successfully.")
63
+
64
+ # Perform similarity check
65
+ similarity = torch.cosine_similarity(image_embs, text_embs)
66
+ st.write("Similarity check completed.")
67
+ st.write(similarity)