Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark script + improvements and bug fixes #46

Merged
merged 23 commits into from
May 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
863b9bb
Start refactoring for batched diarization pipeline
juanmc2005 Apr 22, 2022
d6a4d4b
Batchify FrameWiseModel and ChunkWiseModel
juanmc2005 Apr 22, 2022
35a4ffe
Add batched pipeline implementation
juanmc2005 Apr 25, 2022
6ab9a21
Move pre-calculated pipeline to OnlineSpeakerDiarization.from_file()
juanmc2005 Apr 25, 2022
fe1d5e8
Add argument to skip plotting for faster inference in demo script
juanmc2005 Apr 25, 2022
63a04a6
Merge branch 'develop' of github.com:juanmc2005/StreamingSpeakerDiari…
juanmc2005 Apr 26, 2022
3907efb
Remove empty line
juanmc2005 Apr 26, 2022
66a15ac
Add benchmark script. Add optional verbosity to from_file(). Add tqdm…
juanmc2005 Apr 26, 2022
97a4f43
Dumb down PipelineConfig. Make sample rate completely depend on the s…
juanmc2005 Apr 27, 2022
63f24a8
Fix segmentation resolution not being adapted to chunk duration
juanmc2005 Apr 27, 2022
dff01e6
Add DER evaluation to benchmark script. Add FileAudioSource parameter…
juanmc2005 Apr 27, 2022
974e5c2
Add optional processing time profiling in FileAudioSource
juanmc2005 Apr 27, 2022
0aa0778
Add GPU support in demo and benchmarking
juanmc2005 Apr 27, 2022
fdc6f04
Make reference optional in benchmarking script
juanmc2005 Apr 27, 2022
7656f32
Calculate number of chunks from duration instead of samples in ChunkL…
juanmc2005 May 2, 2022
ab41ebd
Fix bug in batched pipeline: an edge case was causing the batch dimen…
juanmc2005 May 3, 2022
c1e725c
Fix bug in from_file(): segmentation and embedding remove batch dimen…
juanmc2005 May 3, 2022
354cd29
Fix end time bug in batched pipeline
juanmc2005 May 4, 2022
d228d55
Centralize stream end time calculation
juanmc2005 May 4, 2022
2981143
Add diart.benchmark in readme file
juanmc2005 May 4, 2022
11009d2
Add pyannote.metrics performance report in diart.benchmark
juanmc2005 May 4, 2022
7b9291e
Add progress bar to demo script
juanmc2005 May 4, 2022
6b67831
Fix method docstring
juanmc2005 May 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ torch.Size([4, 512])
1) Create environment:

```shell
conda create -n diarization python==3.8
conda create -n diarization python=3.8
conda activate diarization
```

Expand Down Expand Up @@ -130,7 +130,10 @@ Awaiting paper publication (ASRU 2021).

![Results table](/table1.png)

To reproduce the results of the paper, use the following hyper-parameters:
Diart aims to be lightweight and capable of real-time streaming in practical scenarios.
Its performance is very close to what is reported in the paper (and sometimes even a bit better).

To obtain the best results, make sure to use the following hyper-parameters:

Dataset | latency | tau | rho | delta
------------|---------|--------|--------|------
Expand All @@ -140,28 +143,16 @@ VoxConverse | any | 0.576 | 0.915 | 0.648
DIHARD II | 1s | 0.619 | 0.326 | 0.997
DIHARD II | 5s | 0.555 | 0.422 | 1.517

For instance, for a DIHARD III configuration:
`diart.benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:

```shell
python -m diart.demo /path/to/file.wav --tau=0.555 --rho=0.422 --delta=1.517 --output /output/dir
python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
```

And then to obtain the diarization error rate:

```python
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.database.util import load_rttm

metric = DiarizationErrorRate()
hypothesis = load_rttm("/output/dir/output.rttm")
hypothesis = list(hypothesis.values())[0] # Extract hypothesis from dictionary
reference = load_rttm("/path/to/reference.rttm")
reference = list(reference.values())[0] # Extract reference from dictionary

