From c255d584807dfc436cc976fce48f3ebcf1a93f5e Mon Sep 17 00:00:00 2001 From: Bertrand Higy Date: Tue, 18 Oct 2022 15:02:18 +0000 Subject: [PATCH 1/5] added cropping_mode to DelayedAggregation --- src/diart/blocks/aggregation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index 3b5e6e1f..be642593 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -102,6 +102,8 @@ class DelayedAggregation: "mean": simple average "hamming": average weighted by the Hamming window values (aligned to the buffer) "any": no aggregation, pick the first overlapping window + cropping_mode: ("strict", "loose", "center"), optional + Defines the cropping mode. Defaults to "loose". Example -------- @@ -130,10 +132,12 @@ def __init__( step: float, latency: Optional[float] = None, strategy: Literal["mean", "hamming", "first"] = "hamming", + cropping_mode: Literal["strict", "loose", "center"] = "loose" ): self.step = step self.latency = latency self.strategy = strategy + self.cropping_mode = cropping_mode if self.latency is None: self.latency = self.step @@ -159,7 +163,7 @@ def _prepend( num_frames = output_window.data.shape[0] first_region = Segment(0, output_region.end) first_output = buffers[0].crop( - first_region, fixed=first_region.duration + first_region, mode=self.cropping_mode, fixed=first_region.duration ) first_output[-num_frames:] = output_window.data resolution = output_region.end / first_output.shape[0] From a4f735246a93f27d677f71c8177c9c3f5e23aff2 Mon Sep 17 00:00:00 2001 From: Bertrand Higy Date: Wed, 19 Oct 2022 10:16:14 +0200 Subject: [PATCH 2/5] added assert for cropping mode Co-authored-by: Juan Coria --- src/diart/blocks/aggregation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index be642593..41edbe78 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -137,6 +137,7 @@ def __init__( self.step = step self.latency = latency self.strategy = strategy + assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" self.cropping_mode = cropping_mode if self.latency is None: From e7d6dd7b056e45c56c2e84b37fd3f9de85579afc Mon Sep 17 00:00:00 2001 From: Bertrand Higy Date: Wed, 19 Oct 2022 08:31:55 +0000 Subject: [PATCH 3/5] added details about the cropping mode to the documentation of DelayedAggregation --- src/diart/blocks/aggregation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index 41edbe78..3c121915 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -104,6 +104,9 @@ class DelayedAggregation: "any": no aggregation, pick the first overlapping window cropping_mode: ("strict", "loose", "center"), optional Defines the cropping mode. Defaults to "loose". + This will define the value of the mode parameter used in SlidingWindowFeature.crop + (from pyannote.core, see https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop + for more details). Example -------- From 4866586c6b458811b4b152842a929b234bb4e369 Mon Sep 17 00:00:00 2001 From: Bertrand Higy Date: Wed, 19 Oct 2022 08:47:57 +0000 Subject: [PATCH 4/5] added cropping_mode to AggregationStrategy and children --- src/diart/blocks/aggregation.py | 34 ++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index 3c121915..38673262 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -6,18 +6,34 @@ class AggregationStrategy: - """Abstract class representing a strategy to aggregate overlapping buffers""" + """Abstract class representing a strategy to aggregate overlapping buffers + + Parameters + ---------- + cropping_mode: ("strict", "loose", "center"), optional + Defines the cropping mode. Defaults to "loose". + This will define the value of the mode parameter used in SlidingWindowFeature.crop + (from pyannote.core, see https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop + for more details). + """ + + def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"): + assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`" + self.cropping_mode = cropping_mode @staticmethod - def build(name: Literal["mean", "hamming", "first"]) -> 'AggregationStrategy': + def build( + name: Literal["mean", "hamming", "first"], + cropping_mode: Literal["strict", "loose", "center"] = "loose" + ) -> 'AggregationStrategy': """Build an AggregationStrategy instance based on its name""" assert name in ("mean", "hamming", "first") if name == "mean": - return AverageStrategy() + return AverageStrategy(cropping_mode) elif name == "hamming": - return HammingWeightedAverageStrategy() + return HammingWeightedAverageStrategy(cropping_mode) else: - return FirstOnlyStrategy() + return FirstOnlyStrategy(cropping_mode) def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature: """Aggregate chunks over a specific region. @@ -55,11 +71,11 @@ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.n hamming, intersection = [], [] for buffer in buffers: # Crop buffer to focus region - b = buffer.crop(focus, fixed=focus.duration) + b = buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) # Crop Hamming window to focus region h = np.expand_dims(np.hamming(num_frames), axis=-1) h = SlidingWindowFeature(h, buffer.sliding_window) - h = h.crop(focus, fixed=focus.duration) + h = h.crop(focus, mode=self.cropping_mode, fixed=focus.duration) hamming.append(h.data) intersection.append(b.data) hamming, intersection = np.stack(hamming), np.stack(intersection) @@ -73,7 +89,7 @@ class AverageStrategy(AggregationStrategy): def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: # Stack all overlapping regions intersection = np.stack([ - buffer.crop(focus, fixed=focus.duration) + buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) for buffer in buffers ]) return np.mean(intersection, axis=0) @@ -83,7 +99,7 @@ class FirstOnlyStrategy(AggregationStrategy): """Instead of aggregating, keep the first focus region in the buffer list""" def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - return buffers[0].crop(focus, fixed=focus.duration) + return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration) class DelayedAggregation: From 6e0f27a09ffd31e6d49356b9b746f2213a2a1c70 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Oct 2022 11:18:04 +0200 Subject: [PATCH 5/5] Add missing cropping_mode in aggregation calls --- src/diart/blocks/aggregation.py | 20 +++++++++----------- src/diart/blocks/diarization.py | 2 ++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py index 38673262..b6352a28 100644 --- a/src/diart/blocks/aggregation.py +++ b/src/diart/blocks/aggregation.py @@ -11,10 +11,9 @@ class AggregationStrategy: Parameters ---------- cropping_mode: ("strict", "loose", "center"), optional - Defines the cropping mode. Defaults to "loose". - This will define the value of the mode parameter used in SlidingWindowFeature.crop - (from pyannote.core, see https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop - for more details). + Defines the mode to crop buffer chunks as in pyannote.core. + See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop + Defaults to "loose". """ def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"): @@ -113,16 +112,15 @@ class DelayedAggregation: latency: float, optional Desired latency, in seconds. Defaults to step. The higher the latency, the more overlapping windows to aggregate. - strategy: ("mean", "hamming", "any"), optional + strategy: ("mean", "hamming", "first"), optional Specifies how to aggregate overlapping windows. Defaults to "hamming". "mean": simple average "hamming": average weighted by the Hamming window values (aligned to the buffer) - "any": no aggregation, pick the first overlapping window + "first": no aggregation, pick the first overlapping window cropping_mode: ("strict", "loose", "center"), optional - Defines the cropping mode. Defaults to "loose". - This will define the value of the mode parameter used in SlidingWindowFeature.crop - (from pyannote.core, see https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop - for more details). + Defines the mode to crop buffer chunks as in pyannote.core. + See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop + Defaults to "loose". Example -------- @@ -165,7 +163,7 @@ def __init__( assert self.step <= self.latency, "Invalid latency requested" self.num_overlapping_windows = int(round(self.latency / self.step)) - self.aggregate = AggregationStrategy.build(self.strategy) + self.aggregate = AggregationStrategy.build(self.strategy, self.cropping_mode) def _prepend( self, diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index a53ee213..c889f24d 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -100,11 +100,13 @@ def __init__(self, config: Optional[PipelineConfig] = None): self.config.step, self.config.latency, strategy="hamming", + cropping_mode="loose", ) self.audio_aggregation = DelayedAggregation( self.config.step, self.config.latency, strategy="first", + cropping_mode="center", ) self.binarize = Binarize(self.config.tau_active)