diffusers-sdxl-controlnet
/
examples
/research_projects
/multi_token_textual_inversion
/multi_token_clip.py
""" | |
The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing | |
a photo of <concept>_0 <concept>_1 ... and so on | |
and instead just do | |
a photo of <concept> | |
which gets translated to the above. This needs to work for both inference and training. | |
For inference, | |
the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with | |
it's underlying vectors | |
For training, | |
we would want to abstract away some logic like | |
1. Adding tokens | |
2. Updating gradient mask | |
3. Saving embeddings | |
to our Util class here. | |
so | |
TODO: | |
1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x | |
2. have mechanism for adding tokens x | |
3. have mech for saving emebeddings x | |
4. get mask to update x | |
5. Loading tokens from embedding x | |
6. Integrate to training x | |
7. Test | |
""" | |
import copy | |
import random | |
from transformers import CLIPTokenizer | |
class MultiTokenCLIPTokenizer(CLIPTokenizer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.token_map = {} | |
def try_adding_tokens(self, placeholder_token, *args, **kwargs): | |
num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | |
" `placeholder_token` that is not already in the tokenizer." | |
) | |
def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): | |
output = [] | |
if num_vec_per_token == 1: | |
self.try_adding_tokens(placeholder_token, *args, **kwargs) | |
output.append(placeholder_token) | |
else: | |
output = [] | |
for i in range(num_vec_per_token): | |
ith_token = placeholder_token + f"_{i}" | |
self.try_adding_tokens(ith_token, *args, **kwargs) | |
output.append(ith_token) | |
# handle cases where there is a new placeholder token that contains the current placeholder token but is larger | |
for token in self.token_map: | |
if token in placeholder_token: | |
raise ValueError( | |
f"The tokenizer already has placeholder token {token} that can get confused with" | |
f" {placeholder_token}keep placeholder tokens independent" | |
) | |
self.token_map[placeholder_token] = output | |
def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): | |
""" | |
Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder | |
can encode them | |
vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119 | |
where shuffling tokens were found to force the model to learn the concepts more descriptively. | |
""" | |
if isinstance(text, list): | |
output = [] | |
for i in range(len(text)): | |
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) | |
return output | |
for placeholder_token in self.token_map: | |
if placeholder_token in text: | |
tokens = self.token_map[placeholder_token] | |
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] | |
if vector_shuffle: | |
tokens = copy.copy(tokens) | |
random.shuffle(tokens) | |
text = text.replace(placeholder_token, " ".join(tokens)) | |
return text | |
def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): | |
return super().__call__( | |
self.replace_placeholder_tokens_in_text( | |
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | |
), | |
*args, | |
**kwargs, | |
) | |
def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): | |
return super().encode( | |
self.replace_placeholder_tokens_in_text( | |
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | |
), | |
*args, | |
**kwargs, | |
) | |