TomatoCocotree
上传
6a62ffb
raw
history blame
23.9 kB
import argparse
import logging
import os
import sys
import PIL.Image
import numpy
import torch
import wx
import json
from typing import List
# Set the working directory to the "live2d" subdirectory to work with file structure
target_directory = os.path.join(os.getcwd(), "live2d")
os.chdir(target_directory)
sys.path.append(os.getcwd())
from tha3.poser.modes.load_poser import load_poser
from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup
from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \
rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)
os.chdir(parent_directory)
class MorphCategoryControlPanel(wx.Panel):
def __init__(self,
parent,
title: str,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.pose_param_category = pose_param_category
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
self.sizer.Add(title_text, 0, wx.EXPAND)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
if len(self.param_groups) > 0:
self.choice.SetSelection(0)
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
self.sizer.Add(self.choice, 0, wx.EXPAND)
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.left_slider, 0, wx.EXPAND)
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
self.sizer.Add(self.right_slider, 0, wx.EXPAND)
self.checkbox = wx.CheckBox(self, label="Show")
self.checkbox.SetValue(True)
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
self.update_ui()
self.sizer.Fit(self)
def update_ui(self):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.left_slider.Enable(False)
self.right_slider.Enable(False)
self.checkbox.Enable(True)
elif param_group.get_arity() == 1:
self.left_slider.Enable(True)
self.right_slider.Enable(False)
self.checkbox.Enable(False)
else:
self.left_slider.Enable(True)
self.right_slider.Enable(True)
self.checkbox.Enable(False)
def on_choice_updated(self, event: wx.Event):
param_group = self.param_groups[self.choice.GetSelection()]
if param_group.is_discrete():
self.checkbox.SetValue(True)
self.update_ui()
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
selected_morph_index = self.choice.GetSelection()
param_group = self.param_groups[selected_morph_index]
param_index = param_group.get_parameter_index()
if param_group.is_discrete():
if self.checkbox.GetValue():
for i in range(param_group.get_arity()):
pose[param_index + i] = 1.0
else:
param_range = param_group.get_range()
alpha = (self.left_slider.GetValue() + 1000) / 2000.0
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
if param_group.get_arity() == 2:
alpha = (self.right_slider.GetValue() + 1000) / 2000.0
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
class SimpleParamGroupsControlPanel(wx.Panel):
def __init__(self, parent,
pose_param_category: PoseParameterCategory,
param_groups: List[PoseParameterGroup]):
super().__init__(parent, style=wx.SIMPLE_BORDER)
self.sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.sizer)
self.SetAutoLayout(1)
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
for param_group in self.param_groups:
assert not param_group.is_discrete()
assert param_group.get_arity() == 1
self.sliders = []
for param_group in self.param_groups:
static_text = wx.StaticText(
self,
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
self.sizer.Add(static_text, 0, wx.EXPAND)
range = param_group.get_range()
min_value = int(range[0] * 1000)
max_value = int(range[1] * 1000)
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
self.sizer.Add(slider, 0, wx.EXPAND)
self.sliders.append(slider)
self.sizer.Fit(self)
def set_param_value(self, pose: List[float]):
if len(self.param_groups) == 0:
return
for param_group_index in range(len(self.param_groups)):
param_group = self.param_groups[param_group_index]
slider = self.sliders[param_group_index]
param_range = param_group.get_range()
param_index = param_group.get_parameter_index()
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
def convert_output_image_from_torch_to_numpy(output_image):
if output_image.shape[2] == 2:
h, w, c = output_image.shape
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
elif output_image.shape[0] == 4:
numpy_image = rgba_to_numpy_image(output_image)
elif output_image.shape[0] == 3:
numpy_image = rgb_to_numpy_image(output_image)
elif output_image.shape[0] == 1:
c, h, w = output_image.shape
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
numpy_image = rgba_to_numpy_image(alpha_image)
elif output_image.shape[0] == 2:
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
else:
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
return numpy_image
class MainFrame(wx.Frame):
def __init__(self, poser: Poser, device: torch.device):
super().__init__(None, wx.ID_ANY, "Poser")
self.poser = poser
self.dtype = self.poser.get_dtype()
self.device = device
self.image_size = self.poser.get_image_size()
self.wx_source_image = None
self.torch_source_image = None
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.init_left_panel()
self.init_control_panel()
self.init_right_panel()
self.main_sizer.Fit(self)
self.timer = wx.Timer(self, wx.ID_ANY)
self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
save_image_id = wx.NewIdRef()
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
accelerator_table = wx.AcceleratorTable([
(wx.ACCEL_CTRL, ord('S'), save_image_id)
])
self.SetAcceleratorTable(accelerator_table)
self.last_pose = None
self.last_output_index = self.output_index_choice.GetSelection()
self.last_output_numpy_image = None
self.wx_source_image = None
self.torch_source_image = None
self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
self.source_image_dirty = True
def init_left_panel(self):
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.left_panel.SetSizer(left_panel_sizer)
self.left_panel.SetAutoLayout(1)
self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n")
left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
left_panel_sizer.Fit(self.left_panel)
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
def on_erase_background(self, event: wx.Event):
pass
def init_control_panel(self):
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.control_panel.SetSizer(self.control_panel_sizer)
self.control_panel.SetMinSize(wx.Size(256, 1))
morph_categories = [
PoseParameterCategory.EYEBROW,
PoseParameterCategory.EYE,
PoseParameterCategory.MOUTH,
PoseParameterCategory.IRIS_MORPH
]
morph_category_titles = {
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
PoseParameterCategory.EYE: " ------------ Eye ------------ ",
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
}
self.morph_control_panels = {}
for category in morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = MorphCategoryControlPanel(
self.control_panel,
morph_category_titles[category],
category,
self.poser.get_pose_parameter_groups())
self.morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.non_morph_control_panels = {}
non_morph_categories = [
PoseParameterCategory.IRIS_ROTATION,
PoseParameterCategory.FACE_ROTATION,
PoseParameterCategory.BODY_ROTATION,
PoseParameterCategory.BREATHING
]
for category in non_morph_categories:
param_groups = self.poser.get_pose_parameter_groups()
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
if len(filtered_param_groups) == 0:
continue
control_panel = SimpleParamGroupsControlPanel(
self.control_panel,
category,
self.poser.get_pose_parameter_groups())
self.non_morph_control_panels[category] = control_panel
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
self.control_panel_sizer.Fit(self.control_panel)
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
def init_right_panel(self):
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.right_panel.SetSizer(right_panel_sizer)
self.right_panel.SetAutoLayout(1)
self.result_image_panel = wx.Panel(self.right_panel,
size=(self.image_size, self.image_size),
style=wx.SIMPLE_BORDER)
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
self.output_index_choice = wx.Choice(
self.right_panel,
choices=[str(i) for i in range(self.poser.get_output_length())])
self.output_index_choice.SetSelection(0)
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
right_panel_sizer.Fit(self.right_panel)
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
def create_param_category_choice(self, param_category: PoseParameterCategory):
params = []
for param_group in self.poser.get_pose_parameter_groups():
if param_group.get_category() == param_category:
params.append(param_group.get_group_name())
choice = wx.Choice(self.control_panel, choices=params)
if len(params) > 0:
choice.SetSelection(0)
return choice
def load_image(self, event: wx.Event):
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name),
(self.poser.get_image_size(), self.poser.get_image_size()))
w, h = pil_image.size
if pil_image.mode != 'RGBA':
self.source_image_string = "Image must have alpha channel!"
self.wx_source_image = None
self.torch_source_image = None
else:
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\
.to(self.device).to(self.dtype)
self.source_image_dirty = True
self.Refresh()
self.Update()
except:
message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def paint_source_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
def paint_result_image_panel(self, event: wx.Event):
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
def draw_nothing_yet_string_to_bitmap(self, bitmap):
dc = wx.MemoryDC()
dc.SelectObject(bitmap)
dc.Clear()
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
dc.SetFont(font)
w, h = dc.GetTextExtent("Nothing yet!")
dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2)
del dc
def get_current_pose(self):
current_pose = [0.0 for i in range(self.poser.get_num_parameters())]
for morph_control_panel in self.morph_control_panels.values():
morph_control_panel.set_param_value(current_pose)
for rotation_control_panel in self.non_morph_control_panels.values():
rotation_control_panel.set_param_value(current_pose)
return current_pose
def update_images(self, event: wx.Event):
current_pose = self.get_current_pose()
if not self.source_image_dirty \
and self.last_pose is not None \
and self.last_pose == current_pose \
and self.last_output_index == self.output_index_choice.GetSelection():
return
self.last_pose = current_pose
self.last_output_index = self.output_index_choice.GetSelection()
if self.torch_source_image is None:
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
self.source_image_dirty = False
self.Refresh()
self.Update()
return
if self.source_image_dirty:
dc = wx.MemoryDC()
dc.SelectObject(self.source_image_bitmap)
dc.Clear()
dc.DrawBitmap(self.wx_source_image, 0, 0)
self.source_image_dirty = False
pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)
output_index = self.output_index_choice.GetSelection()
with torch.no_grad():
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
self.last_output_numpy_image = numpy_image
wx_image = wx.ImageFromBuffer(
numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(self.image_size - numpy_image.shape[0]) // 2,
(self.image_size - numpy_image.shape[1]) // 2,
True)
del dc
self.Refresh()
self.Update()
def get_current_posedict(self):
# Your dictionary of keys
keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index']
# Get the current pose as a list of values
current_pose_values = self.get_current_pose() # replace this with the actual method or property that gets the pose values
# Create a dictionary by zipping together the keys and values
current_pose_dict = dict(zip(keys, current_pose_values))
return current_pose_dict
def on_save_image(self, event: wx.Event):
if self.last_output_numpy_image is None:
logging.info("There is no output image to save!!!")
return
#keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index']
#current_pose_dict = dict(zip(keys, self.get_current_pose()))
#print(current_pose_dict)
# output settings to console.
dir_name = "data/images"
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
if file_dialog.ShowModal() == wx.ID_OK:
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
try:
if os.path.exists(image_file_name):
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
wx.YES_NO | wx.ICON_QUESTION)
result = message_dialog.ShowModal()
if result == wx.ID_YES:
self.save_last_numpy_image(image_file_name)
else:
self.save_last_numpy_image(image_file_name)
except:
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
message_dialog.ShowModal()
message_dialog.Destroy()
file_dialog.Destroy()
def save_last_numpy_image(self, image_file_name):
numpy_image = self.last_output_numpy_image
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
pil_image.save(image_file_name)
data_dict = self.get_current_posedict() # Get values
json_file_path = os.path.splitext(image_file_name)[0] + ".json" # Generate JSON file path
filename_without_extension = os.path.splitext(os.path.basename(image_file_name))[0]
data_dict_with_filename = {filename_without_extension: data_dict} # Create a new dict with the filename as the key
with open(json_file_path, "w") as file:
json.dump(data_dict_with_filename, file, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Manually pose a character image.')
parser.add_argument(
'--model',
type=str,
required=False,
default='separable_float',
choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
help='The model to use.')
args = parser.parse_args()
device = torch.device('cuda')
try:
poser = load_poser(args.model, device)
except RuntimeError as e:
print(e)
sys.exit()
app = wx.App()
main_frame = MainFrame(poser, device)
main_frame.Show(True)
main_frame.timer.Start(30)
app.MainLoop()