Spaces:
Running
on
L40S
Running
on
L40S
File size: 2,152 Bytes
6f199b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :EMO_digitalhuman
@File :wav_clip.py
@Author :juzhen.czy
@Date :2024/3/4 19:04
'''
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import torch
from torch import nn
import librosa
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange, repeat
class Wav2Vec(ModelMixin):
def __init__(self, model_path):
super(Wav2Vec, self).__init__()
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
self.wav2Vec = Wav2Vec2Model.from_pretrained(model_path)
self.wav2Vec.eval()
def forward(self, x):
with torch.no_grad():
return self.wav2Vec(x).last_hidden_state
# def forward(self, x):
# return self.wav2Vec(x).last_hidden_state
def process(self, x):
return self.processor(x, sampling_rate=16000, return_tensors="pt").input_values.to(self.device)
class AudioFeatureMapper(ModelMixin):
def __init__(self, input_num=15, output_num=77, model_path=None):
super(AudioFeatureMapper, self).__init__()
self.linear = nn.Linear(input_num, output_num)
if model_path is not None:
self.load_state_dict(torch.load(model_path))
def forward(self, x):
# print(x.shape)
result = self.linear(x.permute(0, 2, 1))
result = result.permute(0, 2, 1)
# result = self.linear(x)
return result
def test():
#加载模型
model_path = "/ossfs/workspace/projects/model_weights/Moore-AnimateAnyone/wav2vec2-base-960h"
model = Wav2Vec(model_path)
print("### model loaded ###")
#加载音频
audio_path = "/ossfs/workspace/projects/Moore-AnimateAnyone-master/assets/taken_clip.wav"
input_audio, rate = librosa.load(audio_path, sr=16000)
print(f"输入shape: {input_audio.shape}, rate: {rate}")
# 预处理, 维度变为 (1, input_audio.shape[0]), 增加了一个维度, 声音信号长度本身没有变
input_v = model.process(input_audio)
# 输出结果为
out = model(input_v)
print(f"输入shape: {input_v.shape}, 输出shape: {out.shape}") |