File size: 2,209 Bytes
7c856eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from transformers import pipeline

from pprint import pprint
import re


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
# model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)


def skip_special_tokens_and_prettify(text, tokenizer):
    recipe_maps = {"<sep>": "--", "<section>": "\n"}
    recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))

    text = re.sub(
        recipe_map_pattern, 
        lambda m: recipe_maps[m.group()], 
        re.sub("|".join(tokenizer.all_special_tokens), "", text)
    )

    data = {"title": "", "ingredients": [], "directions": []}
    for section in text.split("\n"):
        section = section.strip()
        section = section.strip()
        if section.startswith("title:"):
            data["title"] = section.replace("title:", "").strip()
        elif section.startswith("ingredients:"):
            data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
        elif section.startswith("directions:"):
            data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
        else:
            pass

    return data


def post_generator(output_tensors, tokenizer):
    output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
    texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
    texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
    return texts


# Example 
generate_kwargs = {
    "max_length": 512,
    "min_length": 64,
    "no_repeat_ngram_size": 3,
    "do_sample": True,
    "top_k": 60,
    "top_p": 0.95,
    "num_return_sequences": 3
}
# items = "potato, cheese"
# generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
# generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
# outputs = post_generator(generated, tokenizer)
# pprint(outputs)