RAG-Diffusion / matrix.py
znchen
Add application file
8fb99cf
raw
history blame
10.1 kB
import colorsys # Polygon regions.
from PIL import Image, ImageChops
from pprint import pprint
import cv2 # Polygon regions.
import numpy as np
import PIL
import torch
SPLROW = ";"
SPLCOL = ","
KEYROW = "ADDROW"
KEYCOL = "ADDCOL"
KEYBASE = "ADDBASE"
KEYCOMM = "ADDCOMM"
KEYBRK = "BREAK"
NLN = "\n"
DKEYINOUT = { # Out/in, horizontal/vertical or row/col first.
("out",False): KEYROW,
("in",False): KEYCOL,
("out",True): KEYCOL,
("in",True): KEYROW,
}
fidentity = lambda x: x
ffloatd = lambda c: (lambda x: floatdef(x,c))
fspace = lambda x: " {} ".format(x)
fcountbrk = lambda x: x.count(KEYBRK)
fint = lambda x: int(x)
def floatdef(x, vdef):
"""Attempt conversion to float, use default value on error.
Mainly for empty ratios, double commas.
"""
try:
return float(x)
except ValueError:
print("'{}' is not a number, converted to {}".format(x,vdef))
return vdef
class Region():
"""Specific Region used to split a layer to single prompts."""
def __init__(self, st, ed, base, breaks):
"""Range with start and end values, base weight and breaks count for context splitting."""
self.start = st # Range for the cell (cols only).
self.end = ed
self.base = base # How much of the base prompt is applied (difference).
self.breaks = breaks # How many unrelated breaks the prompt contains.
class Row():
"""Row containing cell refs and its own ratio range."""
def __init__(self, st, ed, cols):
"""Range with start and end values, base weight and breaks count for context splitting."""
self.start = st # Range for the row.
self.end = ed
self.cols = cols # List of cells.
def is_l2(l):
return isinstance(l[0],list)
def l2_count(l):
cnt = 0
for row in l:
cnt + cnt + len(row)
return cnt
def list_percentify(l):
"""
Convert each row in L2 to relative part of 100%.
Also works on L1, applying once globally.
"""
lret = []
if is_l2(l):
for row in l:
# row2 = [float(v) for v in row]
row2 = [v / sum(row) for v in row]
lret.append(row2)
else:
row = l[:]
# row2 = [float(v) for v in row]
row2 = [v / sum(row) for v in row]
lret = row2
return lret
def list_cumsum(l):
"""
Apply cumsum to L2 per row, ie newl[n] = l[0:n].sum .
Works with L1.
Actually edits l inplace, idc.
"""
lret = []
if is_l2(l):
for row in l:
for (i,v) in enumerate(row):
if i > 0:
row[i] = v + row[i - 1]
lret.append(row)
else:
row = l[:]
for (i,v) in enumerate(row):
if i > 0:
row[i] = v + row[i - 1]
lret = row
return lret
def list_rangify(l):
"""
Merge every 2 elems in L2 to a range, starting from 0.
"""
lret = []
if is_l2(l):
for row in l:
row2 = [0] + row
row3 = []
for i in range(len(row2) - 1):
row3.append([row2[i],row2[i + 1]])
lret.append(row3)
else:
row2 = [0] + l
row3 = []
for i in range(len(row2) - 1):
row3.append([row2[i],row2[i + 1]])
lret = row3
return lret
def ratiosdealer(split_ratio2,split_ratio2r):
split_ratio2 = list_percentify(split_ratio2)
split_ratio2 = list_cumsum(split_ratio2)
split_ratio2 = list_rangify(split_ratio2)
split_ratio2r = list_percentify(split_ratio2r)
split_ratio2r = list_cumsum(split_ratio2r)
split_ratio2r = list_rangify(split_ratio2r)
return split_ratio2,split_ratio2r
def round_dim(x,y):
"""Return division of two numbers, rounding 0.5 up.
Seems that dimensions which are exactly 0.5 are rounded up - see 680x488, second iter.
A simple mod check should get the job done.
If not, can always brute force the divisor with +-1 on each of h/w.
"""
return x // y + (x % y >= y // 2)
def keyconverter(self,split_ratio,usebase):
'''convert BREAKS to ADDCOMM/ADDBASE/ADDCOL/ADDROW'''
if SPLROW not in split_ratio: # Commas only - interpret as 1d.
split_ratio2 = split_l2(split_ratio, SPLROW, SPLCOL, map_function = ffloatd(1))
split_ratio2r = [1]
else:
(split_ratio2r,split_ratio2) = split_l2(split_ratio, SPLROW, SPLCOL,
indsingles = True, map_function = ffloatd(1))
(split_ratio2,split_ratio2r) = ratiosdealer(split_ratio2,split_ratio2r)
#print(keychanger,p.prompt)
txtkey = fspace(DKEYINOUT[("in", False)]) + NLN
lkeys = [txtkey.join([""] * len(cell)) for cell in split_ratio2]
txtkey = fspace(DKEYINOUT[("out", False)]) + NLN
template = txtkey.join(lkeys)
if usebase:
template = fspace(KEYBASE) + NLN + template
changer = template.split(NLN)
changer = [l.strip() for l in changer]
keychanger=changer[:-1]
for change in keychanger:
if change == KEYBASE and KEYBASE in self.SR_prompt: continue
self.SR_prompt= self.SR_prompt.replace(KEYBRK,change,1)
def split_l2(s, key_row, key_col, indsingles = False, map_function = fidentity, split_struct = None):
lret = []
if split_struct is None:
lrows = s.split(key_row)
lrows = [row.split(key_col) for row in lrows]
# print(lrows)
for r in lrows:
cell = [map_function(x) for x in r]
lret.append(cell)
if indsingles:
lsingles = [row[0] for row in lret]
lcells = [row[1:] if len(row) > 1 else row for row in lret]
lret = (lsingles,lcells)
else:
lrows = str(s).split(key_row)
r = 0
lcells = []
lsingles = []
vlast = 1
for row in lrows:
row2 = row.split(key_col)
row2 = [map_function(x) for x in row2]
vlast = row2[-1]
indstop = False
while not indstop:
if (r >= len(split_struct) # Too many cell values, ignore.
or (len(row2) == 0 and len(split_struct) > 0)): # Cell exhausted.
indstop = True
if not indstop:
if indsingles: # Singles split.
lsingles.append(row2[0]) # Row ratio.
if len(row2) > 1:
row2 = row2[1:]
if len(split_struct[r]) >= len(row2): # Repeat last value.
indstop = True
broadrow = row2 + [row2[-1]] * (len(split_struct[r]) - len(row2))
r = r + 1
lcells.append(broadrow)
else: # Overfilled this row, cut and move to next.
broadrow = row2[:len(split_struct[r])]
row2 = row2[len(split_struct[r]):]
r = r + 1
lcells.append(broadrow)
# If not enough new rows, repeat the last one for entire base, preserving structure.
cur = len(lcells)
while cur < len(split_struct):
lcells.append([vlast] * len(split_struct[cur]))
cur = cur + 1
lret = lcells
if indsingles:
lsingles = lsingles + [lsingles[-1]] * (len(split_struct) - len(lsingles))
lret = (lsingles,lcells)
return lret
def matrixdealer(self, split_ratio, baseratio):
# print(split_ratio, baseratio)
prompt = self.SR_prompt
if KEYBASE in prompt: prompt = prompt.split(KEYBASE,1)[1]
if (KEYCOL in prompt.upper() or KEYROW in prompt.upper()):
# breaks = prompt.count(KEYROW) + prompt.count(KEYCOL) + int(self.usebase)
# Prompt anchors, count breaks between special keywords.
# print('prompt:', prompt)
lbreaks = split_l2(prompt, KEYROW, KEYCOL, map_function = fcountbrk)
# print('lbreaks', lbreaks)
if (SPLROW not in split_ratio and (KEYROW in prompt.upper()) != (KEYCOL in prompt.upper())):
# By popular demand, 1d integrated into 2d.
# This works by either adding a single row value (inner),
# or setting flip to the reverse (outer).
# Only applies when using just ADDROW / ADDCOL keys, and commas in ratio.
split_ratio = "1" + SPLCOL + split_ratio
(split_ratio2r,split_ratio2) = split_l2(split_ratio, SPLROW, SPLCOL, indsingles = True,
map_function = ffloatd(1), split_struct = lbreaks)
else: # Standard ratios, split to rows and cols.
(split_ratio2r,split_ratio2) = split_l2(split_ratio, SPLROW, SPLCOL, indsingles = True,
map_function = ffloatd(1), split_struct = lbreaks)
# print('split_ratio2r', split_ratio2r)
# print('split_ratio2', split_ratio2)
# More like "bweights", applied per cell only.
baseratio2 = split_l2(baseratio, SPLROW, SPLCOL, map_function = ffloatd(0), split_struct = lbreaks)
# print(baseratio2)
(split_ratio,split_ratior) = ratiosdealer(split_ratio2,split_ratio2r)
baseratio = baseratio2
# Merge various L2s to cells and rows.
drows = []
for r,_ in enumerate(lbreaks):
dcells = []
for c,_ in enumerate(lbreaks[r]):
d = Region(split_ratio[r][c][0], split_ratio[r][c][1], baseratio[r][c], lbreaks[r][c])
dcells.append(d)
drow = Row(split_ratior[r][0], split_ratior[r][1], dcells)
drows.append(drow)
self.split_ratio = drows
self.baseratio = baseratio
# class test:
# def __init__(self, prompt,split_ratio=None,baseratio=0.2,usebase=False):
# self.prompt = prompt
# self.split_ratio = split_ratio
# self.baseratio = 0.2
# self.usebase = usebase
# test_prompt='a girl BREAK a cute boy BREAK a dog BREAK a tree.'
# split_ratio='1,1,1;1,1,1'
# x=test(test_prompt,split_ratio)
# keyconverter(x,split_ratio,usebase=False)
# print(x.prompt)
# matrixdealer(x, split_ratio, 0.2)