der = metric(reference, hypothesis)
```
`diart.benchmark` runs a faster inference and evaluation by pre-calculating model outputs in batches.
More options about benchmarking can be found by running `python -m diart.benchmark -h`.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) in RTTM format for every entry of Table 1 and Figure 5 in the paper. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.

![Figure 5](/figure5.png)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
einops>=0.3.0
tqdm>=4.64.0
git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
89 changes: 89 additions & 0 deletions src/diart/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
from pathlib import Path

import torch
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate

import diart.operators as dops
import diart.sources as src
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.sinks import RTTMWriter

# Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files <conversation>.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str, help="Directory with RTTM files <conversation>.rttm")
parser.add_argument("--step", default=0.5, type=float, help="Source sliding window step in seconds. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help="System latency in seconds. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help="Activity threshold tau active in [0,1]. Defaults to 0.5")
parser.add_argument("--rho", default=0.3, type=float, help="Speech ratio threshold rho update in [0,1]. Defaults to 0.3")
parser.add_argument("--delta", default=1, type=float, help="Maximum distance threshold delta new in [0,2]. Defaults to 1")
parser.add_argument("--gamma", default=3, type=float, help="Parameter gamma for overlapped speech penalty. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help="Parameter beta for overlapped speech penalty. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help="Maximum number of identifiable speakers. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If lower than 2, run fully online and estimate real-time latency. Defaults to 32")
parser.add_argument("--output", type=str, help="Output directory to store RTTM files. Defaults to `root`")
parser.add_argument("--gpu", dest="gpu", action="store_true", help="Add this flag to run on GPU")
args = parser.parse_args()

args.root = Path(args.root)
assert args.root.is_dir(), "Root argument must be a directory"
if args.reference is not None:
args.reference = Path(args.reference)
assert args.reference.is_dir(), "Reference argument must be a directory"
args.output = args.root if args.output is None else Path(args.output)
args.output.mkdir(parents=True, exist_ok=True)

# Define online speaker diarization pipeline
config = PipelineConfig(
step=args.step,
latency=args.latency,
tau_active=args.tau,
rho_update=args.rho,
delta_new=args.delta,
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
device=torch.device("cuda") if args.gpu else None,
)
pipeline = OnlineSpeakerDiarization(config)

# Run inference
chunk_loader = src.ChunkLoader(pipeline.sample_rate, pipeline.duration, config.step)
for filepath in args.root.expanduser().iterdir():
num_chunks = chunk_loader.num_chunks(filepath)

# Stream fully online if batch size is 1 or lower
source = None
if args.batch_size < 2:
source = src.FileAudioSource(
filepath,
filepath.stem,
src.RegularAudioFileReader(pipeline.sample_rate, pipeline.duration, config.step),
# Benchmark the processing time of a single chunk
profile=True,
)
observable = pipeline.from_source(source, output_waveform=False)
else:
observable = pipeline.from_file(filepath, batch_size=args.batch_size, verbose=True)

observable.pipe(
dops.progress(f"Streaming {filepath.stem}", total=num_chunks, leave=source is None)
).subscribe(
RTTMWriter(path=args.output / f"{filepath.stem}.rttm")
)

if source is not None:
source.read()

# Run evaluation
if args.reference is not None:
metric = DiarizationErrorRate(collar=0, skip_overlap=False)
for ref_path in args.reference.iterdir():
ref = load_rttm(ref_path).popitem()[1]
hyp = load_rttm(args.output / ref_path.name).popitem()[1]
metric(ref, hyp)
print()
metric.report(display=True)
print()
51 changes: 31 additions & 20 deletions src/diart/demo.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import argparse
from pathlib import Path

import rx.operators as ops
import torch

import diart.operators as dops
import diart.sources as src
import rx.operators as ops
from diart.pipelines import OnlineSpeakerDiarization
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.sinks import RealTimePlot, RTTMWriter

# Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'")
parser.add_argument("--step", default=0.5, type=float, help="Source sliding window step")
parser.add_argument("--latency", default=0.5, type=float, help="System latency")
parser.add_argument("--sample-rate", default=16000, type=int, help="Source sample rate")
parser.add_argument("--tau", default=0.5, type=float, help="Activity threshold tau active")
parser.add_argument("--rho", default=0.3, type=float, help="Speech duration threshold rho update")
parser.add_argument("--delta", default=1, type=float, help="Maximum distance threshold delta new")
parser.add_argument("--gamma", default=3, type=float, help="Parameter gamma for overlapped speech penalty")
parser.add_argument("--beta", default=10, type=float, help="Parameter beta for overlapped speech penalty")
parser.add_argument("--max-speakers", default=20, type=int, help="Maximum number of identifiable speakers")
parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference")
parser.add_argument("--gpu", dest="gpu", action="store_true", help="Add this flag to run on GPU")
parser.add_argument(
"--output", type=str,
help="Output directory to store the RTTM. Defaults to home directory "
Expand All @@ -27,7 +30,7 @@
args = parser.parse_args()

# Define online speaker diarization pipeline
pipeline = OnlineSpeakerDiarization(
config = PipelineConfig(
step=args.step,
latency=args.latency,
tau_active=args.tau,
Expand All @@ -36,36 +39,44 @@
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
device=torch.device("cuda") if args.gpu else None,
)
pipeline = OnlineSpeakerDiarization(config)

# Manage audio source
if args.source != "microphone":
args.source = Path(args.source).expanduser()
uri = args.source.name.split(".")[0]
output_dir = args.source.parent if args.output is None else Path(args.output)
audio_source = src.FileAudioSource(
file=args.source,
uri=uri,
uri=args.source.stem,
reader=src.RegularAudioFileReader(
args.sample_rate, pipeline.duration, pipeline.step
pipeline.sample_rate, pipeline.duration, config.step
),
)
else:
output_dir = Path("~/").expanduser() if args.output is None else Path(args.output)
audio_source = src.MicrophoneAudioSource(args.sample_rate)
audio_source = src.MicrophoneAudioSource(pipeline.sample_rate)

# Build pipeline from audio source and stream predictions to a real-time plot
pipeline.from_source(audio_source).pipe(
ops.do(RTTMWriter(path=output_dir / "output.rttm")),
dops.buffer_output(
duration=pipeline.duration,
step=pipeline.step,
latency=pipeline.latency,
sample_rate=audio_source.sample_rate
),
).subscribe(RealTimePlot(pipeline.duration, pipeline.latency))
# Build pipeline from audio source and stream predictions
rttm_writer = RTTMWriter(path=output_dir / f"{audio_source.uri}.rttm")
observable = pipeline.from_source(audio_source).pipe(
dops.progress(f"Streaming {audio_source.uri}", total=audio_source.length, leave=True)
)
if args.no_plot:
# Write RTTM file only
observable.subscribe(rttm_writer)
else:
# Write RTTM file + buffering and real-time plot
observable.pipe(
ops.do(rttm_writer),
dops.buffer_output(
duration=pipeline.duration,
step=config.step,
latency=config.latency,
sample_rate=pipeline.sample_rate
),
).subscribe(RealTimePlot(pipeline.duration, config.latency))

# Read audio source as a stream
if args.source == "microphone":
print("Recording...")
audio_source.read()
Loading