2006elad commited on
Commit
1cb6077
·
1 Parent(s): 741c4df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -23
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
2
- import gradio as gr
3
 
4
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
5
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
@@ -12,7 +11,8 @@ generation_kwargs = {
12
  "no_repeat_ngram_size": 3,
13
  "do_sample": True,
14
  "top_k": 60,
15
- "top_p": 0.95
 
16
  }
17
 
18
  special_tokens = tokenizer.all_special_tokens
@@ -37,15 +37,14 @@ def target_postprocessing(texts, special_tokens):
37
  new_texts.append(text)
38
  return new_texts
39
 
40
- def generation_function(texts):
41
- _inputs = texts if isinstance(texts, list) else [texts]
42
- inputs = [prefix + inp for inp in _inputs]
43
  inputs = tokenizer(
44
  inputs,
45
  max_length=256,
46
  padding="max_length",
47
  truncation=True,
48
- return_tensors="jax"
49
  )
50
  input_ids = inputs.input_ids
51
  attention_mask = inputs.attention_mask
@@ -54,21 +53,12 @@ def generation_function(texts):
54
  attention_mask=attention_mask,
55
  **generation_kwargs
56
  )
57
- generated = output_ids.sequences
58
- generated_recipe = target_postprocessing(
59
- tokenizer.batch_decode(generated, skip_special_tokens=False),
60
- special_tokens
61
- )
62
- return generated_recipe[0] # Return the first generated recipe as a string
63
-
64
-
65
- iface = gr.Interface(
66
- fn=generation_function,
67
- inputs="text",
68
- outputs="text",
69
- title="Recipe Generation",
70
- description="Generate a recipe based on an input text."
71
- )
72
 
73
- if __name__ == "__main__":
74
- iface.launch()
 
 
 
1
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
 
2
 
3
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
4
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
 
11
  "no_repeat_ngram_size": 3,
12
  "do_sample": True,
13
  "top_k": 60,
14
+ "top_p": 0.95,
15
+ "num_return_sequences": 1
16
  }
17
 
18
  special_tokens = tokenizer.all_special_tokens
 
37
  new_texts.append(text)
38
  return new_texts
39
 
40
+ def generate_recipe(items):
41
+ inputs = [prefix + items]
 
42
  inputs = tokenizer(
43
  inputs,
44
  max_length=256,
45
  padding="max_length",
46
  truncation=True,
47
+ return_tensors="pt"
48
  )
49
  input_ids = inputs.input_ids
50
  attention_mask = inputs.attention_mask
 
53
  attention_mask=attention_mask,
54
  **generation_kwargs
55
  )
56
+ generated = output_ids.squeeze().tolist()
57
+ generated_recipe = tokenizer.batch_decode(generated, skip_special_tokens=False)
58
+ generated_recipe = target_postprocessing(generated_recipe, special_tokens)
59
+ return generated_recipe[0]
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Example usage
62
+ input_items = "apple, cucumber"
63
+ generated_recipe = generate_recipe(input_items)
64
+ print(generated_recipe)