from transformers import PreTrainedModel from transformers import AutoConfig, AutoModel from .model import GRAMT from .configuration_gramt_binaural_time import GRAMTBinauralTimeConfig class GRAMTBinauralTimeModel(PreTrainedModel): config_class = GRAMTBinauralTimeConfig def __init__(self, config): super().__init__(config) self.model = GRAMT( in_channels = config.in_channels, decoder_mlp_ratio = config.decoder_mlp_ratio, decoder_depth = config.decoder_depth, decoder_num_heads = config.decoder_num_heads, decoder_embedding_dim = config.decoder_embedding_dim, decoder_window_sizes = config.decoder_window_sizes, encoder_num_layers = config.encoder_num_layers, encoder_num_heads = config.encoder_num_heads, encoder_hidden_dim = config.encoder_hidden_dim, encoder_mlp_ratio = config.encoder_mlp_ratio, encoder_dropout = config.encoder_dropout, encoder_attention_dropout = config.encoder_attention_dropout, encoder_norm_layer_eps = config.encoder_norm_layer_eps, patch_size = config.patch_size, frequency_stride = config.frequency_stride, time_stride = config.time_stride, max_length = config.max_length, num_mel_bins = config.num_mel_bins ) def forward(self, tensor, strategy = "raw"): return self.model.get_audio_representation(tensor, strategy = strategy) gram = GRAMTBinauralTimeModel(GRAMTBinauralTimeConfig()) AutoConfig.register("gramt-binaural-time", GRAMTBinauralTimeConfig) AutoModel.register(GRAMTBinauralTimeConfig, GRAMTBinauralTimeModel)