Skip to content

Commit

Permalink
make diarization faster
Browse files Browse the repository at this point in the history
  • Loading branch information
davidas1 authored Aug 2, 2023
1 parent d80b986 commit 8de0e2a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ print(result["segments"]) # after alignment
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio_file)
diarize_segments = diarize_model(audio)
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)

result = whisperx.assign_word_speakers(diarize_segments, result)
Expand Down
8 changes: 7 additions & 1 deletion whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Optional, Union
import torch

from .audio import SAMPLE_RATE

class DiarizationPipeline:
def __init__(
self,
Expand All @@ -16,7 +18,11 @@ def __init__(
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)

def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
Expand Down
3 changes: 2 additions & 1 deletion whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def cli():
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
audio = load_audio(input_audio_path)
diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
Expand Down

0 comments on commit 8de0e2a

Please sign in to comment.