kassemsabeh commited on
Commit
42466a5
·
1 Parent(s): b0cd2d3

Add application

Browse files
Files changed (2) hide show
  1. app.py +66 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
5
+
6
+ model_id = 'ksabeh/t5-base-qpave'
7
+ max_input_length = 512
8
+ max_target_length = 20
9
+ auth_token = os.environ.get('TOKEN')
10
+
11
+ model = T5ForConditionalGeneration.from_pretrained(model_id, use_auth_token=auth_token)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
13
+
14
+ def predict(cg_attribute, text, fg_attribute, category):
15
+ input = f"{fg_attribute}: {text}"
16
+ model_input = tokenizer(input, max_length=max_input_length, truncation=True,
17
+ padding="max_length")
18
+ model_input = {k:torch.unsqueeze(torch.tensor(v),dim=0) for k,v in model_input.items()}
19
+ predictions = model.generate(**model_input, num_beams=4, do_sample=True, max_length=10)
20
+ return tokenizer.batch_decode(predictions, skip_special_tokens=True)[0]
21
+
22
+ # iface = gr.Interface(
23
+ # predict,
24
+ # inputs=["text", "text", "text", "text"],
25
+ # outputs=['text'],
26
+ # title="QPAVE",
27
+ # examples=[["Arriba Salsa Garlic and Cilantro, 16 oz", "Food"],
28
+ # ["MV Verholen Black GPS Ball Mount for BMW K1200S K1200R K1300S K1300R Black GPS Ball Mount VER-4901-10181", "Toys"],
29
+ # ["Mitsubishi 3000GT License Plate Frame (Zince Metal)", "Automotive"],
30
+ # ["Fun Fire Truck Pinata Personalized", "Toys"],
31
+ # ["White Chocolate Caramel Gourmet Popcorn Kelly", "Food"]
32
+ # ]
33
+ # )
34
+
35
+ # iface.launch()
36
+
37
+ demo = gr.Interface(
38
+ predict,
39
+ [
40
+ gr.Textbox(
41
+ label = "Coarse-grained Attribute",
42
+ info = "The coarse-grained attribute name",
43
+ lines = 1,
44
+ ),
45
+ gr.Textbox(
46
+ label = "Context",
47
+ info = "The value of the coarse-grained attribute",
48
+ lines = 1,
49
+ ),
50
+ gr.Textbox(
51
+ label = "Fine-grained Attribute",
52
+ info = "The target fine-grained attribute name",
53
+ lines = 1,
54
+ ),
55
+ gr.Textbox(
56
+ label = "Category",
57
+ info = "The product category",
58
+ lines = 1,
59
+ )
60
+ ],
61
+ "dataframe",
62
+ title="QPAVE",
63
+ examples=[["Processor", "3ghz intel core i5", "Brand Name", "Computers & Tablets"]
64
+ ],
65
+ cache_examples = True
66
+ )
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch