| from transformers import PreTrainedModel | |
| from transformers import AutoConfig, AutoModel | |
| from .model import GRAMT | |
| from .configuration_gramt_binaural_time import GRAMTBinauralTimeConfig | |
| class GRAMTBinauralTimeModel(PreTrainedModel): | |
| config_class = GRAMTBinauralTimeConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = GRAMT( | |
| in_channels = config.in_channels, | |
| decoder_mlp_ratio = config.decoder_mlp_ratio, | |
| decoder_depth = config.decoder_depth, | |
| decoder_num_heads = config.decoder_num_heads, | |
| decoder_embedding_dim = config.decoder_embedding_dim, | |
| decoder_window_sizes = config.decoder_window_sizes, | |
| encoder_num_layers = config.encoder_num_layers, | |
| encoder_num_heads = config.encoder_num_heads, | |
| encoder_hidden_dim = config.encoder_hidden_dim, | |
| encoder_mlp_ratio = config.encoder_mlp_ratio, | |
| encoder_dropout = config.encoder_dropout, | |
| encoder_attention_dropout = config.encoder_attention_dropout, | |
| encoder_norm_layer_eps = config.encoder_norm_layer_eps, | |
| patch_size = config.patch_size, | |
| frequency_stride = config.frequency_stride, | |
| time_stride = config.time_stride, | |
| max_length = config.max_length, | |
| num_mel_bins = config.num_mel_bins | |
| ) | |
| def forward(self, tensor, strategy = "raw"): | |
| return self.model.get_audio_representation(tensor, strategy = strategy) | |
| gram = GRAMTBinauralTimeModel(GRAMTBinauralTimeConfig()) | |
| AutoConfig.register("gramt-binaural-time", GRAMTBinauralTimeConfig) | |
| AutoModel.register(GRAMTBinauralTimeConfig, GRAMTBinauralTimeModel) | |