|
--- |
|
language: |
|
- en |
|
license: mit |
|
base_model: |
|
- facebook/wav2vec2-base |
|
--- |
|
|
|
SCD(Speaker Change Detection,讲者变化检测):是指在音频或视频内容中识别出讲话者发生变化的技术。它通常被应用于多讲者的对话或演讲场景中,以此来检测何时从一个讲者切换到另一个讲者。 |
|
|
|
如何使用 |
|
# Note: at the time this code was originally written, transformers.Wav2Vec2ForAudioFrameClassification was incomplete |
|
# -> this adds the then-missing parts |
|
class Wav2Vec2ForAudioFrameClassification_custom(transformers.Wav2Vec2ForAudioFrameClassification, |
|
PyTorchModelHubMixin, |
|
repo_url="your-repo-url", |
|
pipeline_tag="text-to-image", |
|
license="mit",): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter: |
|
raise ValueError( |
|
"Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)" |
|
) |
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
labels=None, # ADDED |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.wav2vec2( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
logits = self.classifier(hidden_states) |
|
labels = labels.reshape(-1,1) # 1xN -> Nx1 |
|
|
|
# ADDED |
|
loss = None |
|
if labels is not None: |
|
if self.num_labels == 1: |
|
loss_fct = MSELoss() |
|
#loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels) |
|
else: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|