wav2vec2-2-bart-base / load_and_save.py
cromz22
upload model files
3665b39
#!/usr/bin/env python3
# from transformers.utils.dummy_pt_objects import SpeechEncoderDecoderModel
# from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
# from transformers.models.auto.tokenization_auto import AutoTokenizer
# from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor
from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor
# checkpoints to leverage
encoder_id = "facebook/wav2vec2-base"
decoder_id = "facebook/bart-base"
# load and save speech-encoder-decoder model
# set some hyper-parameters for training and evaluation
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_id,
decoder_id,
encoder_add_adapter=True,
encoder_feat_proj_dropout=0.0,
encoder_layerdrop=0.0,
max_length=200,
num_beams=5,
)
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
model.config.pad_token_id = model.decoder.config.pad_token_id
model.config.eos_token_id = model.decoder.config.eos_token_id
model.save_pretrained("./")
# load and save processor
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
processor.save_pretrained("./")