File size: 6,397 Bytes
b213d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import glob
import os
import shutil
import time
from random import randint

import cv2
import numpy as np
import torch
from densepose import add_densepose_config
from densepose.vis.base import CompoundVisualizer
from densepose.vis.densepose_results import DensePoseResultsFineSegmentationVisualizer
from densepose.vis.extractor import CompoundExtractor, create_extractor
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from PIL import Image


class DensePose:
    """
    DensePose used in this project is from Detectron2 (https://github.com/facebookresearch/detectron2).
    These codes are modified from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose.
    The checkpoint is downloaded from https://github.com/facebookresearch/detectron2/blob/main/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo.

    We use the model R_50_FPN_s1x with id 165712039, but other models should also work.
    The config file is downloaded from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose/configs.
    Noted that the config file should match the model checkpoint and Base-DensePose-RCNN-FPN.yaml is also needed.
    """

    def __init__(self, model_path="./checkpoints/densepose_", device="cuda"):
        self.device = device
        self.config_path = os.path.join(model_path, "densepose_rcnn_R_50_FPN_s1x.yaml")
        self.model_path = os.path.join(model_path, "model_final_162be9.pkl")
        self.visualizations = ["dp_segm"]
        self.VISUALIZERS = {"dp_segm": DensePoseResultsFineSegmentationVisualizer}
        self.min_score = 0.8

        self.cfg = self.setup_config()
        self.predictor = DefaultPredictor(self.cfg)
        self.predictor.model.to(self.device)

    def setup_config(self):
        opts = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(self.min_score)]
        cfg = get_cfg()
        add_densepose_config(cfg)
        cfg.merge_from_file(self.config_path)
        cfg.merge_from_list(opts)
        cfg.MODEL.WEIGHTS = self.model_path
        cfg.freeze()
        return cfg

    @staticmethod
    def _get_input_file_list(input_spec: str):
        if os.path.isdir(input_spec):
            file_list = [
                os.path.join(input_spec, fname)
                for fname in os.listdir(input_spec)
                if os.path.isfile(os.path.join(input_spec, fname))
            ]
        elif os.path.isfile(input_spec):
            file_list = [input_spec]
        else:
            file_list = glob.glob(input_spec)
        return file_list

    def create_context(self, cfg, output_path):
        vis_specs = self.visualizations
        visualizers = []
        extractors = []
        for vis_spec in vis_specs:
            texture_atlas = texture_atlases_dict = None
            vis = self.VISUALIZERS[vis_spec](
                cfg=cfg,
                texture_atlas=texture_atlas,
                texture_atlases_dict=texture_atlases_dict,
                alpha=1.0,
            )
            visualizers.append(vis)
            extractor = create_extractor(vis)
            extractors.append(extractor)
        visualizer = CompoundVisualizer(visualizers)
        extractor = CompoundExtractor(extractors)
        context = {
            "extractor": extractor,
            "visualizer": visualizer,
            "out_fname": output_path,
            "entry_idx": 0,
        }
        return context

    def execute_on_outputs(self, context, entry, outputs):
        extractor = context["extractor"]

        data = extractor(outputs)

        H, W, _ = entry["image"].shape
        result = np.zeros((H, W), dtype=np.uint8)

        data, box = data[0]
        x, y, w, h = [int(_) for _ in box[0].cpu().numpy()]
        i_array = data[0].labels[None].cpu().numpy()[0]
        result[y : y + h, x : x + w] = i_array
        result = Image.fromarray(result)
        result.save(context["out_fname"])

    def __call__(self, image_or_path, resize=512) -> Image.Image:
        """
        :param image_or_path: Path of the input image.
        :param resize: Resize the input image if its max size is larger than this value.
        :return: Dense pose image.
        """
        # random tmp path with timestamp
        tmp_path = f"./densepose_/tmp/"
        if not os.path.exists(tmp_path):
            os.makedirs(tmp_path)

        image_path = os.path.join(
            tmp_path, f"{int(time.time())}-{self.device}-{randint(0, 100000)}.png"
        )
        if isinstance(image_or_path, str):
            assert image_or_path.split(".")[-1] in [
                "jpg",
                "png",
            ], "Only support jpg and png images."
            shutil.copy(image_or_path, image_path)
        elif isinstance(image_or_path, Image.Image):
            image_or_path.save(image_path)
        else:
            shutil.rmtree(tmp_path)
            raise TypeError("image_path must be str or PIL.Image.Image")

        output_path = image_path.replace(".png", "_dense.png").replace(
            ".jpg", "_dense.png"
        )
        w, h = Image.open(image_path).size

        file_list = self._get_input_file_list(image_path)
        assert len(file_list), "No input images found!"
        context = self.create_context(self.cfg, output_path)
        for file_name in file_list:
            img = read_image(file_name, format="BGR")  # predictor expects BGR image.
            # resize
            if (_ := max(img.shape)) > resize:
                scale = resize / _
                img = cv2.resize(
                    img, (int(img.shape[1] * scale), int(img.shape[0] * scale))
                )

            with torch.no_grad():
                outputs = self.predictor(img)["instances"]
                try:
                    self.execute_on_outputs(
                        context, {"file_name": file_name, "image": img}, outputs
                    )
                except Exception as e:
                    null_gray = Image.new("L", (1, 1))
                    null_gray.save(output_path)

        dense_gray = Image.open(output_path).convert("L")
        dense_gray = dense_gray.resize((w, h), Image.NEAREST)
        # remove image_path and output_path
        os.remove(image_path)
        os.remove(output_path)

        return dense_gray


if __name__ == "__main__":
    pass