File size: 2,152 Bytes
6f199b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :EMO_digitalhuman 
@File    :wav_clip.py
@Author  :juzhen.czy
@Date    :2024/3/4 19:04 
'''
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import torch
from torch import nn
import librosa
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange, repeat


class Wav2Vec(ModelMixin):
    def __init__(self, model_path):
        super(Wav2Vec, self).__init__()
        self.processor = Wav2Vec2Processor.from_pretrained(model_path)
        self.wav2Vec = Wav2Vec2Model.from_pretrained(model_path)
        self.wav2Vec.eval()

    def forward(self, x):
        with torch.no_grad():
            return self.wav2Vec(x).last_hidden_state

    # def forward(self, x):
    #     return self.wav2Vec(x).last_hidden_state

    def process(self, x):
        return self.processor(x, sampling_rate=16000, return_tensors="pt").input_values.to(self.device)

class AudioFeatureMapper(ModelMixin):
    def __init__(self, input_num=15, output_num=77, model_path=None):
        super(AudioFeatureMapper, self).__init__()
        self.linear = nn.Linear(input_num, output_num)
        if model_path is not None:
            self.load_state_dict(torch.load(model_path))

    def forward(self, x):
        # print(x.shape)
        result = self.linear(x.permute(0, 2, 1))
        result = result.permute(0, 2, 1)
        # result = self.linear(x)
        return result

def test():
    #加载模型
    model_path = "/ossfs/workspace/projects/model_weights/Moore-AnimateAnyone/wav2vec2-base-960h"
    model = Wav2Vec(model_path)
    print("### model loaded ###")
    #加载音频
    audio_path = "/ossfs/workspace/projects/Moore-AnimateAnyone-master/assets/taken_clip.wav"
    input_audio, rate = librosa.load(audio_path, sr=16000)
    print(f"输入shape: {input_audio.shape}, rate: {rate}")

    # 预处理, 维度变为 (1, input_audio.shape[0]), 增加了一个维度, 声音信号长度本身没有变
    input_v = model.process(input_audio)

    # 输出结果为
    out = model(input_v)
    print(f"输入shape: {input_v.shape}, 输出shape: {out.shape}")