Spaces:
Build error
Build error
File size: 4,748 Bytes
e030ae6 9df4338 e030ae6 9df4338 e030ae6 9df4338 c0cf912 9df4338 7517145 9df4338 e030ae6 9df4338 1f57142 bd89db0 1f57142 9df4338 e030ae6 445b401 1f57142 445b401 e030ae6 9df4338 e030ae6 c88b60d 2a2e9cb 9df4338 e030ae6 1f57142 e030ae6 5e02842 85054d0 861be40 85054d0 861be40 85054d0 861be40 c24bdb1 85054d0 e030ae6 be193e7 9df4338 e030ae6 9df4338 e030ae6 9df4338 861be40 e030ae6 9df4338 e030ae6 9df4338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import gradio as gr
import os
model_name = 'eliolio/bart-finetuned-yelpreviews'
bert_model_name = 'eliolio/bert-correlation-yelpreviews'
access_token = os.environ.get('private_token')
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, use_auth_token=access_token
)
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=access_token
)
bert_tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_auth_token=access_token
)
bert_model = AutoModelForSequenceClassification.from_pretrained(
bert_model_name, use_auth_token=access_token
)
def correlation_score(table, review):
# Compute the correlation score
args = ((table, review))
inputs = bert_tokenizer(*args, padding=True, max_length=128, truncation=True, return_tensors="pt")
logits = bert_model(**inputs).logits
probs = logits.softmax(dim=-1)
return {
"correlated": probs[:, 1].item(),
"uncorrelated": probs[:, 0].item()
}
def create_prompt(stars, useful, funny, cool):
return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
def postprocess(review):
dot = review.rfind('.')
return review[:dot+1]
def generate_reviews(stars, useful, funny, cool):
text = create_prompt(stars, useful, funny, cool)
inputs = tokenizer(text, return_tensors='pt')
out = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
do_sample=True,
num_return_sequences=3,
temperature=1.2,
top_p=0.9
)
reviews = []
scores = []
for review in out:
reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
for review in reviews:
scores.append(
correlation_score(text[17:], review)
)
return reviews[0], reviews[1], reviews[2], scores[0], scores[1], scores[2]
css = """
#ctr {text-align: center;}
#btn {color: white; background: linear-gradient( 90deg, rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
"""
md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Yelp reviews with BART-base ⭐⭐⭐</h1>
This space demonstrates how synthetic data generation can be performed on natural language columns, as found in the Yelp reviews dataset.
| review id | stars | useful | funny | cool | text |
|:---:|:---:|:---:|:---:|:---:|:---:|
| 0 | 5 | 1 | 0 | 1 | "Wow! Yummy, different, delicious. Our favorite is the lamb curry and korma. With 10 different kinds of naan!!! Don't let the outside deter you (because we almost changed our minds)...go in and try something new! You'll be glad you did!"
The model is a fine-tuned version of [facebook/bart-base](https://huggingface.com/facebook/bart-base) on Yelp reviews with the following input-output pairs:
- **Input**: "Generate review: stars: 5, useful: 1, funny: 0, cool: 1"
- **Output**: "Wow! Yummy, different, delicious. Our favorite is the lamb curry and korma. With 10 different kinds of naan!!! Don't let the outside deter you (because we almost changed our minds)...go in and try something new! You'll be glad you did!"
"""
resources = """## Resources
- Code for training: [github repo](https://github.com/EliottZemour/yelp-reviews/)
- The Yelp reviews dataset can be found in json format [here](https://www.yelp.com/dataset)."""
demo = gr.Blocks(css=css)
with demo:
with gr.Row():
gr.Markdown(md_text)
with gr.Row():
stars = gr.inputs.Slider(minimum=0, maximum=5,
step=1, default=0, label="stars")
useful = gr.inputs.Slider(
minimum=0, maximum=5, step=1, default=0, label="useful")
funny = gr.inputs.Slider(minimum=0, maximum=5,
step=1, default=0, label="funny")
cool = gr.inputs.Slider(minimum=0, maximum=5,
step=1, default=0, label="cool")
with gr.Row():
button = gr.Button("Generate reviews !", elem_id='btn')
with gr.Row():
output1 = gr.Textbox(label="Review #1")
output2 = gr.Textbox(label="Review #2")
output3 = gr.Textbox(label="Review #3")
with gr.Row():
score1 = gr.Label(label="Correlation score #1")
score2 = gr.Label(label="Correlation score #2")
score3 = gr.Label(label="Correlation score #3")
with gr.Row():
gr.Markdown(resources)
button.click(
fn=generate_reviews,
inputs=[stars, useful, funny, cool],
outputs=[output1, output2, output3, score1, score2, score3]
)
demo.launch()
|