Skip to content

Commit

Permalink
fix: thread safety, avoid parallel trainings, delay initial training …
Browse files Browse the repository at this point in the history
…until more data has been gathered (core sends train bus message), reduce log spam
  • Loading branch information
JarbasAl committed Jan 26, 2025
1 parent 7ca51b5 commit e78dac4
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
16 changes: 14 additions & 2 deletions ovos_padatious/domain_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,18 @@ def __init__(self, cache_dir: Optional[str] = None, disable_padaos: bool = False
disable_padaos=disable_padaos)
self.domains: Dict[str, IntentContainer] = {}
self.training_data: Dict[str, List[str]] = defaultdict(list)
self.instantiate_from_disk()
self.must_train = True

def instantiate_from_disk(self) -> None:
"""
Instantiates the necessary (internal) data structures when loading persisted model from disk.
This is done via injecting entities and intents back from cached file versions.
"""
self.domain_engine.instantiate_from_disk()
for engine in self.domains.values():
engine.instantiate_from_disk()

def remove_domain(self, domain_name: str):
"""
Remove a domain and its associated intents and training data.
Expand All @@ -54,6 +64,8 @@ def add_domain_intent(self, domain_name: str, intent_name: str, intent_samples:
if domain_name not in self.domains:
self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir,
disable_padaos=self.disable_padaos)
self.domains[domain_name].instantiate_from_disk()

self.domains[domain_name].add_intent(intent_name, intent_samples)
self.training_data[domain_name] += intent_samples
self.must_train = True
Expand Down Expand Up @@ -165,11 +177,11 @@ def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains:
return sorted(matches, reverse=True, key=lambda k: k.conf)

def train(self):
for domain, samples in self.training_data.items():
for domain, samples in dict(self.training_data).items(): # copy for thread safety
LOG.debug(f"Training domain: {domain}")
self.domain_engine.add_intent(domain, samples)
self.domain_engine.train()
for domain in self.domains:
for domain in dict(self.domains): # copy for thread safety
LOG.debug(f"Training domain sub-intents: {domain}")
self.domains[domain].train()
self.must_train = False
49 changes: 29 additions & 20 deletions ovos_padatious/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def normalize_utterances(utterances: List[str], lang: str, cast_to_ascii: bool =
# Replace accented characters and punctuation if needed
if cast_to_ascii:
utterances = [remove_accents_and_punct(u) for u in utterances]
# strip punctuation marks, that just causes duplicate training data
utterances = [u.rstrip(string.punctuation) for u in utterances]
# Stem words if stemmer is provided
if stemmer is not None:
utterances = stemmer.stem_sentences(utterances)
Expand Down Expand Up @@ -274,6 +276,7 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
if engine_class is None and self.config.get("domain_engine"):
engine_class = DomainIntentContainer

self.remove_punct = self.config.get("cast_to_ascii", False)
use_stemmer = self.config.get("stem", False)
self.engine_class = engine_class or IntentContainer
intent_cache = expanduser(self.config.get('intent_cache') or
Expand All @@ -284,6 +287,8 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
intent_cache += "_domain"
if use_stemmer:
intent_cache += "_stemmer"
if self.remove_punct:
intent_cache += "_normalized"
self.containers = {lang: self.engine_class(cache_dir=f"{intent_cache}/{lang}",
disable_padaos=self.config.get("disable_padaos", False))
for lang in langs}
Expand All @@ -300,8 +305,9 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
else:
self.stemmers = {}

self.finished_training_event = Event() # DEPRECATED
self.finished_initial_train = False
self.first_train = Event()
self.finished_training_event = Event()
self.finished_training_event.set() # is cleared when training starts

self.registered_intents = []
self.registered_entities = []
Expand Down Expand Up @@ -348,7 +354,7 @@ def _match_level(self, utterances, limit, lang=None, message: Optional[Message]
utterances = normalize_utterances(utterances, lang,
stemmer=stemmer,
keep_order=True,
cast_to_ascii=self.config.get("cast_to_ascii", False))
cast_to_ascii=self.remove_punct)
padatious_intent = self.calc_intent(utterances, lang, message)
if padatious_intent is not None and padatious_intent.conf > limit:
skill_id = padatious_intent.name.split(':')[0]
Expand Down Expand Up @@ -391,26 +397,30 @@ def train(self, message=None):
Args:
message (Message): optional triggering message
"""
LOG.debug("Padatious training start")
if not any(engine.must_train for engine in self.containers.values()):
LOG.debug(f"Nothing new to train for padatious")
# inform the rest of the system to not wait for training finish
self.bus.emit(Message('mycroft.skills.trained'))
return

# wait for any already ongoing training
# padatious doesnt like threads
if not self.finished_training_event.is_set():
self.finished_training_event.wait()
with self.lock:
if not any(engine.must_train for engine in self.containers.values()):
# LOG.debug(f"Nothing new to train for padatious")
# inform the rest of the system to not wait for training finish
self.bus.emit(Message('mycroft.skills.trained'))
self.finished_training_event.set()
return
self.finished_training_event.clear()
# TODO - run this in subprocess?, sometimes fann2 segfaults and kills ovos-core...
for lang in self.containers:
if self.containers[lang].must_train:
LOG.debug(f"Training padatious for lang '{lang}'")
#LOG.debug(f"Training padatious for lang '{lang}'")
self.containers[lang].train()

LOG.debug(f"Training complete for padatious!")
if not self.finished_initial_train:
self.finished_initial_train = True
# inform the rest of the system to stop waiting for training finish
self.bus.emit(Message('mycroft.skills.trained'))
self.finished_training_event.set()

# inform the rest of the system to stop waiting for training finish
self.bus.emit(Message('mycroft.skills.trained'))
LOG.debug("Padatious training end")
if not self.first_train.is_set():
self.first_train.set()

@deprecated("'wait_and_train' has been deprecated, use 'train' directly", "2.0.0")
def wait_and_train(self):
Expand Down Expand Up @@ -493,15 +503,14 @@ def _register_object(self, message, object_name, register_func):
samples = normalize_utterances(samples, lang,
stemmer=stemmer,
keep_order=False,
cast_to_ascii=self.config.get("cast_to_ascii", False))
cast_to_ascii=self.remove_punct)

if self.engine_class == DomainIntentContainer:
register_func(skill_id, name, samples)
else:
register_func(name, samples)

self.finished_initial_train = False
if self.config.get("instant_train", True):
if self.config.get("instant_train", False) or self.first_train.is_set():
self.train(message)

def register_intent(self, message):
Expand Down
14 changes: 8 additions & 6 deletions ovos_padatious/training_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,17 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai
else:
hash_fn = join(self.cache, name + '.hash')
old_hsh = None
min_ver = splitext(ovos_padatious.__version__)[0]
new_hsh = lines_hash([min_ver] + lines)

if isfile(hash_fn):
with open(hash_fn, 'rb') as g:
old_hsh = g.read()
min_ver = splitext(ovos_padatious.__version__)[0]
new_hsh = lines_hash([min_ver] + lines)
if not old_hsh:
LOG.debug("First time training")
elif old_hsh and old_hsh != new_hsh:
LOG.debug(f"{name} Hash changed! retraining - {old_hsh} {new_hsh}")
if old_hsh != new_hsh:
LOG.debug(f"{name} training data changed! retraining")
else:
LOG.debug(f"First time training '{name}")

retrain = reload_cache or old_hsh != new_hsh
if not retrain:
try:
Expand Down

0 comments on commit e78dac4

Please sign in to comment.