scd / README.md
zhniu's picture
Upload Wav2Vec2ForAudioFrameClassification_custom
8171f1f verified
|
raw
history blame
3.47 kB
metadata
language:
  - en
license: mit
base_model:
  - facebook/wav2vec2-base

SCD(Speaker Change Detection,讲者变化检测):是指在音频或视频内容中识别出讲话者发生变化的技术。它通常被应用于多讲者的对话或演讲场景中,以此来检测何时从一个讲者切换到另一个讲者。

如何使用

Note: at the time this code was originally written, transformers.Wav2Vec2ForAudioFrameClassification was incomplete

-> this adds the then-missing parts

class Wav2Vec2ForAudioFrameClassification_custom(transformers.Wav2Vec2ForAudioFrameClassification, PyTorchModelHubMixin, repo_url="your-repo-url", pipeline_tag="text-to-image", license="mit",): def init(self, config): super().init(config) self.num_labels = config.num_labels

    if hasattr(config, "add_adapter") and config.add_adapter:
        raise ValueError(
            "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
        )
    self.wav2vec2 = Wav2Vec2Model(config)
    num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
    if config.use_weighted_layer_sum:
        self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    self.init_weights()

def forward(
    self,
    input_values,
    attention_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
    labels=None, # ADDED
):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

    outputs = self.wav2vec2(
        input_values,
        attention_mask=attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    if self.config.use_weighted_layer_sum:
        hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
        hidden_states = torch.stack(hidden_states, dim=1)
        norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
        hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
    else:
        hidden_states = outputs[0]

    logits = self.classifier(hidden_states)
    labels = labels.reshape(-1,1) # 1xN -> Nx1

    # ADDED
    loss = None
    if labels is not None:
        if self.num_labels == 1:
            loss_fct = MSELoss()
            #loss = loss_fct(logits.squeeze(), labels.squeeze())
            loss = loss_fct(logits.view(-1, self.num_labels), labels)
        else:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        

    if not return_dict:
        output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
        return ((loss,) + output) if loss is not None else output

    return TokenClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )