File size: 2,955 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import math
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR

from ding.policy import Policy
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils import EasyTimer


class ImageClassificationPolicy(Policy):
    config = dict(
        type='image_classification',
        on_policy=False,
    )

    def _init_learn(self):
        self._optimizer = SGD(
            self._model.parameters(),
            lr=self._cfg.learn.learning_rate,
            weight_decay=self._cfg.learn.weight_decay,
            momentum=0.9
        )
        self._timer = EasyTimer(cuda=True)

        def lr_scheduler_fn(epoch):
            if epoch <= self._cfg.learn.warmup_epoch:
                return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
            else:
                ratio = epoch // self._cfg.learn.decay_epoch
                return math.pow(self._cfg.learn.decay_rate, ratio)

        self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
        self._lr_scheduler.step()
        self._learn_model = model_wrap(self._model, 'base')
        self._learn_model.reset()

        self._ce_loss = nn.CrossEntropyLoss()

    def _forward_learn(self, data):
        if self._cuda:
            data = to_device(data, self._device)
        self._learn_model.train()

        with self._timer:
            img, target = data
            logit = self._learn_model.forward(img)
            loss = self._ce_loss(logit, target)
        forward_time = self._timer.value

        with self._timer:
            self._optimizer.zero_grad()
            loss.backward()
        backward_time = self._timer.value

        with self._timer:
            if self._cfg.multi_gpu:
                self.sync_gradients(self._learn_model)
        sync_time = self._timer.value
        self._optimizer.step()

        cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
        cur_lr = sum(cur_lr) / len(cur_lr)
        return {
            'cur_lr': cur_lr,
            'total_loss': loss.item(),
            'forward_time': forward_time,
            'backward_time': backward_time,
            'sync_time': sync_time,
        }

    def _monitor_vars_learn(self):
        return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']

    def _init_eval(self):
        self._eval_model = model_wrap(self._model, 'base')

    def _forward_eval(self, data):
        if self._cuda:
            data = to_device(data, self._device)
        self._eval_model.eval()
        with torch.no_grad():
            output = self._eval_model.forward(data)
        if self._cuda:
            output = to_device(output, 'cpu')
        return output

    def _init_collect(self):
        pass

    def _forward_collect(self, data):
        pass

    def _process_transition(self):
        pass

    def _get_train_sample(self):
        pass