Spaces:
Restarting
on
Zero
Restarting
on
Zero
import torch | |
class LatentRebatch: | |
def INPUT_TYPES(s): | |
return {"required": { "latents": ("LATENT",), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
}} | |
RETURN_TYPES = ("LATENT",) | |
INPUT_IS_LIST = True | |
OUTPUT_IS_LIST = (True, ) | |
FUNCTION = "rebatch" | |
CATEGORY = "latent/batch" | |
def get_batch(latents, list_ind, offset): | |
'''prepare a batch out of the list of latents''' | |
samples = latents[list_ind]['samples'] | |
shape = samples.shape | |
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') | |
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: | |
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") | |
if mask.shape[0] < samples.shape[0]: | |
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] | |
if 'batch_index' in latents[list_ind]: | |
batch_inds = latents[list_ind]['batch_index'] | |
else: | |
batch_inds = [x+offset for x in range(shape[0])] | |
return samples, mask, batch_inds | |
def get_slices(indexable, num, batch_size): | |
'''divides an indexable object into num slices of length batch_size, and a remainder''' | |
slices = [] | |
for i in range(num): | |
slices.append(indexable[i*batch_size:(i+1)*batch_size]) | |
if num * batch_size < len(indexable): | |
return slices, indexable[num * batch_size:] | |
else: | |
return slices, None | |
def slice_batch(batch, num, batch_size): | |
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] | |
return list(zip(*result)) | |
def cat_batch(batch1, batch2): | |
if batch1[0] is None: | |
return batch2 | |
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] | |
return result | |
def rebatch(self, latents, batch_size): | |
batch_size = batch_size[0] | |
output_list = [] | |
current_batch = (None, None, None) | |
processed = 0 | |
for i in range(len(latents)): | |
# fetch new entry of list | |
#samples, masks, indices = self.get_batch(latents, i) | |
next_batch = self.get_batch(latents, i, processed) | |
processed += len(next_batch[2]) | |
# set to current if current is None | |
if current_batch[0] is None: | |
current_batch = next_batch | |
# add previous to list if dimensions do not match | |
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: | |
sliced, _ = self.slice_batch(current_batch, 1, batch_size) | |
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) | |
current_batch = next_batch | |
# cat if everything checks out | |
else: | |
current_batch = self.cat_batch(current_batch, next_batch) | |
# add to list if dimensions gone above target batch size | |
if current_batch[0].shape[0] > batch_size: | |
num = current_batch[0].shape[0] // batch_size | |
sliced, remainder = self.slice_batch(current_batch, num, batch_size) | |
for i in range(num): | |
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) | |
current_batch = remainder | |
#add remainder | |
if current_batch[0] is not None: | |
sliced, _ = self.slice_batch(current_batch, 1, batch_size) | |
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) | |
#get rid of empty masks | |
for s in output_list: | |
if s['noise_mask'].mean() == 1.0: | |
del s['noise_mask'] | |
return (output_list,) | |
class ImageRebatch: | |
def INPUT_TYPES(s): | |
return {"required": { "images": ("IMAGE",), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
}} | |
RETURN_TYPES = ("IMAGE",) | |
INPUT_IS_LIST = True | |
OUTPUT_IS_LIST = (True, ) | |
FUNCTION = "rebatch" | |
CATEGORY = "image/batch" | |
def rebatch(self, images, batch_size): | |
batch_size = batch_size[0] | |
output_list = [] | |
all_images = [] | |
for img in images: | |
for i in range(img.shape[0]): | |
all_images.append(img[i:i+1]) | |
for i in range(0, len(all_images), batch_size): | |
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) | |
return (output_list,) | |
NODE_CLASS_MAPPINGS = { | |
"RebatchLatents": LatentRebatch, | |
"RebatchImages": ImageRebatch, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"RebatchLatents": "Rebatch Latents", | |
"RebatchImages": "Rebatch Images", | |
} | |