from torchvision.io import read_image, ImageReadMode import torch import numpy as np from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode from PIL import Image import os import streamlit as st from transformers import MBart50TokenizerFast class Toc: def __init__(self): self._items = [] self._placeholder = None def title(self, text): self._markdown(text, "h1") def header(self, text): self._markdown(text, "h2", " " * 2) def subheader(self, text): self._markdown(text, "h3", " " * 4) def subsubheader(self, text): self._markdown(text, "h4", " " * 8) def placeholder(self, sidebar=False): self._placeholder = st.sidebar.empty() if sidebar else st.empty() def generate(self): if self._placeholder: self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True) def _markdown(self, text, level, space=""): key = "".join(filter(str.isalnum, text)).lower() st.markdown(f"<{level} id='{key}'>{text}", unsafe_allow_html=True) self._items.append(f"{space}* {text}") tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") class Transform(torch.nn.Module): def __init__(self, image_size): super().__init__() self.transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x = self.transforms(x) return x transform = Transform(224) def get_transformed_image(image): if isinstance(image, np.ndarray) and image.shape[-1] == 3: image = image.transpose(2, 0, 1) image = torch.tensor(image) return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy() def read_markdown(path, parent="./sections/"): with open(os.path.join(parent, path)) as f: return f.read() language_mapping = { "en": "en_XX", "de": "de_DE", "fr": "fr_XX", "es": "es_XX" } code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", }