gramt-binaural-time / modeling_gramt_binaural_time.py
GokseninYuksel's picture
Upload model
f0e612b verified
from transformers import PreTrainedModel
from transformers import AutoConfig, AutoModel
from .model import GRAMT
from .configuration_gramt_binaural_time import GRAMTBinauralTimeConfig
class GRAMTBinauralTimeModel(PreTrainedModel):
config_class = GRAMTBinauralTimeConfig
def __init__(self, config):
super().__init__(config)
self.model = GRAMT(
in_channels = config.in_channels,
decoder_mlp_ratio = config.decoder_mlp_ratio,
decoder_depth = config.decoder_depth,
decoder_num_heads = config.decoder_num_heads,
decoder_embedding_dim = config.decoder_embedding_dim,
decoder_window_sizes = config.decoder_window_sizes,
encoder_num_layers = config.encoder_num_layers,
encoder_num_heads = config.encoder_num_heads,
encoder_hidden_dim = config.encoder_hidden_dim,
encoder_mlp_ratio = config.encoder_mlp_ratio,
encoder_dropout = config.encoder_dropout,
encoder_attention_dropout = config.encoder_attention_dropout,
encoder_norm_layer_eps = config.encoder_norm_layer_eps,
patch_size = config.patch_size,
frequency_stride = config.frequency_stride,
time_stride = config.time_stride,
max_length = config.max_length,
num_mel_bins = config.num_mel_bins
)
def forward(self, tensor, strategy = "raw"):
return self.model.get_audio_representation(tensor, strategy = strategy)
gram = GRAMTBinauralTimeModel(GRAMTBinauralTimeConfig())
AutoConfig.register("gramt-binaural-time", GRAMTBinauralTimeConfig)
AutoModel.register(GRAMTBinauralTimeConfig, GRAMTBinauralTimeModel)