EliottZemour commited on
Commit
9df4338
1 Parent(s): c88b60d

add correlation score

Browse files
Files changed (1) hide show
  1. app.py +50 -11
app.py CHANGED
@@ -1,23 +1,48 @@
1
  import torch
2
  import transformers
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import gradio as gr
5
  import os
6
 
7
  model_name = 'eliolio/bart-finetuned-yelpreviews'
 
8
 
9
  access_token = os.environ.get('private_token')
10
 
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def create_prompt(stars, useful, funny, cool):
15
  return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
16
 
 
17
  def postprocess(review):
18
  dot = review.rfind('.')
19
  return review[:dot+1]
20
 
 
21
  def generate_reviews(stars, useful, funny, cool):
22
  text = create_prompt(stars, useful, funny, cool)
23
  inputs = tokenizer(text, return_tensors='pt')
@@ -30,10 +55,15 @@ def generate_reviews(stars, useful, funny, cool):
30
  top_p=0.9
31
  )
32
  reviews = []
 
33
  for review in out:
34
  reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
 
 
 
 
 
35
 
36
- return reviews[0], reviews[1], reviews[2]
37
 
38
  css = """
39
  #ctr {text-align: center;}
@@ -65,12 +95,16 @@ demo = gr.Blocks(css=css)
65
  with demo:
66
  with gr.Row():
67
  gr.Markdown(md_text)
68
-
69
  with gr.Row():
70
- stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars")
71
- useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful")
72
- funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny")
73
- cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool")
 
 
 
 
74
  with gr.Row():
75
  button = gr.Button("Generate reviews !", elem_id='btn')
76
 
@@ -79,13 +113,18 @@ with demo:
79
  output2 = gr.Textbox(label="Review #2")
80
  output3 = gr.Textbox(label="Review #3")
81
 
 
 
 
 
 
82
  with gr.Row():
83
  gr.Markdown(resources)
84
 
85
  button.click(
86
  fn=generate_reviews,
87
  inputs=[stars, useful, funny, cool],
88
- outputs=[output1, output2, output3]
89
  )
90
 
91
- demo.launch()
 
1
  import torch
2
  import transformers
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
4
  import gradio as gr
5
  import os
6
 
7
  model_name = 'eliolio/bart-finetuned-yelpreviews'
8
+ bert_model_name = 'eliolio/bert-correlation-yelpreviews'
9
 
10
  access_token = os.environ.get('private_token')
11
 
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(
13
+ model_name, use_auth_token=access_token
14
+ )
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ model_name, use_auth_token=access_token
17
+ )
18
+
19
+ bert_tokenizer = AutoTokenizer.from_pretrained(
20
+ bert_model_name, use_auth_token=access_token
21
+ )
22
+ bert_model = AutoModelForSequenceClassification.from_pretrained(
23
+ bert_model_name, use_auth_token=access_token
24
+ )
25
+
26
+
27
+ def correlation_score(table, review):
28
+ # Compute the correlation score
29
+ args = ((table, review))
30
+ inputs = bert_tokenizer(*args, padding=True, max_length=128, truncation=True, return_tensors="pt")
31
+ logits = model(**inputs).logits
32
+ probs = logits.softmax(dim=-1)
33
+ return {
34
+ "correlation": probs[:, 1].item()
35
+ }
36
 
37
  def create_prompt(stars, useful, funny, cool):
38
  return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
39
 
40
+
41
  def postprocess(review):
42
  dot = review.rfind('.')
43
  return review[:dot+1]
44
 
45
+
46
  def generate_reviews(stars, useful, funny, cool):
47
  text = create_prompt(stars, useful, funny, cool)
48
  inputs = tokenizer(text, return_tensors='pt')
 
55
  top_p=0.9
56
  )
57
  reviews = []
58
+ scores = []
59
  for review in out:
60
  reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
61
+ scores.append(
62
+ correlation_score(text[17:], review)
63
+ )
64
+
65
+ return reviews[0], reviews[1], reviews[2], scores[0], scores[1], scores[2]
66
 
 
67
 
68
  css = """
69
  #ctr {text-align: center;}
 
95
  with demo:
96
  with gr.Row():
97
  gr.Markdown(md_text)
98
+
99
  with gr.Row():
100
+ stars = gr.inputs.Slider(minimum=0, maximum=5,
101
+ step=1, default=0, label="stars")
102
+ useful = gr.inputs.Slider(
103
+ minimum=0, maximum=5, step=1, default=0, label="useful")
104
+ funny = gr.inputs.Slider(minimum=0, maximum=5,
105
+ step=1, default=0, label="funny")
106
+ cool = gr.inputs.Slider(minimum=0, maximum=5,
107
+ step=1, default=0, label="cool")
108
  with gr.Row():
109
  button = gr.Button("Generate reviews !", elem_id='btn')
110
 
 
113
  output2 = gr.Textbox(label="Review #2")
114
  output3 = gr.Textbox(label="Review #3")
115
 
116
+ with gr.Row():
117
+ score1 = gr.Label(label="Correlation score #1")
118
+ score2 = gr.Label(label="Correlation score #2")
119
+ score3 = gr.Label(label="Correlation score #3")
120
+
121
  with gr.Row():
122
  gr.Markdown(resources)
123
 
124
  button.click(
125
  fn=generate_reviews,
126
  inputs=[stars, useful, funny, cool],
127
+ outputs=[output1, output2, output3, score1, score2, score3]
128
  )
129
 
130
+ demo.launch()