2006elad commited on
Commit
de9b5a0
1 Parent(s): bedad03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -14
app.py CHANGED
@@ -1,16 +1,73 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- def predict(text, name):
5
- pipe = pipeline("text2text-generation", model="flax-community/t5-recipe-generation")
6
- output = pipe(text)[0]
7
- print("Pipeline output:", output) # Print the complete output dictionary
8
- return output["translation_text"]
9
-
10
- demo = gr.Interface(
11
- fn=predict,
12
- inputs=[gr.inputs.Textbox(label="Text"), gr.inputs.Textbox(label="Name")],
13
- outputs=gr.outputs.Textbox(),
14
- allow_sandbox_access=False, # Disable the sandbox access
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
16
- demo.launch()
 
 
 
1
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
2
  import gradio as gr
3
+
4
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
5
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
6
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
7
+
8
+ prefix = "items: "
9
+ generation_kwargs = {
10
+ "max_length": 512,
11
+ "min_length": 64,
12
+ "no_repeat_ngram_size": 3,
13
+ "do_sample": True,
14
+ "top_k": 60,
15
+ "top_p": 0.95
16
+ }
17
+
18
+ special_tokens = tokenizer.all_special_tokens
19
+ tokens_map = {
20
+ "<sep>": "--",
21
+ "<section>": "\n"
22
+ }
23
+
24
+ def skip_special_tokens(text, special_tokens):
25
+ for token in special_tokens:
26
+ text = text.replace(token, "")
27
+ return text
28
+
29
+ def target_postprocessing(texts, special_tokens):
30
+ if not isinstance(texts, list):
31
+ texts = [texts]
32
+ new_texts = []
33
+ for text in texts:
34
+ text = skip_special_tokens(text, special_tokens)
35
+ for k, v in tokens_map.items():
36
+ text = text.replace(k, v)
37
+ new_texts.append(text)
38
+ return new_texts
39
+
40
+ def generation_function(texts):
41
+ _inputs = texts if isinstance(texts, list) else [texts]
42
+ inputs = [prefix + inp for inp in _inputs]
43
+ inputs = tokenizer(
44
+ inputs,
45
+ max_length=256,
46
+ padding="max_length",
47
+ truncation=True,
48
+ return_tensors="jax"
49
+ )
50
+ input_ids = inputs.input_ids
51
+ attention_mask = inputs.attention_mask
52
+ output_ids = model.generate(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ **generation_kwargs
56
+ )
57
+ generated = output_ids.sequences
58
+ generated_recipe = target_postprocessing(
59
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
60
+ special_tokens
61
+ )
62
+ return generated_recipe
63
+
64
+ iface = gr.Interface(
65
+ fn=generation_function,
66
+ inputs="text",
67
+ outputs="text",
68
+ title="Recipe Generation",
69
+ description="Generate a recipe based on an input text."
70
  )
71
+
72
+ if __name__ == "__main__":
73
+ iface.launch()