unclecode sergeipetrov commited on
Commit
3ef7085
0 Parent(s):

Duplicate from sergeipetrov/asrdiarization-handler

Browse files

Co-authored-by: Sergei Petrov <sergeipetrov@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +23 -0
  3. config.py +33 -0
  4. diarization_utils.py +141 -0
  5. handler.py +103 -0
  6. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ASR+Diarization handler that works natively with Inference Endpoints.
2
+
3
+ Example payload:
4
+ ```python
5
+ import base64
6
+ import requests
7
+
8
+ API_URL = "<your endpoint URL>"
9
+ filepath = "/path/to/audio"
10
+
11
+ with open(filepath, 'rb') as f:
12
+ audio_encoded = base64.b64encode(f.read()).decode("utf-8")
13
+
14
+ data = {
15
+ "inputs": audio_encoded,
16
+ "parameters": {
17
+ "batch_size": 24
18
+ }
19
+ }
20
+
21
+ resp = requests.post(API_URL, json=data, headers={"Authorization": "Bearer <your token>"})
22
+ print(resp.json())
23
+ ```
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from pydantic import BaseModel
4
+ from pydantic_settings import BaseSettings
5
+ from typing import Optional, Literal
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class ModelSettings(BaseSettings):
11
+ asr_model: str
12
+ assistant_model: Optional[str]
13
+ diarization_model: Optional[str]
14
+ hf_token: Optional[str]
15
+
16
+
17
+ class InferenceConfig(BaseModel):
18
+ task: Literal["transcribe", "translate"] = "transcribe"
19
+ batch_size: int = 24
20
+ assisted: bool = False
21
+ chunk_length_s: int = 30
22
+ sampling_rate: int = 16000
23
+ language: Optional[str] = None
24
+ num_speakers: Optional[int] = None
25
+ min_speakers: Optional[int] = None
26
+ max_speakers: Optional[int] = None
27
+
28
+
29
+ model_settings = ModelSettings()
30
+
31
+ logger.info(f"asr model: {model_settings.asr_model}")
32
+ logger.info(f"assist model: {model_settings.assistant_model}")
33
+ logger.info(f"diar model: {model_settings.diarization_model}")
diarization_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchaudio import functional as F
4
+ from transformers.pipelines.audio_utils import ffmpeg_read
5
+ from starlette.exceptions import HTTPException
6
+ import sys
7
+
8
+ # Code from insanely-fast-whisper:
9
+ # https://github.com/Vaibhavs10/insanely-fast-whisper
10
+
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def preprocess_inputs(inputs, sampling_rate):
15
+ inputs = ffmpeg_read(inputs, sampling_rate)
16
+
17
+ if sampling_rate != 16000:
18
+ inputs = F.resample(
19
+ torch.from_numpy(inputs), sampling_rate, 16000
20
+ ).numpy()
21
+
22
+ if len(inputs.shape) != 1:
23
+ logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
+ raise HTTPException(
25
+ status_code=400,
26
+ detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
+ )
28
+
29
+ # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
30
+ diarizer_inputs = torch.from_numpy(inputs).float()
31
+ diarizer_inputs = diarizer_inputs.unsqueeze(0)
32
+
33
+ return inputs, diarizer_inputs
34
+
35
+
36
+ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
37
+ diarization = diarization_pipeline(
38
+ {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
39
+ num_speakers=parameters.num_speakers,
40
+ min_speakers=parameters.min_speakers,
41
+ max_speakers=parameters.max_speakers,
42
+ )
43
+
44
+ segments = []
45
+ for segment, track, label in diarization.itertracks(yield_label=True):
46
+ segments.append(
47
+ {
48
+ "segment": {"start": segment.start, "end": segment.end},
49
+ "track": track,
50
+ "label": label,
51
+ }
52
+ )
53
+
54
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
+ new_segments = []
57
+ prev_segment = cur_segment = segments[0]
58
+
59
+ for i in range(1, len(segments)):
60
+ cur_segment = segments[i]
61
+
62
+ # check if we have changed speaker ("label")
63
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
+ # add the start/end times for the super-segment to the new list
65
+ new_segments.append(
66
+ {
67
+ "segment": {
68
+ "start": prev_segment["segment"]["start"],
69
+ "end": cur_segment["segment"]["start"],
70
+ },
71
+ "speaker": prev_segment["label"],
72
+ }
73
+ )
74
+ prev_segment = segments[i]
75
+
76
+ # add the last segment(s) if there was no speaker change
77
+ new_segments.append(
78
+ {
79
+ "segment": {
80
+ "start": prev_segment["segment"]["start"],
81
+ "end": cur_segment["segment"]["end"],
82
+ },
83
+ "speaker": prev_segment["label"],
84
+ }
85
+ )
86
+
87
+ return new_segments
88
+
89
+
90
+ def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
+ # get the end timestamps for each chunk from the ASR output
92
+ end_timestamps = np.array(
93
+ [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
+ segmented_preds = []
95
+
96
+ # align the diarizer timestamps and the ASR timestamps
97
+ for segment in new_segments:
98
+ # get the diarizer end timestamp
99
+ end_time = segment["segment"]["end"]
100
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
+
103
+ if group_by_speaker:
104
+ segmented_preds.append(
105
+ {
106
+ "speaker": segment["speaker"],
107
+ "text": "".join(
108
+ [chunk["text"] for chunk in transcript[: upto_idx + 1]]
109
+ ),
110
+ "timestamp": (
111
+ transcript[0]["timestamp"][0],
112
+ transcript[upto_idx]["timestamp"][1],
113
+ ),
114
+ }
115
+ )
116
+ else:
117
+ for i in range(upto_idx + 1):
118
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
+
120
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
+ transcript = transcript[upto_idx + 1:]
122
+ end_timestamps = end_timestamps[upto_idx + 1:]
123
+
124
+ if len(end_timestamps) == 0:
125
+ break
126
+
127
+ return segmented_preds
128
+
129
+
130
+ def diarize(diarization_pipeline, file, parameters, asr_outputs):
131
+ _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
+
133
+ segments = diarize_audio(
134
+ diarizer_inputs,
135
+ diarization_pipeline,
136
+ parameters
137
+ )
138
+
139
+ return post_process_segments_and_transcripts(
140
+ segments, asr_outputs["chunks"], group_by_speaker=False
141
+ )
handler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import os
4
+ import base64
5
+
6
+ from pyannote.audio import Pipeline
7
+ from transformers import pipeline, AutoModelForCausalLM
8
+ from diarization_utils import diarize
9
+ from huggingface_hub import HfApi
10
+ from pydantic import ValidationError
11
+ from starlette.exceptions import HTTPException
12
+
13
+ from config import model_settings, InferenceConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EndpointHandler():
19
+ def __init__(self, path=""):
20
+
21
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
+ logger.info(f"Using device: {device.type}")
23
+ torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
24
+
25
+ self.assistant_model = AutoModelForCausalLM.from_pretrained(
26
+ model_settings.assistant_model,
27
+ torch_dtype=torch_dtype,
28
+ low_cpu_mem_usage=True,
29
+ use_safetensors=True
30
+ ) if model_settings.assistant_model else None
31
+
32
+ if self.assistant_model:
33
+ self.assistant_model.to(device)
34
+
35
+ self.asr_pipeline = pipeline(
36
+ "automatic-speech-recognition",
37
+ model=model_settings.asr_model,
38
+ torch_dtype=torch_dtype,
39
+ device=device
40
+ )
41
+
42
+ if model_settings.diarization_model:
43
+ # diarization pipeline doesn't raise if there is no token
44
+ HfApi().whoami(model_settings.hf_token)
45
+ self.diarization_pipeline = Pipeline.from_pretrained(
46
+ checkpoint_path=model_settings.diarization_model,
47
+ use_auth_token=model_settings.hf_token,
48
+ )
49
+ self.diarization_pipeline.to(device)
50
+ else:
51
+ self.diarization_pipeline = None
52
+
53
+
54
+ def __call__(self, inputs):
55
+ file = inputs.pop("inputs")
56
+ file = base64.b64decode(file)
57
+ parameters = inputs.pop("parameters", {})
58
+ try:
59
+ parameters = InferenceConfig(**parameters)
60
+ except ValidationError as e:
61
+ logger.error(f"Error validating parameters: {e}")
62
+ raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
63
+
64
+ logger.info(f"inference parameters: {parameters}")
65
+
66
+ generate_kwargs = {
67
+ "task": parameters.task,
68
+ "language": parameters.language,
69
+ "assistant_model": self.assistant_model if parameters.assisted else None
70
+ }
71
+
72
+ try:
73
+ asr_outputs = self.asr_pipeline(
74
+ file,
75
+ chunk_length_s=parameters.chunk_length_s,
76
+ batch_size=parameters.batch_size,
77
+ generate_kwargs=generate_kwargs,
78
+ return_timestamps=True,
79
+ )
80
+ except RuntimeError as e:
81
+ logger.error(f"ASR inference error: {str(e)}")
82
+ raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
83
+ except Exception as e:
84
+ logger.error(f"Unknown error diring ASR inference: {str(e)}")
85
+ raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
86
+
87
+ if self.diarization_pipeline:
88
+ try:
89
+ transcript = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
90
+ except RuntimeError as e:
91
+ logger.error(f"Diarization inference error: {str(e)}")
92
+ raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
93
+ except Exception as e:
94
+ logger.error(f"Unknown error during diarization: {str(e)}")
95
+ raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
96
+ else:
97
+ transcript = []
98
+
99
+ return {
100
+ "speakers": transcript,
101
+ "chunks": asr_outputs["chunks"],
102
+ "text": asr_outputs["text"],
103
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.2
2
+ torch==2.2.1
3
+ pyannote-audio==3.1.1
4
+ transformers==4.38.2
5
+ numpy==1.26.4
6
+ torchaudio==2.2.1
7
+ pydantic==2.6.3
8
+ pydantic-settings==2.2.1