AIChef / app.py
2006elad's picture
Update app.py
1cb6077
raw
history blame
1.85 kB
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
prefix = "items: "
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 1
}
special_tokens = tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
def skip_special_tokens(text, special_tokens):
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
def generate_recipe(items):
inputs = [prefix + items]
inputs = tokenizer(
inputs,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**generation_kwargs
)
generated = output_ids.squeeze().tolist()
generated_recipe = tokenizer.batch_decode(generated, skip_special_tokens=False)
generated_recipe = target_postprocessing(generated_recipe, special_tokens)
return generated_recipe[0]
# Example usage
input_items = "apple, cucumber"
generated_recipe = generate_recipe(input_items)
print(generated_recipe)