liujch1998 commited on
Commit
7802ab3
1 Parent(s): 39aed69

Barebone demo

Browse files
Files changed (1) hide show
  1. app.py +127 -10
app.py CHANGED
@@ -1,15 +1,132 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
 
 
 
 
 
 
5
 
6
- def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  gr.Interface(
11
- predict,
12
- inputs=gr.inputs.Image(label="Upload hot dog candidate", type="filepath"),
13
- outputs=gr.outputs.Label(num_top_classes=2),
14
- title="Hot Dog? Or Not?",
15
- ).launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import transformers
4
 
5
+ def reduce_sum(value, mask, axis=None):
6
+ if axis is None:
7
+ return torch.sum(value * mask)
8
+ return torch.sum(value * mask, axis)
9
+ def reduce_mean(value, mask, axis=None):
10
+ if axis is None:
11
+ return torch.sum(value * mask) / torch.sum(mask)
12
+ return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
13
 
14
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
15
+
16
+ max_input_len = 256
17
+ max_output_len = 32
18
+ m = 10
19
+ top_p = 0.5
20
+
21
+ class InteractiveRainier:
22
+
23
+ def __init__(self):
24
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
25
+ self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
26
+ self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device)
27
+ self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none')
28
+
29
+ def parse_choices(self, s):
30
+ '''
31
+ s: serialized_choices '(A) ... (B) ... (C) ...'
32
+ '''
33
+ choices = []
34
+ key = 'A' if s.find('(A)') != -1 else 'a'
35
+ while True:
36
+ pos = s.find(f'({chr(ord(key) + 1)})')
37
+ if pos == -1:
38
+ break
39
+ choice = s[3:pos]
40
+ s = s[pos:]
41
+ choice = choice.strip(' ')
42
+ choices.append(choice)
43
+ key = chr(ord(key) + 1)
44
+ choice = s[3:]
45
+ choice = choice.strip(' ')
46
+ choices.append(choice)
47
+ return choices
48
+
49
+ def run(self, question):
50
+ tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
51
+ knowledges_ids = self.rainier_model.generate(
52
+ input_ids=tokenized.input_ids,
53
+ max_length=max_output_len + 1,
54
+ min_length=3,
55
+ do_sample=True,
56
+ num_return_sequences=m,
57
+ top_p=top_p,
58
+ ) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS])
59
+ knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
60
+ knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
61
+ knowledges = list(set(knowledges))
62
+ knowledges = [''] + knowledges
63
+
64
+ prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges]
65
+ choices = self.parse_choices(question.split('\\n')[1].strip(' '))
66
+ prompts = [prompt.lower() for prompt in prompts]
67
+ choices = [choice.lower() for choice in choices]
68
+ answer_logitss = []
69
+ for choice in choices:
70
+ tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L)
71
+ tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
72
+ pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id)
73
+ tokenized_choices.input_ids[pad_mask] = -100
74
+ tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L)
75
+
76
+ with torch.no_grad():
77
+ logits = self.qa_model(
78
+ input_ids=tokenized_prompts.input_ids,
79
+ attention_mask=tokenized_prompts.attention_mask,
80
+ labels=tokenized_choices.input_ids,
81
+ ).logits # (1+K, L, V)
82
+
83
+ losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1))
84
+ losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L)
85
+ losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K)
86
+ answer_logitss.append(-losses)
87
+ answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C)
88
+ answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
89
+
90
+ # Ensemble
91
+ knowless_pred = answer_probss[0, :].argmax(dim=0).item()
92
+ knowless_pred = choices[knowless_pred]
93
+
94
+ answer_probs = answer_probss.max(dim=0).values # (C)
95
+ knowful_pred = answer_probs.argmax(dim=0).item()
96
+ knowful_pred = choices[knowful_pred]
97
+ selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
98
+ selected_knowledge = knowledges[selected_knowledge_ix]
99
+
100
+ return {
101
+ 'question': question,
102
+ 'knowledges': knowledges,
103
+ 'knowless_pred': knowless_pred,
104
+ 'knowful_pred': knowful_pred,
105
+ 'selected_knowledge': selected_knowledge,
106
+ }
107
+
108
+ rainier = InteractiveRainier()
109
+
110
+ def predict(question, choices):
111
+ result = rainier.run(f'{question} \\n {choices}')
112
+ output = ''
113
+ output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
114
+ output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
115
+ output += '\n'
116
+ output += f'All generated knowledges:\n'
117
+ for knowledge in result['knowledges']:
118
+ output += f' {knowledge}\n'
119
+ output += '\n'
120
+ output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
121
+ return output
122
+
123
+ input_question = gr.inputs.Textbox(label='Question:')
124
+ input_choices = gr.inputs.TextBox(label='Choices:')
125
+ output_text = gr.outputs.Textbox(label='Output')
126
 
127
  gr.Interface(
128
+ fn=predict,
129
+ inputs=[input_question, input_choices],
130
+ outputs=output_text,
131
+ title="Rainier",
132
+ ).launch()