-
-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathconfig.py
153 lines (127 loc) · 5.04 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Any, Optional, Union, Tuple
import numpy as np
import torch
from typing_extensions import Literal
from .. import models as m
from .. import utils
from ..audio import FilePath, AudioLoader
class BasePipelineConfig:
@property
def duration(self) -> float:
raise NotImplementedError
@property
def step(self) -> float:
raise NotImplementedError
@property
def latency(self) -> float:
raise NotImplementedError
@property
def sample_rate(self) -> int:
raise NotImplementedError
@staticmethod
def from_dict(data: Any) -> 'BasePipelineConfig':
raise NotImplementedError
def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
right = utils.get_padding_right(self.latency, self.step)
left = utils.get_padding_left(file_duration + right, self.duration)
return left, right
def optimal_block_size(self) -> int:
return int(np.rint(self.step * self.sample_rate))
class PipelineConfig(BasePipelineConfig):
def __init__(
self,
segmentation: Optional[m.SegmentationModel] = None,
embedding: Optional[m.EmbeddingModel] = None,
duration: Optional[float] = None,
step: float = 0.5,
latency: Optional[Union[float, Literal["max", "min"]]] = None,
tau_active: float = 0.6,
rho_update: float = 0.3,
delta_new: float = 1,
gamma: float = 3,
beta: float = 10,
max_speakers: int = 20,
device: Optional[torch.device] = None,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
if self.segmentation is None:
self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
# Default duration is the one given by the segmentation model
self._duration = duration
# Expected sample rate is given by the segmentation model
self._sample_rate: Optional[int] = None
# Default embedding model is pyannote/embedding
self.embedding = embedding
if self.embedding is None:
self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
# Latency defaults to the step duration
self._step = step
self._latency = latency
if self._latency is None or self._latency == "min":
self._latency = self._step
elif self._latency == "max":
self._latency = self._duration
self.tau_active = tau_active
self.rho_update = rho_update
self.delta_new = delta_new
self.gamma = gamma
self.beta = beta
self.max_speakers = max_speakers
self.device = device
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@staticmethod
def from_dict(data: Any) -> 'PipelineConfig':
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
device = utils.get(data, "device", None)
if device is None:
device = torch.device("cpu") if utils.get(data, "cpu", False) else None
# Instantiate models
hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
embedding = utils.get(data, "embedding", "pyannote/embedding")
embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
# Hyper-parameters and their aliases
tau = utils.get(data, "tau_active", None)
if tau is None:
tau = utils.get(data, "tau", 0.6)
rho = utils.get(data, "rho_update", None)
if rho is None:
rho = utils.get(data, "rho", 0.3)
delta = utils.get(data, "delta_new", None)
if delta is None:
delta = utils.get(data, "delta", 1)
return PipelineConfig(
segmentation=segmentation,
embedding=embedding,
duration=utils.get(data, "duration", None),
step=utils.get(data, "step", 0.5),
latency=utils.get(data, "latency", None),
tau_active=tau,
rho_update=rho,
delta_new=delta,
gamma=utils.get(data, "gamma", 3),
beta=utils.get(data, "beta", 10),
max_speakers=utils.get(data, "max_speakers", 20),
device=device,
)
@property
def duration(self) -> float:
if self._duration is None:
self._duration = self.segmentation.duration
return self._duration
@property
def step(self) -> float:
return self._step
@property
def latency(self) -> float:
return self._latency
@property
def sample_rate(self) -> int:
if self._sample_rate is None:
self._sample_rate = self.segmentation.sample_rate
return self._sample_rate