2006elad commited on
Commit
3acf68b
1 Parent(s): 5dcff48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -21
app.py CHANGED
@@ -1,25 +1,73 @@
1
- import streamlit as st
2
- from transformers import T5Tokenizer, T5ForConditionalGeneration
3
 
4
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
5
- tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
6
- model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
7
-
8
- def generate_recipe(input_items):
9
- prefix = "items: "
10
- input_text = prefix + input_items
11
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
12
- output_ids = model.generate(input_ids)
13
- generated_recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
14
- return generated_recipe
15
-
16
- def main():
17
- st.title("Recipe Generation")
18
- input_items = st.text_area("Enter the recipe instructions:")
19
- if st.button("Generate Recipe"):
20
- generated_recipe = generate_recipe(input_items)
21
- st.subheader("Generated Recipe:")
22
- st.text(generated_recipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
- main()
 
1
+ import gradio as gr
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
5
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
6
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
7
+
8
+ prefix = "items: "
9
+ generation_kwargs = {
10
+ "max_length": 512,
11
+ "min_length": 64,
12
+ "no_repeat_ngram_size": 3,
13
+ "do_sample": True,
14
+ "top_k": 60,
15
+ "top_p": 0.95,
16
+ }
17
+
18
+ special_tokens = tokenizer.all_special_tokens
19
+ tokens_map = {
20
+ "<sep>": "--",
21
+ "<section>": "\n",
22
+ }
23
+
24
+ def skip_special_tokens(text, special_tokens):
25
+ for token in special_tokens:
26
+ text = text.replace(token, "")
27
+ return text
28
+
29
+ def target_postprocessing(texts, special_tokens):
30
+ if not isinstance(texts, list):
31
+ texts = [texts]
32
+ new_texts = []
33
+ for text in texts:
34
+ text = skip_special_tokens(text, special_tokens)
35
+ for k, v in tokens_map.items():
36
+ text = text.replace(k, v)
37
+ new_texts.append(text)
38
+ return new_texts
39
+
40
+ def generation_function(texts):
41
+ _inputs = texts if isinstance(texts, list) else [texts]
42
+ inputs = [prefix + inp for inp in _inputs]
43
+ inputs = tokenizer(
44
+ inputs,
45
+ max_length=256,
46
+ padding="max_length",
47
+ truncation=True,
48
+ return_tensors="jax",
49
+ )
50
+ input_ids = inputs.input_ids
51
+ attention_mask = inputs.attention_mask
52
+ output_ids = model.generate(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ **generation_kwargs,
56
+ )
57
+ generated = output_ids.sequences
58
+ generated_recipe = target_postprocessing(
59
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
60
+ special_tokens,
61
+ )
62
+ return generated_recipe[0] # Return the first generated recipe as a string
63
+
64
+ iface = gr.Interface(
65
+ fn=generation_function,
66
+ inputs="text",
67
+ outputs="text",
68
+ title="Recipe Generation",
69
+ description="Generate a recipe based on an input text.",
70
+ )
71
 
72
  if __name__ == "__main__":
73
+ iface.launch()