2006elad commited on
Commit
587e383
·
1 Parent(s): 1cb6077

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -1,8 +1,8 @@
1
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
2
 
3
- MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
4
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
5
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
6
 
7
  prefix = "items: "
8
  generation_kwargs = {
@@ -15,12 +15,6 @@ generation_kwargs = {
15
  "num_return_sequences": 1
16
  }
17
 
18
- special_tokens = tokenizer.all_special_tokens
19
- tokens_map = {
20
- "<sep>": "--",
21
- "<section>": "\n"
22
- }
23
-
24
  def skip_special_tokens(text, special_tokens):
25
  for token in special_tokens:
26
  text = text.replace(token, "")
@@ -32,8 +26,6 @@ def target_postprocessing(texts, special_tokens):
32
  new_texts = []
33
  for text in texts:
34
  text = skip_special_tokens(text, special_tokens)
35
- for k, v in tokens_map.items():
36
- text = text.replace(k, v)
37
  new_texts.append(text)
38
  return new_texts
39
 
@@ -44,18 +36,18 @@ def generate_recipe(items):
44
  max_length=256,
45
  padding="max_length",
46
  truncation=True,
47
- return_tensors="pt"
48
  )
49
  input_ids = inputs.input_ids
50
  attention_mask = inputs.attention_mask
 
51
  output_ids = model.generate(
52
  input_ids=input_ids,
53
  attention_mask=attention_mask,
54
  **generation_kwargs
55
  )
56
- generated = output_ids.squeeze().tolist()
57
- generated_recipe = tokenizer.batch_decode(generated, skip_special_tokens=False)
58
- generated_recipe = target_postprocessing(generated_recipe, special_tokens)
59
  return generated_recipe[0]
60
 
61
  # Example usage
 
1
+ from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
2
 
3
+ MODEL_NAME_OR_PATH = "t5-base"
4
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
5
+ model = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
6
 
7
  prefix = "items: "
8
  generation_kwargs = {
 
15
  "num_return_sequences": 1
16
  }
17
 
 
 
 
 
 
 
18
  def skip_special_tokens(text, special_tokens):
19
  for token in special_tokens:
20
  text = text.replace(token, "")
 
26
  new_texts = []
27
  for text in texts:
28
  text = skip_special_tokens(text, special_tokens)
 
 
29
  new_texts.append(text)
30
  return new_texts
31
 
 
36
  max_length=256,
37
  padding="max_length",
38
  truncation=True,
39
+ return_tensors="jax"
40
  )
41
  input_ids = inputs.input_ids
42
  attention_mask = inputs.attention_mask
43
+
44
  output_ids = model.generate(
45
  input_ids=input_ids,
46
  attention_mask=attention_mask,
47
  **generation_kwargs
48
  )
49
+ generated_recipe = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
50
+ generated_recipe = target_postprocessing(generated_recipe, tokenizer.all_special_tokens)
 
51
  return generated_recipe[0]
52
 
53
  # Example usage