Error while model loading
#2
by
sunovivid
- opened
H!. Thank you for the cool work.
I want to use this model, but when I use this code, it generates an error.
Traceback (most recent call last):
File "/home/cvlab12/project/donghoon/photo-restoration/ColorizeNet/colorize.py", line 14, in <module>
model.load_state_dict(load_state_dict(
File "/home/cvlab12/project/donghoon/photo-restoration/ColorizeNet/utils/model.py", line 40, in load_state_dict
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
File "/home/cvlab12/miniconda3/envs/photo/lib/python3.9/site-packages/torch/serialization.py", line 1114, in load
return _legacy_load(
File "/home/cvlab12/miniconda3/envs/photo/lib/python3.9/site-packages/torch/serialization.py", line 1338, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, 'v'.
This is my code to generate samples (almost same except the model loading path)
import random
import cv2
import einops
import numpy as np
import torch
from pytorch_lightning import seed_everything
from utils.data import HWC3, apply_color, resize_image
from utils.ddim import DDIMSampler
from utils.model import create_model, load_state_dict
model = create_model('./models/cldm_v21.yaml').cpu()
model.load_state_dict(load_state_dict(
'./models/colorizenet-sd21.ckpt', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)
input_image = cv2.imread("sample_data/sample1_bw.jpg")
input_image = HWC3(input_image)
img = resize_image(input_image, resolution=512)
H, W, C = img.shape
num_samples = 1
control = torch.from_numpy(img.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
# seed = random.randint(0, 65535)
seed = 1294574436
seed_everything(seed)
prompt = "Colorize this image"
n_prompt = ""
guess_mode = False
strength = 1.0
eta = 0.0
ddim_steps = 20
scale = 9.0
cond = {"c_concat": [control], "c_crossattn": [
model.get_learned_conditioning([prompt] * num_samples)]}
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
[strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
* 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
colored_results = [apply_color(img, result) for result in results]
[cv2.imwrite(f"colorized_{i}.jpg", cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) for i, result in enumerate(colored_results)]
Why does this error happen? Do you have any idea on this issue?
Sorry, I was not familiar with HF datasets, so I didn't know that I should download the model additionally. I thought Git clone is sufficient.
For future newbies like me, I remain the issue.
I used wget https://huggingface.co/rsortino/ColorizeNet/resolve/main/colorizenet-sd21.ckpt?download=true
, which is the model path of colorizenet-sd21.ckpt
.
`