import streamlit as st import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer from PIL import ( Image, ImageFont, ImageDraw ) import re from examples import EXAMPLES import meta from utils import ( remote_css, local_css, load_image, pure_comma_separation ) class TextGeneration: def __init__(self): self.debug = True self.dummy_output = { 'directions': [ 'peel the potato and slice thinly.', 'place in a microwave safe dish.', 'cover with plastic wrap and microwave on high for 5 minutes.', 'remove from the microwave and sprinkle with cheese.', 'return to the microwave for 1 minute or until cheese is melted.', 'return to the microwave for 1 minute or until cheese is melted.', 'return to the microwave for 1 minute or until cheese is melted.' ], 'ingredients': [ '1 potato', '1 slice cheese', '1 potato', '1 slice cheese' '1 potato', '1 slice cheese', '1 slice cheese', '1 potato', '1 slice cheese' '1 potato', '1 slice cheese', ], 'title': 'Cheese Potatoes' } self.tokenizer = None self.generator = None self.task = "text2text-generation" self.model_name_or_path = "flax-community/t5-recipe-generation" self.h1_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 100) self.h2_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 50) self.p_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Regular.ttf", 30) set_seed(42) def _skip_special_tokens_and_prettify(self, text): recipe_maps = {"": "--", "
": "\n"} recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys())) text = re.sub( recipe_map_pattern, lambda m: recipe_maps[m.group()], re.sub("|".join(self.tokenizer.all_special_tokens), "", text) ) data = {"title": "", "ingredients": [], "directions": []} for section in text.split("\n"): section = section.strip() if section.startswith("title:"): data["title"] = " ".join( [w.strip().capitalize() for w in section.replace("title:", "").strip().split() if w.strip()] ) elif section.startswith("ingredients:"): data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')] elif section.startswith("directions:"): data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')] else: pass return data def load(self): if not self.debug: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path) def prepare_frame(self, recipe, frame): im_editable = ImageDraw.Draw(frame) # Title ws, hs = 120, 500 im_editable.text( (ws, hs), recipe["title"], (61, 61, 70), font=self.h1_font, ) # Ingredients hs = hs + 200 im_editable.text( (ws, hs), "Ingredients", (61, 61, 70), font=self.h2_font, ) hs = hs + 80 im_editable.text( (ws + 10, hs), "\n".join([f"- {item}" for item in recipe["ingredients"]]), (61, 61, 70), font=self.p_font, ) # Directions hs = hs + 400 im_editable.text( (ws, hs), "Directions", (61, 61, 70), font=self.h2_font, ) hs = hs + 80 im_editable.text( (ws + 10, hs), "\n".join([f"- {item}" for item in recipe["directions"]]), (61, 61, 70), font=self.p_font, ) return frame def generate(self, items, generation_kwargs): print(generation_kwargs) if not self.debug: generation_kwargs["num_return_sequences"] = 1 # generation_kwargs["return_full_text"] = False generation_kwargs["return_tensors"] = True generation_kwargs["return_text"] = False generated_ids = self.generator( items, **generation_kwargs, )[0]["generated_token_ids"] recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False) recipe = self._skip_special_tokens_and_prettify(recipe) return recipe return self.dummy_output def generate_frame(self, recipe): frame = load_image("asset/images/recipe-post.png") return self.prepare_frame(recipe, frame) @st.cache(allow_output_mutation=True) def load_text_generator(): generator = TextGeneration() generator.load() return generator chef_top = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "do_sample": True, "top_k": 60, "top_p": 0.95, "num_return_sequences": 1 } chef_beam = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "early_stopping": True, "num_beams": 5, "length_penalty": 1.5, "num_return_sequences": 1 } def main(): st.set_page_config( page_title="Chef Transformer", page_icon="🍲", layout="wide", initial_sidebar_state="expanded" ) generator = load_text_generator() local_css("asset/css/style.css") st.sidebar.image(load_image("asset/images/chef-transformer-transparent.png"), width=310) st.sidebar.title("Choose your own chef") chef = st.sidebar.selectbox("Chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"]) st.markdown(meta.HEADER_INFO) prompts = list(EXAMPLES.keys()) + ["Custom"] prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) if prompt == "Custom": prompt_box = "" else: prompt_box = EXAMPLES[prompt] items = st.text_input( 'Add custom ingredients here (separated by `,`): ', pure_comma_separation(prompt_box, return_list=False), key="custom_keywords", max_chars=1000) items = pure_comma_separation(items, return_list=False) entered_items = st.empty() if st.button('Get Recipe!'): entered_items.markdown("**Generate recipe for:** " + items) with st.spinner("Generating recipe..."): gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam generated_recipe = generator.generate(items, gen_kw) recipe_post = generator.generate_frame(generated_recipe) col1, col2, col3 = st.beta_columns([1, 6, 1]) with col1: st.write("") with col2: st.image( recipe_post, # width=500, caption="Your recipe", use_column_width="auto", output_format="PNG" ) with col3: st.write("") if __name__ == '__main__': main()