Commit
•
453d528
1
Parent(s):
1d650f1
Update app.py
Browse files
app.py
CHANGED
@@ -31,23 +31,6 @@ if torch.cuda.is_available():
|
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
32 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
33 |
|
34 |
-
client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
|
35 |
-
|
36 |
-
required_settings = rg.Settings(
|
37 |
-
fields=[rg.TextField(name="conversation")],
|
38 |
-
questions=[
|
39 |
-
rg.TextQuestion(name="chosen"),
|
40 |
-
rg.TextQuestion(name="rejected"),
|
41 |
-
],
|
42 |
-
)
|
43 |
-
name = "test"
|
44 |
-
if client.datasets(name=name).exists():
|
45 |
-
dataset: rg.Dataset = client.datasets(name=name)
|
46 |
-
|
47 |
-
else:
|
48 |
-
dataset = rg.Dataset(name=name, settings=required_settings)
|
49 |
-
dataset.create()
|
50 |
-
|
51 |
|
52 |
@spaces.GPU
|
53 |
def generate(
|
@@ -101,7 +84,6 @@ chat_interface = ChatInterface(
|
|
101 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
102 |
),
|
103 |
css=style,
|
104 |
-
rg_dataset=dataset,
|
105 |
cache_examples=False,
|
106 |
additional_inputs=[
|
107 |
gr.Slider(
|
|
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
32 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
@spaces.GPU
|
36 |
def generate(
|
|
|
84 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
85 |
),
|
86 |
css=style,
|
|
|
87 |
cache_examples=False,
|
88 |
additional_inputs=[
|
89 |
gr.Slider(
|