Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import torch | |
import random | |
# import mediapipe as mp | |
from lite_openpose.body_bbox_detector import BodyPoseEstimator | |
from NTED.extraction_distribution_model import Generator | |
from NTED.demo_dataset import DemoDataset | |
from NTED.base_function import accumulate | |
from NTED.config import Config | |
def set_random_seed(seed): | |
r"""Set random seeds for everything. | |
Args: | |
seed (int): Random seed. | |
by_rank (bool): | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
class NTED(): | |
def __init__(self): | |
super(NTED, self).__init__() | |
self.openpose_module = BodyPoseEstimator('cpu') | |
set_random_seed(0) | |
self.opt = Config('NTED/fashion_512.yaml', is_train=False) | |
net_G = Generator(**self.opt.gen.param).to('cpu') | |
net_G_ema = Generator(**self.opt.gen.param).to('cpu') | |
net_G_ema.eval() | |
accumulate(net_G_ema, net_G, 0) | |
checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage) | |
net_G_ema.load_state_dict(checkpoint['net_G_ema']) | |
self.net_G = net_G_ema.eval() | |
self.data_loader = DemoDataset() | |
# mp_hands = mp.solutions.hands | |
# self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1) | |
self.ref_img = cv2.imread('example/ref_img.png') | |
self.ref_img = cv2.resize(self.ref_img, (352, 512)) | |
def hand_pose_est(self, img): | |
results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB)) | |
image_height, image_width, _ = img.shape | |
pose_data = [] | |
if results.multi_hand_landmarks is not None: | |
for hand_landmarks in results.multi_hand_landmarks: | |
for joint_idx in range(21): | |
pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height]) | |
if len(results.multi_hand_landmarks) == 2: | |
if results.multi_handedness[0].classification[0].label == 'Right': | |
# 交换一下,先左手再右手 | |
tmp = pose_data[:21].copy() | |
pose_data[:21] = pose_data[21:] | |
pose_data[21:] = tmp | |
elif len(results.multi_hand_landmarks) == 1: | |
miss_hand = [[-1, -1] for _ in range(21)] | |
if results.multi_handedness[0].classification[0].label == 'Left': | |
pose_data += miss_hand | |
else: | |
pose_data = miss_hand + pose_data | |
else: | |
for _ in range(42): | |
pose_data.append([-1, -1]) | |
pose_data = np.array(pose_data, dtype=np.int32) | |
return pose_data | |
def inference(self, img): | |
img = cv2.resize(img, (352, 512)) | |
body_pose, bbox = self.openpose_module.detect_body_pose(img.copy()) | |
# hand_pose = self.hand_pose_est(img.copy()) | |
data = self.data_loader.load_item(self.ref_img, body_pose[0], None) | |
output = self.net_G( | |
data['reference_image'], | |
data['target_skeleton'], | |
) | |
fake_image = output['fake_image'][0] | |
fake_image = self.data_loader.tensor2im(fake_image) | |
fake_image = cv2.resize(fake_image, (288, 480)) | |
return data['skeleton_img'], fake_image |