File size: 2,523 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Optional, Callable, List, Any

from ding.policy import PolicyFactory
from ding.worker import IMetric, MetricSerialEvaluator


class AccMetric(IMetric):

    def eval(self, inputs: Any, label: Any) -> dict:
        return {'Acc': (inputs['logit'].sum(dim=1) == label).sum().item() / label.shape[0]}

    def reduce_mean(self, inputs: List[Any]) -> Any:
        s = 0
        for item in inputs:
            s += item['Acc']
        return {'Acc': s / len(inputs)}

    def gt(self, metric1: Any, metric2: Any) -> bool:
        if metric2 is None:
            return True
        if isinstance(metric2, dict):
            m2 = metric2['Acc']
        else:
            m2 = metric2
        return metric1['Acc'] > m2


def mark_not_expert(ori_data: List[dict]) -> List[dict]:
    for i in range(len(ori_data)):
        # Set is_expert flag (expert 1, agent 0)
        ori_data[i]['is_expert'] = 0
    return ori_data


def mark_warm_up(ori_data: List[dict]) -> List[dict]:
    # for td3_vae
    for i in range(len(ori_data)):
        ori_data[i]['warm_up'] = True
    return ori_data


def random_collect(
        policy_cfg: 'EasyDict',  # noqa
        policy: 'Policy',  # noqa
        collector: 'ISerialCollector',  # noqa
        collector_env: 'BaseEnvManager',  # noqa
        commander: 'BaseSerialCommander',  # noqa
        replay_buffer: 'IBuffer',  # noqa
        postprocess_data_fn: Optional[Callable] = None
) -> None:  # noqa
    assert policy_cfg.random_collect_size > 0
    if policy_cfg.get('transition_with_policy_data', False):
        collector.reset_policy(policy.collect_mode)
    else:
        action_space = collector_env.action_space
        random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
        collector.reset_policy(random_policy)
    collect_kwargs = commander.step()
    if policy_cfg.collect.collector.type == 'episode':
        new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
    else:
        new_data = collector.collect(
            n_sample=policy_cfg.random_collect_size,
            random_collect=True,
            record_random_collect=False,
            policy_kwargs=collect_kwargs
        )  # 'record_random_collect=False' means random collect without output log
    if postprocess_data_fn is not None:
        new_data = postprocess_data_fn(new_data)
    replay_buffer.push(new_data, cur_collector_envstep=0)
    collector.reset_policy(policy.collect_mode)