File size: 4,183 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from tensorflow.keras import layers
from tensorflow.keras import Model
import tensorflow as tf
from transformers import TFPreTrainedModel

valid_types = ["gpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]


class Transformer(Model):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.remaining_frames_method = self.get_remaining_frames_method(config)
        self.transformer_type = self.get_transformer_type(config)
        self.transformer = self.load_transformer(
            self.remaining_frames_method, self.transformer_type
        )

    def get_transformer_type(self, config):
        if "transformer_type" in config:
            transformer_type = config["transformer_type"]
            if transformer_type not in valid_types:
                raise ValueError(
                    f"transformer_type {transformer_type} is not valid. Valid types are {valid_types}"
                )
            return transformer_type
        else:
            return valid_types[0]

    def get_remaining_frames_method(self, config) -> str:
        """Get the method to use for remaining frames.
        Check if the method is set inside the configuration, otherwise use concat as the default.
        """
        if "remaining_frames_method" in config:
            return config["remaining_frames_method"]
        else:
            return "concat"

    def load_transformer(self, method: str, transformer_type: str) -> TFPreTrainedModel:
        print("using method ", method)
        if method == "own_embeddings":
            from ganime.model.vqgan_clean.experimental.gpt2_embedding import (
                TFGPT2LMHeadModel,
            )

            transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type)

        else:
            from transformers import TFGPT2LMHeadModel

            transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type)
        return transformer

    def concatenate_inputs(
        self, remaining_frames, last_frame_indices, previous_frame_indices
    ) -> tf.Tensor:
        if self.remaining_frames_method == "concat":
            return tf.concat(
                [remaining_frames, last_frame_indices, previous_frame_indices], axis=1
            )
        else:
            return tf.concat([last_frame_indices, previous_frame_indices], axis=1)

    def call_transformer(
        self, transformer_input, remaining_frames, training, attention_mask
    ):
        if self.remaining_frames_method == "concat":
            return self.transformer(
                transformer_input, training=training, attention_mask=attention_mask
            )
        elif self.remaining_frames_method == "token_type_ids":
            return self.transformer(
                transformer_input,
                token_type_ids=remaining_frames,
                training=training,
                attention_mask=attention_mask,
            )
        elif self.remaining_frames_method == "own_embeddings":
            return self.transformer(
                transformer_input,
                remaining_frames_ids=remaining_frames,
                training=training,
                attention_mask=attention_mask,
            )
        else:
            raise ValueError(
                f"Unknown remaining_frames_method {self.remaining_frames_method}"
            )

    def call(self, inputs, training=True, mask=None):
        remaining_frames, last_frame_indices, previous_frame_indices = inputs
        remaining_frames = tf.expand_dims(remaining_frames, axis=1)
        shape_to_keep = tf.shape(last_frame_indices)[1]

        h = self.concatenate_inputs(
            remaining_frames, last_frame_indices, previous_frame_indices
        )

        # transformer_input = h[:, :-1]
        transformer_input = h
        mask = tf.ones_like(transformer_input) * tf.cast(
            tf.cast(remaining_frames, dtype=tf.bool), dtype=remaining_frames.dtype
        )

        h = self.call_transformer(transformer_input, remaining_frames, training, mask)
        h = h.logits
        # h = self.transformer.transformer.wte(h, mode="linear")
        h = h[:, -shape_to_keep:]
        return h