shpotes commited on
Commit
fed60af
1 Parent(s): 427f0cb

add clip-base app

Browse files
Files changed (2) hide show
  1. app.py +81 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import jax
4
+ import streamlit as st
5
+ import transformers
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import AutoTokenizer
8
+ import torch
9
+ from torchvision.io import ImageReadMode, read_image
10
+
11
+
12
+ LOCAL_PATH = snapshot_download("flax-community/medclip")
13
+ sys.path.append(LOCAL_PATH)
14
+
15
+ from src.modeling_medclip import FlaxMedCLIP
16
+
17
+ def prepare_image(image_path, model):
18
+ image = read_image(image_path, mode=ImageReadMode.RGB)
19
+ preprocess = Transform(model.config.vision_config.image_size)
20
+ preprocess = torch.jit.script(preprocess)
21
+ preprocessed_image = preprocess(image)
22
+ pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
23
+ return pixel_values
24
+
25
+ def prepare_text(text, tokenizer):
26
+ return tokenizer(text, return_tensors="np")
27
+
28
+ def save_file_to_disk(uplaoded_file):
29
+ temp_file = os.path.join("/tmp", uplaoded_file.name)
30
+ with open(temp_file, "wb") as f:
31
+ f.write(uploaded_file.getbuffer())
32
+ return temp_file
33
+ @st.cache(
34
+ hash_funcs={
35
+ transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: id,
36
+ FlaxHybridCLIP: id,
37
+ },
38
+ show_spinner=False
39
+ )
40
+ def load_tokenizer_and_model():
41
+ # load the saved model
42
+ tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
43
+ model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
44
+ return tokenizer, model
45
+
46
+ def run_inference(image_path, text, model, tokenizer):
47
+ pixel_values = prepare_image(image_path, model)
48
+ input_text = prepare_text(text, tokenizer)
49
+ model_output = model(
50
+ input_text["input_ids"],
51
+ pixel_values,
52
+ attention_mask=input_text["attention_mask"],
53
+ train=False,
54
+ return_dict=True,
55
+ )
56
+ logits = model_output["logits_per_image"]
57
+ score = jax.nn.sigmoid(logits)[0][0]
58
+ return score
59
+
60
+ tokenizer, model = load_tokenizer_and_model()
61
+ st.title("Caption Scoring")
62
+ uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg"])
63
+ text_input = st.text_input("Type a caption")
64
+ if uploaded_file is not None and text_input:
65
+ local_image_path = None
66
+ try:
67
+ local_image_path = save_file_to_disk(uploaded_file)
68
+ score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
69
+ st.image(
70
+ uploaded_file,
71
+ caption=text_input,
72
+ width=None,
73
+ use_column_width=None,
74
+ clamp=False,
75
+ channels="RGB",
76
+ output_format="auto",
77
+ )
78
+ st.write(f"## Score: {score:.2f}")
79
+ finally:
80
+ if local_image_path:
81
+ os.remove(local_image_path)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ flax==0.3.4
2
+ huggingface-hub==0.0.12
3
+ jax==0.2.17
4
+ streamlit==0.84.1
5
+ torch==1.9.0
6
+ torchvision==0.10.0
7
+ transformers==4.8.2
8
+ watchdog==2.1.3