AhmedSSabir commited on
Commit
3d34bd5
1 Parent(s): 6c54368

Upload demo.py

Browse files
Files changed (1) hide show
  1. demo.py +56 -0
demo.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from doctest import OutputChecker
3
+ import sys
4
+ import argparse
5
+ import torch
6
+ import re
7
+ import os
8
+ import gradio as gr
9
+ from sentence_transformers import SentenceTransformer, util
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ from lm_scorer.models.auto import AutoLMScorer as LMScorer
12
+ from sentence_transformers import SentenceTransformer, util
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
+ model = SentenceTransformer('stsb-distilbert-base', device=device)
17
+ batch_size = 1
18
+ scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
19
+
20
+
21
+ def cos_sim(a, b):
22
+ return np.inner(a, b) / (np.linalg.norm(a) * (np.linalg.norm(b)))
23
+
24
+
25
+
26
+ def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
27
+ caption = caption
28
+ visual_context_label= visual_context_label
29
+ visual_context_prob = visual_context_prob
30
+ caption_emb = model.encode(caption, convert_to_tensor=True)
31
+ visual_context_label_emb = model.encode(visual_context_label, convert_to_tensor=True)
32
+
33
+
34
+ sim = cosine_scores = util.pytorch_cos_sim(caption_emb, visual_context_label_emb)
35
+ sim = sim.cpu().numpy()
36
+ sim = str(sim)[1:-1]
37
+ sim = str(sim)[1:-1]
38
+
39
+ LM = scorer.sentence_score(caption, reduce="mean")
40
+ score = pow(float(LM),pow((1-float(sim))/(1+ float(sim)),1-float(visual_context_prob)))
41
+
42
+
43
+ #return {"LM": float(LM)/1, "sim": float(sim)/1, "score": float(score)/1 }
44
+ return {"init hypothesis": float(LM)/1, "Visual Belief Revision": float(score)/1 }
45
+ #return LM, sim, score
46
+
47
+
48
+
49
+ demo = gr.Interface(
50
+ fn=Visual_re_ranker,
51
+ description="Demo for Belief Revision based Caption Re-ranker with Visual Semantic Information",
52
+ inputs=[gr.Textbox(value="a city street filled with traffic at night") , gr.Textbox(value="traffic"), gr.Textbox(value="0.7458009")],
53
+ #outputs=[gr.Textbox(value="Language Model Score") , gr.Textbox(value="Semantic Similarity Score"), gr.Textbox(value="Belief revision score via visual context")],
54
+ outputs="label",
55
+ )
56
+ demo.launch()