import streamlit as st from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration import numpy as np MODEL_NAME_OR_PATH = "t5-base" tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH) model = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH) prefix = "items: " generation_kwargs = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "do_sample": True, "top_k": 60, "top_p": 0.95, "num_return_sequences": 1 } def skip_special_tokens(text, special_tokens): for token in special_tokens: text = text.replace(token, "") return text def target_postprocessing(texts, special_tokens): if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) new_texts.append(text) return new_texts def generate_recipe(items): inputs = [prefix + items] inputs = tokenizer( inputs, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) input_ids = inputs.input_ids attention_mask = inputs.attention_mask output_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, **generation_kwargs ) # Convert output IDs to numpy array output_ids = np.array(output_ids) generated_recipe = tokenizer.batch_decode(output_ids, skip_special_tokens=False) generated_recipe = target_postprocessing(generated_recipe, tokenizer.all_special_tokens) return generated_recipe[0] def main(): st.title("Recipe Generation") items = st.text_input("Enter food items separated by comma (e.g., apple, cucumber):") if st.button("Generate Recipe"): generated_recipe = generate_recipe(items) st.write(generated_recipe) if __name__ == "__main__": main()