Spaces:
Runtime error
Runtime error
File size: 2,007 Bytes
6742988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
""" DalleBart processor """
import jax.numpy as jnp
from .configuration import DalleBartConfig
from .text import TextNormalizer
from .tokenizer import DalleBartTokenizer
from .utils import PretrainedFromWandbMixin
class DalleBartProcessorBase:
def __init__(
self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
):
self.tokenizer = tokenizer
self.normalize_text = normalize_text
self.max_text_length = max_text_length
if normalize_text:
self.text_processor = TextNormalizer()
# create unconditional tokens
uncond = self.tokenizer(
"",
return_tensors="jax",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).data
self.input_ids_uncond = uncond["input_ids"]
self.attention_mask_uncond = uncond["attention_mask"]
def __call__(self, text: str = None):
# check that text is not a string
assert not isinstance(text, str), "text must be a list of strings"
if self.normalize_text:
text = [self.text_processor(t) for t in text]
res = self.tokenizer(
text,
return_tensors="jax",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).data
# tokens used only with super conditioning
n = len(text)
res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
return res
@classmethod
def from_pretrained(cls, *args, **kwargs):
tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
config = DalleBartConfig.from_pretrained(*args, **kwargs)
return cls(tokenizer, config.normalize_text, config.max_text_length)
class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
pass
|