ControlNet-Face-Chinese / pred_color.py
svjack's picture
Update pred_color.py
6859b03
###
'''
!git clone https://huggingface.co/spaces/radames/SPIGA-face-alignment-headpose-estimator
!cp -r SPIGA-face-alignment-headpose-estimator/SPIGA .
!pip install -r SPIGA/requirements.txt
!pip install datasets
!pip install retinaface-py>=0.0.2
!pip install bounding-box
!huggingface-cli login
'''
import sys
sys.path.insert(0, "SPIGA")
import numpy as np
from datasets import load_dataset
from spiga.inference.config import ModelConfig
from spiga.inference.framework import SPIGAFramework
processor = SPIGAFramework(ModelConfig("300wpublic"))
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.path import Path
import PIL
def get_patch(landmarks, color='lime', closed=False):
contour = landmarks
ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
facecolor = (0, 0, 0, 0) # Transparent fill color, if open
if closed:
contour.append(contour[0])
ops.append(Path.CLOSEPOLY)
facecolor = color
path = Path(contour, ops)
return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
# Draw to a buffer.
def conditioning_from_landmarks(landmarks, size=512):
# Precisely control output image size
dpi = 72
fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0})
fig.set_dpi(dpi)
black = np.zeros((size, size, 3))
ax.imshow(black)
face_patch = get_patch(landmarks[0:17])
l_eyebrow = get_patch(landmarks[17:22], color='yellow')
r_eyebrow = get_patch(landmarks[22:27], color='yellow')
nose_v = get_patch(landmarks[27:31], color='orange')
nose_h = get_patch(landmarks[31:36], color='orange')
l_eye = get_patch(landmarks[36:42], color='magenta', closed=True)
r_eye = get_patch(landmarks[42:48], color='magenta', closed=True)
outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True)
inner_lips = get_patch(landmarks[60:68], color='blue', closed=True)
ax.add_patch(face_patch)
ax.add_patch(l_eyebrow)
ax.add_patch(r_eyebrow)
ax.add_patch(nose_v)
ax.add_patch(nose_h)
ax.add_patch(l_eye)
ax.add_patch(r_eye)
ax.add_patch(outer_lips)
ax.add_patch(inner_lips)
plt.axis('off')
fig.canvas.draw()
buffer, (width, height) = fig.canvas.print_to_buffer()
assert width == height
assert width == size
buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
buffer = buffer[:, :, 0:3]
plt.close(fig)
return PIL.Image.fromarray(buffer)
import retinaface
import spiga.demo.analyze.track.retinasort.config as cfg
config = cfg.cfg_retinasort
device = "cpu"
face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'],
device=device,
extra_features=config['retina']['extra_features'],
cfg_postreat=config['retina']['postreat'])
import cv2
Image = PIL.Image
import os
def single_pred_features(image):
if type(image) == type("") and os.path.exists(image):
image = Image.open(image).convert("RGB")
elif hasattr(image, "shape"):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
image = image.resize((512, 512))
cv2_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
face_detector.set_input_shape(image.size[1], image.size[0])
features = face_detector.inference(image)
if features:
bboxes = features['bbox']
bboxes_n = []
for bbox in bboxes:
x1, y1, x2, y2 = bbox[:4]
bbox_wh = [x1, y1, x2-x1, y2-y1]
bboxes_n.append(bbox_wh)
face_features = processor.inference(cv2_image, bboxes_n)
landmarks = face_features["landmarks"][0]
face_features["spiga"] = landmarks
face_features['spiga_seg'] = conditioning_from_landmarks(landmarks)
return features ,face_features
def produce_center_crop_image(features ,face_features, draw_rect = False):
left, top, right, bottom, _ = features["bbox"][0]
color = "red"
label = ""
from bounding_box import bounding_box as bb
img = np.asarray(face_features["spiga_seg"])
step = 20
if draw_rect:
bb.add(img, left - step, top - step, right + step, bottom + step, label, color)
crop_img = Image.fromarray(img[ int(top - step):int(bottom + step) ,int(left - step):int(right + step), :])
crop_img = crop_img.resize((256, 256))
req = Image.fromarray(
np.concatenate(
[np.full([512, 128, 3], fill_value=0),
np.concatenate([np.full([128, 256, 3], fill_value=0) ,np.asarray(crop_img),
np.full([128, 256, 3], fill_value=0)], axis = 0),
np.full([512, 128, 3], fill_value=0)
], axis = 1
).astype(np.uint8))
return req
'''
from pred_color import *
img = "babyxiang_ai.png"
img = "Protector_Cromwell_style.png"
features ,face_features = single_pred_features(img)
fix_img = produce_center_crop_image(features ,face_features, draw_rect = False)
fix_img
fix_r_img = produce_center_crop_image(features ,face_features, draw_rect = True)
fix_r_img
from pred_color import *
img = "babyxiang_ai.png"
img = "Protector_Cromwell_style.png"
features ,face_features = single_pred_features(img)
left, top, right, bottom, _ = features["bbox"][0]
color = "red"
label = ""
from bounding_box import bounding_box as bb
img = np.asarray(face_features["spiga_seg"])
step = 20
bb.add(img, left - step, top - step, right + step, bottom + step, label, color)
Image.fromarray(img)
crop_img = Image.fromarray(img[ int(top - step):int(bottom + step) ,int(left - step):int(right + step), :])
crop_img = crop_img.resize((256, 256))
crop_img
Image.fromarray(
np.concatenate(
[np.full([512, 128, 3], fill_value=0),
np.concatenate([np.full([128, 256, 3], fill_value=0) ,np.asarray(crop_img),
np.full([128, 256, 3], fill_value=0)], axis = 0),
np.full([512, 128, 3], fill_value=0)
], axis = 1
).astype(np.uint8))
'''
if __name__ == "__main__":
from datasets import load_dataset, Dataset
ds = load_dataset("svjack/facesyntheticsspigacaptioned_en_zh_1")
dss = ds["train"]
xiangbaobao = PIL.Image.open("babyxiang.png")
out = single_pred_features(xiangbaobao.resize((512, 512)))
out["spiga_seg"]
out = single_pred_features(dss[0]["image"])
out["spiga_seg"]