File size: 3,332 Bytes
991f07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
# coding=utf-8

import torch
import torch.nn.functional as F


class Batch:
    @staticmethod
    def build(data):
        fields = list(data[0].keys())
        transposed = {}
        for field in fields:
            if isinstance(data[0][field], tuple):
                transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field])))
            else:
                transposed[field] = Batch._stack(field, [example[field] for example in data])

        return transposed

    @staticmethod
    def _stack(field: str, examples):
        if field == "anchored_labels":
            return examples

        dim = examples[0].dim()

        if dim == 0:
            return torch.stack(examples)

        lengths = [max(example.size(i) for example in examples) for i in range(dim)]
        if any(length == 0 for length in lengths):
            return torch.LongTensor(len(examples), *lengths)

        examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples]
        return torch.stack(examples)

    @staticmethod
    def _pad_size(example, total_size):
        return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))]

    @staticmethod
    def index_select(batch, indices):
        filtered_batch = {}
        for key, examples in batch.items():
            if isinstance(examples, list) or isinstance(examples, tuple):
                filtered_batch[key] = [example.index_select(0, indices) for example in examples]
            else:
                filtered_batch[key] = examples.index_select(0, indices)

        return filtered_batch

    @staticmethod
    def to_str(batch):
        string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()])
        return string

    @staticmethod
    def to(batch, device):
        converted = {}
        for field in batch.keys():
            converted[field] = Batch._to(batch[field], device)
        return converted

    @staticmethod
    def _short_str(tensor):
        # unwrap variable to tensor
        if not torch.is_tensor(tensor):
            # (1) unpack variable
            if hasattr(tensor, "data"):
                tensor = getattr(tensor, "data")
            # (2) handle include_lengths
            elif isinstance(tensor, tuple) or isinstance(tensor, list):
                return str(tuple(Batch._short_str(t) for t in tensor))
            # (3) fallback to default str
            else:
                return str(tensor)

        # copied from torch _tensor_str
        size_str = "x".join(str(size) for size in tensor.size())
        device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device())
        strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str)
        return strt

    @staticmethod
    def _to(tensor, device):
        if not torch.is_tensor(tensor):
            if isinstance(tensor, tuple):
                return tuple(Batch._to(t, device) for t in tensor)
            elif isinstance(tensor, list):
                return [Batch._to(t, device) for t in tensor]
            else:
                raise Exception(f"unsupported type of {tensor} to be casted to cuda")

        return tensor.to(device, non_blocking=True)