Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os | |
import sys | |
import PIL.Image | |
import numpy | |
import torch | |
import wx | |
import json | |
from typing import List | |
# Set the working directory to the "live2d" subdirectory to work with file structure | |
target_directory = os.path.join(os.getcwd(), "live2d") | |
os.chdir(target_directory) | |
sys.path.append(os.getcwd()) | |
from tha3.poser.modes.load_poser import load_poser | |
from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup | |
from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \ | |
rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image | |
current_directory = os.getcwd() | |
parent_directory = os.path.dirname(current_directory) | |
os.chdir(parent_directory) | |
class MorphCategoryControlPanel(wx.Panel): | |
def __init__(self, | |
parent, | |
title: str, | |
pose_param_category: PoseParameterCategory, | |
param_groups: List[PoseParameterGroup]): | |
super().__init__(parent, style=wx.SIMPLE_BORDER) | |
self.pose_param_category = pose_param_category | |
self.sizer = wx.BoxSizer(wx.VERTICAL) | |
self.SetSizer(self.sizer) | |
self.SetAutoLayout(1) | |
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER) | |
self.sizer.Add(title_text, 0, wx.EXPAND) | |
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] | |
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups]) | |
if len(self.param_groups) > 0: | |
self.choice.SetSelection(0) | |
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated) | |
self.sizer.Add(self.choice, 0, wx.EXPAND) | |
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) | |
self.sizer.Add(self.left_slider, 0, wx.EXPAND) | |
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) | |
self.sizer.Add(self.right_slider, 0, wx.EXPAND) | |
self.checkbox = wx.CheckBox(self, label="Show") | |
self.checkbox.SetValue(True) | |
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER) | |
self.update_ui() | |
self.sizer.Fit(self) | |
def update_ui(self): | |
param_group = self.param_groups[self.choice.GetSelection()] | |
if param_group.is_discrete(): | |
self.left_slider.Enable(False) | |
self.right_slider.Enable(False) | |
self.checkbox.Enable(True) | |
elif param_group.get_arity() == 1: | |
self.left_slider.Enable(True) | |
self.right_slider.Enable(False) | |
self.checkbox.Enable(False) | |
else: | |
self.left_slider.Enable(True) | |
self.right_slider.Enable(True) | |
self.checkbox.Enable(False) | |
def on_choice_updated(self, event: wx.Event): | |
param_group = self.param_groups[self.choice.GetSelection()] | |
if param_group.is_discrete(): | |
self.checkbox.SetValue(True) | |
self.update_ui() | |
def set_param_value(self, pose: List[float]): | |
if len(self.param_groups) == 0: | |
return | |
selected_morph_index = self.choice.GetSelection() | |
param_group = self.param_groups[selected_morph_index] | |
param_index = param_group.get_parameter_index() | |
if param_group.is_discrete(): | |
if self.checkbox.GetValue(): | |
for i in range(param_group.get_arity()): | |
pose[param_index + i] = 1.0 | |
else: | |
param_range = param_group.get_range() | |
alpha = (self.left_slider.GetValue() + 1000) / 2000.0 | |
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha | |
if param_group.get_arity() == 2: | |
alpha = (self.right_slider.GetValue() + 1000) / 2000.0 | |
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha | |
class SimpleParamGroupsControlPanel(wx.Panel): | |
def __init__(self, parent, | |
pose_param_category: PoseParameterCategory, | |
param_groups: List[PoseParameterGroup]): | |
super().__init__(parent, style=wx.SIMPLE_BORDER) | |
self.sizer = wx.BoxSizer(wx.VERTICAL) | |
self.SetSizer(self.sizer) | |
self.SetAutoLayout(1) | |
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] | |
for param_group in self.param_groups: | |
assert not param_group.is_discrete() | |
assert param_group.get_arity() == 1 | |
self.sliders = [] | |
for param_group in self.param_groups: | |
static_text = wx.StaticText( | |
self, | |
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER) | |
self.sizer.Add(static_text, 0, wx.EXPAND) | |
range = param_group.get_range() | |
min_value = int(range[0] * 1000) | |
max_value = int(range[1] * 1000) | |
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) | |
self.sizer.Add(slider, 0, wx.EXPAND) | |
self.sliders.append(slider) | |
self.sizer.Fit(self) | |
def set_param_value(self, pose: List[float]): | |
if len(self.param_groups) == 0: | |
return | |
for param_group_index in range(len(self.param_groups)): | |
param_group = self.param_groups[param_group_index] | |
slider = self.sliders[param_group_index] | |
param_range = param_group.get_range() | |
param_index = param_group.get_parameter_index() | |
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin()) | |
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha | |
def convert_output_image_from_torch_to_numpy(output_image): | |
if output_image.shape[2] == 2: | |
h, w, c = output_image.shape | |
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) | |
elif output_image.shape[0] == 4: | |
numpy_image = rgba_to_numpy_image(output_image) | |
elif output_image.shape[0] == 3: | |
numpy_image = rgb_to_numpy_image(output_image) | |
elif output_image.shape[0] == 1: | |
c, h, w = output_image.shape | |
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) | |
numpy_image = rgba_to_numpy_image(alpha_image) | |
elif output_image.shape[0] == 2: | |
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) | |
else: | |
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) | |
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0)) | |
return numpy_image | |
class MainFrame(wx.Frame): | |
def __init__(self, poser: Poser, device: torch.device): | |
super().__init__(None, wx.ID_ANY, "Poser") | |
self.poser = poser | |
self.dtype = self.poser.get_dtype() | |
self.device = device | |
self.image_size = self.poser.get_image_size() | |
self.wx_source_image = None | |
self.torch_source_image = None | |
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL) | |
self.SetSizer(self.main_sizer) | |
self.SetAutoLayout(1) | |
self.init_left_panel() | |
self.init_control_panel() | |
self.init_right_panel() | |
self.main_sizer.Fit(self) | |
self.timer = wx.Timer(self, wx.ID_ANY) | |
self.Bind(wx.EVT_TIMER, self.update_images, self.timer) | |
save_image_id = wx.NewIdRef() | |
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id) | |
accelerator_table = wx.AcceleratorTable([ | |
(wx.ACCEL_CTRL, ord('S'), save_image_id) | |
]) | |
self.SetAcceleratorTable(accelerator_table) | |
self.last_pose = None | |
self.last_output_index = self.output_index_choice.GetSelection() | |
self.last_output_numpy_image = None | |
self.wx_source_image = None | |
self.torch_source_image = None | |
self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size) | |
self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size) | |
self.source_image_dirty = True | |
def init_left_panel(self): | |
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1)) | |
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) | |
left_panel_sizer = wx.BoxSizer(wx.VERTICAL) | |
self.left_panel.SetSizer(left_panel_sizer) | |
self.left_panel.SetAutoLayout(1) | |
self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size), | |
style=wx.SIMPLE_BORDER) | |
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) | |
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) | |
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) | |
self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n") | |
left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) | |
self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image) | |
left_panel_sizer.Fit(self.left_panel) | |
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE) | |
def on_erase_background(self, event: wx.Event): | |
pass | |
def init_control_panel(self): | |
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL) | |
self.control_panel.SetSizer(self.control_panel_sizer) | |
self.control_panel.SetMinSize(wx.Size(256, 1)) | |
morph_categories = [ | |
PoseParameterCategory.EYEBROW, | |
PoseParameterCategory.EYE, | |
PoseParameterCategory.MOUTH, | |
PoseParameterCategory.IRIS_MORPH | |
] | |
morph_category_titles = { | |
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ", | |
PoseParameterCategory.EYE: " ------------ Eye ------------ ", | |
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ", | |
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ", | |
} | |
self.morph_control_panels = {} | |
for category in morph_categories: | |
param_groups = self.poser.get_pose_parameter_groups() | |
filtered_param_groups = [group for group in param_groups if group.get_category() == category] | |
if len(filtered_param_groups) == 0: | |
continue | |
control_panel = MorphCategoryControlPanel( | |
self.control_panel, | |
morph_category_titles[category], | |
category, | |
self.poser.get_pose_parameter_groups()) | |
self.morph_control_panels[category] = control_panel | |
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) | |
self.non_morph_control_panels = {} | |
non_morph_categories = [ | |
PoseParameterCategory.IRIS_ROTATION, | |
PoseParameterCategory.FACE_ROTATION, | |
PoseParameterCategory.BODY_ROTATION, | |
PoseParameterCategory.BREATHING | |
] | |
for category in non_morph_categories: | |
param_groups = self.poser.get_pose_parameter_groups() | |
filtered_param_groups = [group for group in param_groups if group.get_category() == category] | |
if len(filtered_param_groups) == 0: | |
continue | |
control_panel = SimpleParamGroupsControlPanel( | |
self.control_panel, | |
category, | |
self.poser.get_pose_parameter_groups()) | |
self.non_morph_control_panels[category] = control_panel | |
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) | |
self.control_panel_sizer.Fit(self.control_panel) | |
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE) | |
def init_right_panel(self): | |
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) | |
right_panel_sizer = wx.BoxSizer(wx.VERTICAL) | |
self.right_panel.SetSizer(right_panel_sizer) | |
self.right_panel.SetAutoLayout(1) | |
self.result_image_panel = wx.Panel(self.right_panel, | |
size=(self.image_size, self.image_size), | |
style=wx.SIMPLE_BORDER) | |
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) | |
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) | |
self.output_index_choice = wx.Choice( | |
self.right_panel, | |
choices=[str(i) for i in range(self.poser.get_output_length())]) | |
self.output_index_choice.SetSelection(0) | |
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) | |
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND) | |
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n") | |
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND) | |
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image) | |
right_panel_sizer.Fit(self.right_panel) | |
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE) | |
def create_param_category_choice(self, param_category: PoseParameterCategory): | |
params = [] | |
for param_group in self.poser.get_pose_parameter_groups(): | |
if param_group.get_category() == param_category: | |
params.append(param_group.get_group_name()) | |
choice = wx.Choice(self.control_panel, choices=params) | |
if len(params) > 0: | |
choice.SetSelection(0) | |
return choice | |
def load_image(self, event: wx.Event): | |
dir_name = "data/images" | |
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN) | |
if file_dialog.ShowModal() == wx.ID_OK: | |
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) | |
try: | |
pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name), | |
(self.poser.get_image_size(), self.poser.get_image_size())) | |
w, h = pil_image.size | |
if pil_image.mode != 'RGBA': | |
self.source_image_string = "Image must have alpha channel!" | |
self.wx_source_image = None | |
self.torch_source_image = None | |
else: | |
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) | |
self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\ | |
.to(self.device).to(self.dtype) | |
self.source_image_dirty = True | |
self.Refresh() | |
self.Update() | |
except: | |
message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK) | |
message_dialog.ShowModal() | |
message_dialog.Destroy() | |
file_dialog.Destroy() | |
def paint_source_image_panel(self, event: wx.Event): | |
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) | |
def paint_result_image_panel(self, event: wx.Event): | |
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) | |
def draw_nothing_yet_string_to_bitmap(self, bitmap): | |
dc = wx.MemoryDC() | |
dc.SelectObject(bitmap) | |
dc.Clear() | |
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) | |
dc.SetFont(font) | |
w, h = dc.GetTextExtent("Nothing yet!") | |
dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2) | |
del dc | |
def get_current_pose(self): | |
current_pose = [0.0 for i in range(self.poser.get_num_parameters())] | |
for morph_control_panel in self.morph_control_panels.values(): | |
morph_control_panel.set_param_value(current_pose) | |
for rotation_control_panel in self.non_morph_control_panels.values(): | |
rotation_control_panel.set_param_value(current_pose) | |
return current_pose | |
def update_images(self, event: wx.Event): | |
current_pose = self.get_current_pose() | |
if not self.source_image_dirty \ | |
and self.last_pose is not None \ | |
and self.last_pose == current_pose \ | |
and self.last_output_index == self.output_index_choice.GetSelection(): | |
return | |
self.last_pose = current_pose | |
self.last_output_index = self.output_index_choice.GetSelection() | |
if self.torch_source_image is None: | |
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap) | |
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap) | |
self.source_image_dirty = False | |
self.Refresh() | |
self.Update() | |
return | |
if self.source_image_dirty: | |
dc = wx.MemoryDC() | |
dc.SelectObject(self.source_image_bitmap) | |
dc.Clear() | |
dc.DrawBitmap(self.wx_source_image, 0, 0) | |
self.source_image_dirty = False | |
pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype) | |
output_index = self.output_index_choice.GetSelection() | |
with torch.no_grad(): | |
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() | |
numpy_image = convert_output_image_from_torch_to_numpy(output_image) | |
self.last_output_numpy_image = numpy_image | |
wx_image = wx.ImageFromBuffer( | |
numpy_image.shape[0], | |
numpy_image.shape[1], | |
numpy_image[:, :, 0:3].tobytes(), | |
numpy_image[:, :, 3].tobytes()) | |
wx_bitmap = wx_image.ConvertToBitmap() | |
dc = wx.MemoryDC() | |
dc.SelectObject(self.result_image_bitmap) | |
dc.Clear() | |
dc.DrawBitmap(wx_bitmap, | |
(self.image_size - numpy_image.shape[0]) // 2, | |
(self.image_size - numpy_image.shape[1]) // 2, | |
True) | |
del dc | |
self.Refresh() | |
self.Update() | |
def get_current_posedict(self): | |
# Your dictionary of keys | |
keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index'] | |
# Get the current pose as a list of values | |
current_pose_values = self.get_current_pose() # replace this with the actual method or property that gets the pose values | |
# Create a dictionary by zipping together the keys and values | |
current_pose_dict = dict(zip(keys, current_pose_values)) | |
return current_pose_dict | |
def on_save_image(self, event: wx.Event): | |
if self.last_output_numpy_image is None: | |
logging.info("There is no output image to save!!!") | |
return | |
#keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index'] | |
#current_pose_dict = dict(zip(keys, self.get_current_pose())) | |
#print(current_pose_dict) | |
# output settings to console. | |
dir_name = "data/images" | |
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE) | |
if file_dialog.ShowModal() == wx.ID_OK: | |
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) | |
try: | |
if os.path.exists(image_file_name): | |
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser", | |
wx.YES_NO | wx.ICON_QUESTION) | |
result = message_dialog.ShowModal() | |
if result == wx.ID_YES: | |
self.save_last_numpy_image(image_file_name) | |
else: | |
self.save_last_numpy_image(image_file_name) | |
except: | |
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK) | |
message_dialog.ShowModal() | |
message_dialog.Destroy() | |
file_dialog.Destroy() | |
def save_last_numpy_image(self, image_file_name): | |
numpy_image = self.last_output_numpy_image | |
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA') | |
os.makedirs(os.path.dirname(image_file_name), exist_ok=True) | |
pil_image.save(image_file_name) | |
data_dict = self.get_current_posedict() # Get values | |
json_file_path = os.path.splitext(image_file_name)[0] + ".json" # Generate JSON file path | |
filename_without_extension = os.path.splitext(os.path.basename(image_file_name))[0] | |
data_dict_with_filename = {filename_without_extension: data_dict} # Create a new dict with the filename as the key | |
with open(json_file_path, "w") as file: | |
json.dump(data_dict_with_filename, file, indent=4) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Manually pose a character image.') | |
parser.add_argument( | |
'--model', | |
type=str, | |
required=False, | |
default='separable_float', | |
choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'], | |
help='The model to use.') | |
args = parser.parse_args() | |
device = torch.device('cuda') | |
try: | |
poser = load_poser(args.model, device) | |
except RuntimeError as e: | |
print(e) | |
sys.exit() | |
app = wx.App() | |
main_frame = MainFrame(poser, device) | |
main_frame.Show(True) | |
main_frame.timer.Start(30) | |
app.MainLoop() | |