File size: 1,598 Bytes
587e383
de9b5a0
587e383
 
 
de9b5a0
 
 
 
 
 
 
 
1cb6077
 
de9b5a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb6077
 
de9b5a0
 
 
 
 
587e383
de9b5a0
 
 
587e383
de9b5a0
 
 
 
 
587e383
 
1cb6077
de9b5a0
1cb6077
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration

MODEL_NAME_OR_PATH = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
model = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)

prefix = "items: "
generation_kwargs = {
    "max_length": 512,
    "min_length": 64,
    "no_repeat_ngram_size": 3,
    "do_sample": True,
    "top_k": 60,
    "top_p": 0.95,
    "num_return_sequences": 1
}

def skip_special_tokens(text, special_tokens):
    for token in special_tokens:
        text = text.replace(token, "")
    return text

def target_postprocessing(texts, special_tokens):
    if not isinstance(texts, list):
        texts = [texts]
    new_texts = []
    for text in texts:
        text = skip_special_tokens(text, special_tokens)
        new_texts.append(text)
    return new_texts

def generate_recipe(items):
    inputs = [prefix + items]
    inputs = tokenizer(
        inputs,
        max_length=256,
        padding="max_length",
        truncation=True,
        return_tensors="jax"
    )
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        **generation_kwargs
    )
    generated_recipe = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    generated_recipe = target_postprocessing(generated_recipe, tokenizer.all_special_tokens)
    return generated_recipe[0]

# Example usage
input_items = "apple, cucumber"
generated_recipe = generate_recipe(input_items)
print(generated_recipe)