File size: 1,268 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from dataclasses import dataclass
from typing import List
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
@dataclass
class GPTConfig:
n_layer: int
n_head: int
n_embedding: int
vocab_size: int
block_size: int
embedding_percentage_drop: float
attention_percentage_drop: float
@dataclass
class VQVAEConfig:
beta: float
num_embeddings: int
embedding_dim: int
@dataclass
class AutoencoderConfig:
z_channels: int
channels: int
channels_multiplier: List[int]
num_res_blocks: int
attention_resolution: List[int]
resolution: int
dropout: float
@dataclass
class DiscriminatorConfig:
num_layers: int
filters: int
@dataclass
class DiscriminatorLossConfig:
loss: Literal["hinge, vanilla"]
factor: float
iter_start: int
weight: float
@dataclass
class VQVAELossConfig:
codebook_weight: float
perceptual_weight: float
@dataclass
class LossConfig:
discriminator: DiscriminatorLossConfig
vqvae: VQVAELossConfig
perceptual_loss: str
@dataclass
class ModelConfig:
vqvae_config: VQVAEConfig
autoencoder_config: AutoencoderConfig
discriminator_config: DiscriminatorConfig
loss_config: LossConfig
|