diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/BuildJobOptions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/BuildJobOptions.cs index 0a7947d5..547a9dbd 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/BuildJobOptions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/BuildJobOptions.cs @@ -5,5 +5,4 @@ public class BuildJobOptions public const string Key = "BuildJob"; public IList<ClearMLBuildQueue> ClearML { get; set; } = new List<ClearMLBuildQueue>(); - public TimeSpan PostProcessLockLifetime { get; set; } = TimeSpan.FromSeconds(120); } diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index 4f60ae90..5a577cb5 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -378,6 +378,7 @@ public static IMachineBuilder AddServalTranslationEngineService( { options.Interceptors.Add<CancellationInterceptor>(); options.Interceptors.Add<UnimplementedInterceptor>(); + options.Interceptors.Add<TimeoutInterceptor>(); }); builder.AddServalPlatformService(connectionString); diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs index 15002604..58a425ff 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs @@ -7,4 +7,6 @@ public class SmtTransferEngineOptions public string EnginesDir { get; set; } = "translation_engines"; public TimeSpan EngineCommitFrequency { get; set; } = TimeSpan.FromMinutes(5); public TimeSpan InactiveEngineTimeout { get; set; } = TimeSpan.FromMinutes(10); + public TimeSpan SaveModelTimeout { get; set; } = TimeSpan.FromMinutes(5); + public TimeSpan EngineCommitTimeout { get; set; } = TimeSpan.FromMinutes(2); } diff --git a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs index 80b1f648..e3143a3c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs @@ -11,4 +11,5 @@ public record TranslationEngine : IEntity public required bool IsModelPersisted { get; init; } public int BuildRevision { get; init; } public Build? CurrentBuild { get; init; } + public bool? CollectTrainSegmentPairs { get; init; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs index 244aa04a..da670439 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs @@ -59,6 +59,7 @@ public async Task DeleteEngineAsync(string engineId, CancellationToken cancellat public async Task<bool> StartBuildJobAsync( BuildJobRunnerType runnerType, + TranslationEngineType engineType, string engineId, string buildId, BuildStage stage, @@ -67,18 +68,9 @@ public async Task<bool> StartBuildJobAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.GetAsync( - e => - e.EngineId == engineId - && (e.CurrentBuild == null || e.CurrentBuild.JobState != BuildJobState.Canceling), - cancellationToken - ); - if (engine is null) - return false; - IBuildJobRunner runner = _runners[runnerType]; string jobId = await runner.CreateJobAsync( - engine.Type, + engineType, engineId, buildId, stage, @@ -88,8 +80,17 @@ public async Task<bool> StartBuildJobAsync( ); try { - await _engines.UpdateAsync( - e => e.EngineId == engineId, + TranslationEngine? engine = await _engines.UpdateAsync( + e => + e.EngineId == engineId + && ( + (stage == BuildStage.Preprocess && e.CurrentBuild == null) + || ( + stage != BuildStage.Preprocess + && e.CurrentBuild != null + && e.CurrentBuild.JobState != BuildJobState.Canceling + ) + ), u => u.Set( e => e.CurrentBuild, @@ -105,6 +106,11 @@ await _engines.UpdateAsync( ), cancellationToken: cancellationToken ); + if (engine is null) + { + await runner.DeleteJobAsync(jobId, CancellationToken.None); + return false; + } await runner.EnqueueJobAsync(jobId, engine.Type, cancellationToken); return true; } @@ -120,44 +126,36 @@ await _engines.UpdateAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.GetAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - cancellationToken + // cancel a job that hasn't started yet + TranslationEngine? engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Pending, + u => + { + u.Unset(b => b.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); + }, + returnOriginal: true, + cancellationToken: cancellationToken ); - if (engine is null || engine.CurrentBuild is null) - return (null, BuildJobState.None); - - IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; - - if (engine.CurrentBuild.JobState is BuildJobState.Pending) + if (engine is not null && engine.CurrentBuild is not null) { - // cancel a job that hasn't started yet - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - u => u.Unset(b => b.CurrentBuild), - returnOriginal: true, - cancellationToken: cancellationToken - ); - if (engine is not null && engine.CurrentBuild is not null) - { - // job will be deleted from the queue - await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); - return (engine.CurrentBuild.BuildId, BuildJobState.None); - } + // job will be deleted from the queue + IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.None); } - else if (engine.CurrentBuild.JobState is BuildJobState.Active) + + // cancel a job that is already running + engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Active, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) { - // cancel a job that is already running - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), - cancellationToken: cancellationToken - ); - if (engine is not null && engine.CurrentBuild is not null) - { - await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); - return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); - } + IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); } return (null, BuildJobState.None); @@ -193,6 +191,7 @@ public Task BuildJobFinishedAsync( u => { u.Unset(e => e.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); if (buildComplete) u.Inc(e => e.BuildRevision); }, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs index f577fdce..c527a03f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs @@ -85,7 +85,6 @@ await _clearMLService.GetTasksForQueueAsync(_queuePerEngineType[engineType], can var dataAccessContext = scope.ServiceProvider.GetRequiredService<IDataAccessContext>(); var platformService = scope.ServiceProvider.GetRequiredService<IPlatformService>(); - var lockFactory = scope.ServiceProvider.GetRequiredService<IDistributedReaderWriterLockFactory>(); foreach (TranslationEngine engine in trainingEngines) { if (engine.CurrentBuild is null || !tasks.TryGetValue(engine.CurrentBuild.JobId, out ClearMLTask? task)) @@ -119,7 +118,6 @@ or ClearMLTaskStatus.Completed { bool canceled = !await TrainJobStartedAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -159,8 +157,8 @@ await UpdateTrainJobStatus( cancellationToken ); bool canceling = !await TrainJobCompletedAsync( - lockFactory, buildJobService, + engine.Type, engine.EngineId, engine.CurrentBuild.BuildId, (int)GetMetric(task, SummaryMetric, TrainCorpusSizeVariant), @@ -172,7 +170,6 @@ await UpdateTrainJobStatus( { await TrainJobCanceledAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -187,7 +184,6 @@ await TrainJobCanceledAsync( { await TrainJobCanceledAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -201,7 +197,6 @@ await TrainJobCanceledAsync( { await TrainJobFaultedAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -223,7 +218,6 @@ await TrainJobFaultedAsync( private async Task<bool> TrainJobStartedAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -231,29 +225,24 @@ private async Task<bool> TrainJobStartedAsync( CancellationToken cancellationToken = default ) { - bool success; - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - success = await dataAccessContext.WithTransactionAsync( - async (ct) => - { - if (!await buildJobService.BuildJobStartedAsync(engineId, buildId, ct)) - return false; - await platformService.BuildStartedAsync(buildId, CancellationToken.None); - return true; - }, - cancellationToken: cancellationToken - ); - } + bool success = await dataAccessContext.WithTransactionAsync( + async (ct) => + { + if (!await buildJobService.BuildJobStartedAsync(engineId, buildId, ct)) + return false; + await platformService.BuildStartedAsync(buildId, CancellationToken.None); + return true; + }, + cancellationToken: cancellationToken + ); await UpdateTrainJobStatus(platformService, buildId, new ProgressStatus(0), 0, cancellationToken); _logger.LogInformation("Build started ({BuildId})", buildId); return success; } private async Task<bool> TrainJobCompletedAsync( - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, + TranslationEngineType engineType, string engineId, string buildId, int corpusSize, @@ -264,19 +253,16 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - return await buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Postprocess, - (corpusSize, confidence), - buildOptions, - cancellationToken - ); - } + return await buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + engineType, + engineId, + buildId, + BuildStage.Postprocess, + (corpusSize, confidence), + buildOptions, + cancellationToken + ); } finally { @@ -286,7 +272,6 @@ CancellationToken cancellationToken private async Task TrainJobFaultedAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -297,23 +282,19 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await dataAccessContext.WithTransactionAsync( - async (ct) => - { - await platformService.BuildFaultedAsync(buildId, message, ct); - await buildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: cancellationToken - ); - } + await dataAccessContext.WithTransactionAsync( + async (ct) => + { + await platformService.BuildFaultedAsync(buildId, message, ct); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: cancellationToken + ); _logger.LogError("Build faulted ({BuildId}). Error: {ErrorMessage}", buildId, message); } finally @@ -324,7 +305,6 @@ await buildJobService.BuildJobFinishedAsync( private async Task TrainJobCanceledAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -334,23 +314,19 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await dataAccessContext.WithTransactionAsync( - async (ct) => - { - await platformService.BuildCanceledAsync(buildId, ct); - await buildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: cancellationToken - ); - } + await dataAccessContext.WithTransactionAsync( + async (ct) => + { + await platformService.BuildCanceledAsync(buildId, ct); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: cancellationToken + ); _logger.LogInformation("Build canceled ({BuildId})", buildId); } finally diff --git a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs index 6dfea687..f69c74d6 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs @@ -14,14 +14,51 @@ DistributedReaderWriterLockOptions lockOptions private readonly string _id = id; private readonly DistributedReaderWriterLockOptions _lockOptions = lockOptions; - public async Task<IAsyncDisposable> ReaderLockAsync( - TimeSpan? lifetime = default, + public Task ReaderLockAsync( + Func<CancellationToken, Task> action, + TimeSpan? lifetime = null, + CancellationToken cancellationToken = default + ) + { + return ReaderLockAsync<object?>( + async ct => + { + await action(ct); + return null; + }, + lifetime, + cancellationToken + ); + } + + public Task WriterLockAsync( + Func<CancellationToken, Task> action, + TimeSpan? lifetime = null, + CancellationToken cancellationToken = default + ) + { + return WriterLockAsync<object?>( + async ct => + { + await action(ct); + return null; + }, + lifetime, + cancellationToken + ); + } + + public async Task<T> ReaderLockAsync<T>( + Func<CancellationToken, Task<T>> action, + TimeSpan? lifetime = null, CancellationToken cancellationToken = default ) { string lockId = _idGenerator.GenerateId(); TimeSpan resolvedLifetime = lifetime ?? _lockOptions.DefaultLifetime; - if (!await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken)) + + (bool acquired, DateTime expiresAt) = await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken); + if (!acquired) { using ISubscription<RWLock> sub = await _locks.SubscribeAsync(rwl => rwl.Id == _id, cancellationToken); do @@ -39,19 +76,38 @@ public async Task<IAsyncDisposable> ReaderLockAsync( if (timeout != TimeSpan.Zero) await sub.WaitForChangeAsync(timeout, cancellationToken); } - } while (!await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken)); + (acquired, expiresAt) = await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken); + } while (!acquired); + } + + try + { + (bool completed, T? result) = await TaskEx.Timeout(action, expiresAt - DateTime.UtcNow, cancellationToken); + if (!completed) + throw new TimeoutException($"A reader lock for the distributed lock '{_id}' expired."); + return result!; + } + finally + { + Expression<Func<RWLock, bool>> filter = rwl => rwl.Id == _id && rwl.ReaderLocks.Any(l => l.Id == lockId); + await _locks.UpdateAsync( + filter, + u => u.RemoveAll(rwl => rwl.ReaderLocks, l => l.Id == lockId), + cancellationToken: CancellationToken.None + ); } - return new ReaderLockReleaser(this, lockId); } - public async Task<IAsyncDisposable> WriterLockAsync( - TimeSpan? lifetime = default, + public async Task<T> WriterLockAsync<T>( + Func<CancellationToken, Task<T>> action, + TimeSpan? lifetime = null, CancellationToken cancellationToken = default ) { string lockId = _idGenerator.GenerateId(); TimeSpan resolvedLifetime = lifetime ?? _lockOptions.DefaultLifetime; - if (!await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken)) + (bool acquired, DateTime expiresAt) = await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken); + if (!acquired) { await _locks.UpdateAsync( _id, @@ -79,7 +135,8 @@ await _locks.UpdateAsync( if (timeout != TimeSpan.Zero) await sub.WaitForChangeAsync(timeout, cancellationToken); } - } while (!await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken)); + (acquired, expiresAt) = await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken); + } while (!acquired); } catch { @@ -91,12 +148,34 @@ await _locks.UpdateAsync( throw; } } - return new WriterLockReleaser(this, lockId); + + try + { + (bool completed, T? result) = await TaskEx.Timeout(action, expiresAt - DateTime.UtcNow, cancellationToken); + if (!completed) + throw new TimeoutException($"A writer lock for the distributed lock '{_id}' expired."); + return result!; + } + finally + { + Expression<Func<RWLock, bool>> filter = rwl => + rwl.Id == _id && rwl.WriterLock != null && rwl.WriterLock.Id == lockId; + await _locks.UpdateAsync( + filter, + u => u.Unset(rwl => rwl.WriterLock), + cancellationToken: CancellationToken.None + ); + } } - private async Task<bool> TryAcquireWriterLock(string lockId, TimeSpan lifetime, CancellationToken cancellationToken) + private async Task<(bool, DateTime)> TryAcquireWriterLock( + string lockId, + TimeSpan lifetime, + CancellationToken cancellationToken + ) { - var now = DateTime.UtcNow; + DateTime now = DateTime.UtcNow; + DateTime expiresAt = now + lifetime; Expression<Func<RWLock, bool>> filter = rwl => rwl.Id == _id && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt <= now) @@ -109,19 +188,24 @@ void Update(IUpdateBuilder<RWLock> u) new Lock { Id = lockId, - ExpiresAt = now + lifetime, + ExpiresAt = expiresAt, HostId = _hostId } ); u.RemoveAll(rwl => rwl.WriterQueue, l => l.Id == lockId); } RWLock? rwLock = await _locks.UpdateAsync(filter, Update, cancellationToken: cancellationToken); - return rwLock is not null; + return (rwLock is not null, expiresAt); } - private async Task<bool> TryAcquireReaderLock(string lockId, TimeSpan lifetime, CancellationToken cancellationToken) + private async Task<(bool, DateTime)> TryAcquireReaderLock( + string lockId, + TimeSpan lifetime, + CancellationToken cancellationToken + ) { - var now = DateTime.UtcNow; + DateTime now = DateTime.UtcNow; + DateTime expiresAt = now + lifetime; Expression<Func<RWLock, bool>> filter = rwl => rwl.Id == _id && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt <= now) && !rwl.WriterQueue.Any(); void Update(IUpdateBuilder<RWLock> u) @@ -131,42 +215,13 @@ void Update(IUpdateBuilder<RWLock> u) new Lock { Id = lockId, - ExpiresAt = now + lifetime, + ExpiresAt = expiresAt, HostId = _hostId } ); } RWLock? rwLock = await _locks.UpdateAsync(filter, Update, cancellationToken: cancellationToken); - return rwLock is not null; - } - - private class WriterLockReleaser(DistributedReaderWriterLock distributedLock, string lockId) : AsyncDisposableBase - { - private readonly DistributedReaderWriterLock _distributedLock = distributedLock; - private readonly string _lockId = lockId; - - protected override async ValueTask DisposeAsyncCore() - { - Expression<Func<RWLock, bool>> filter = rwl => - rwl.Id == _distributedLock._id && rwl.WriterLock != null && rwl.WriterLock.Id == _lockId; - await _distributedLock._locks.UpdateAsync(filter, u => u.Unset(rwl => rwl.WriterLock)); - } - } - - private class ReaderLockReleaser(DistributedReaderWriterLock distributedLock, string lockId) : AsyncDisposableBase - { - private readonly DistributedReaderWriterLock _distributedLock = distributedLock; - private readonly string _lockId = lockId; - - protected override async ValueTask DisposeAsyncCore() - { - Expression<Func<RWLock, bool>> filter = rwl => - rwl.Id == _distributedLock._id && rwl.ReaderLocks.Any(l => l.Id == _lockId); - await _distributedLock._locks.UpdateAsync( - filter, - u => u.RemoveAll(rwl => rwl.ReaderLocks, l => l.Id == _lockId) - ); - } + return (rwLock is not null, expiresAt); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs index 26fe58ed..13fc9add 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs @@ -3,11 +3,10 @@ public abstract class HangfireBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger<HangfireBuildJob> logger -) : HangfireBuildJob<object?>(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) +) : HangfireBuildJob<object?>(platformService, engines, dataAccessContext, buildJobService, logger) { public virtual Task RunAsync( string engineId, @@ -23,7 +22,6 @@ CancellationToken cancellationToken public abstract class HangfireBuildJob<T>( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger<HangfireBuildJob<T>> logger @@ -31,7 +29,6 @@ ILogger<HangfireBuildJob<T>> logger { protected IPlatformService PlatformService { get; } = platformService; protected IRepository<TranslationEngine> Engines { get; } = engines; - protected IDistributedReaderWriterLockFactory LockFactory { get; } = lockFactory; protected IDataAccessContext DataAccessContext { get; } = dataAccessContext; protected IBuildJobService BuildJobService { get; } = buildJobService; protected ILogger<HangfireBuildJob<T>> Logger { get; } = logger; @@ -44,21 +41,17 @@ public virtual async Task RunAsync( CancellationToken cancellationToken ) { - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); JobCompletionStatus completionStatus = JobCompletionStatus.Completed; try { - await InitializeAsync(engineId, buildId, data, @lock, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + await InitializeAsync(engineId, buildId, data, cancellationToken); + if (!await BuildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) { - if (!await BuildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) - { - completionStatus = JobCompletionStatus.Canceled; - return; - } + completionStatus = JobCompletionStatus.Canceled; + return; } - await DoWorkAsync(engineId, buildId, data, buildOptions, @lock, cancellationToken); + await DoWorkAsync(engineId, buildId, data, buildOptions, cancellationToken); } catch (OperationCanceledException) { @@ -70,22 +63,19 @@ CancellationToken cancellationToken if (engine?.CurrentBuild?.JobState is BuildJobState.Canceling) { completionStatus = JobCompletionStatus.Canceled; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildCanceledAsync(buildId, CancellationToken.None); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildCanceledAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: CancellationToken.None + ); Logger.LogInformation("Build canceled ({0})", buildId); } else if (engine is not null) @@ -93,17 +83,14 @@ await BuildJobService.BuildJobFinishedAsync( // the build was canceled, because of a server shutdown // switch state back to pending completionStatus = JobCompletionStatus.Restarting; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildRestartingAsync(buildId, CancellationToken.None); - await BuildJobService.BuildJobRestartingAsync(engineId, buildId, CancellationToken.None); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildRestartingAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobRestartingAsync(engineId, buildId, CancellationToken.None); + }, + cancellationToken: CancellationToken.None + ); throw; } else @@ -114,38 +101,29 @@ await DataAccessContext.WithTransactionAsync( catch (Exception e) { completionStatus = JobCompletionStatus.Faulted; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: CancellationToken.None + ); Logger.LogError(0, e, "Build faulted ({0})", buildId); throw; } finally { - await CleanupAsync(engineId, buildId, data, @lock, completionStatus); + await CleanupAsync(engineId, buildId, data, completionStatus); } } - protected virtual Task InitializeAsync( - string engineId, - string buildId, - T data, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) + protected virtual Task InitializeAsync(string engineId, string buildId, T data, CancellationToken cancellationToken) { return Task.CompletedTask; } @@ -155,17 +133,10 @@ protected abstract Task DoWorkAsync( string buildId, T data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ); - protected virtual Task CleanupAsync( - string engineId, - string buildId, - T data, - IDistributedReaderWriterLock @lock, - JobCompletionStatus completionStatus - ) + protected virtual Task CleanupAsync(string engineId, string buildId, T data, JobCompletionStatus completionStatus) { return Task.CompletedTask; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs index c9ddf983..61c6122e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs @@ -14,7 +14,8 @@ Task<IReadOnlyList<TranslationEngine>> GetBuildingEnginesAsync( Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); Task<bool> StartBuildJobAsync( - BuildJobRunnerType jobType, + BuildJobRunnerType runnerType, + TranslationEngineType engineType, string engineId, string buildId, BuildStage stage, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs b/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs index 026aff28..7edf79f7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs @@ -2,6 +2,25 @@ public interface IDistributedReaderWriterLock { - Task<IAsyncDisposable> ReaderLockAsync(TimeSpan? lifetime = default, CancellationToken cancellationToken = default); - Task<IAsyncDisposable> WriterLockAsync(TimeSpan? lifetime = default, CancellationToken cancellationToken = default); + Task ReaderLockAsync( + Func<CancellationToken, Task> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + Task WriterLockAsync( + Func<CancellationToken, Task> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + + Task<T> ReaderLockAsync<T>( + Func<CancellationToken, Task<T>> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + Task<T> WriterLockAsync<T>( + Func<CancellationToken, Task<T>> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs index 6612e11e..01776084 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs @@ -2,21 +2,19 @@ public interface ISmtModelFactory { - Task<IInteractiveTranslationModel> CreateAsync( + IInteractiveTranslationModel Create( string engineDir, IRangeTokenizer<string, int, string> tokenizer, IDetokenizer<string, string> detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ); - Task<ITrainer> CreateTrainerAsync( + ITrainer CreateTrainer( string engineDir, IRangeTokenizer<string, int, string> tokenizer, - IParallelTextCorpus corpus, - CancellationToken cancellationToken = default + IParallelTextCorpus corpus ); - Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + void InitNew(string engineDir); + void Cleanup(string engineDir); Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default); Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs index c76b8e91..7ac3eb5b 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs @@ -2,13 +2,12 @@ public interface ITransferEngineFactory { - Task<ITranslationEngine?> CreateAsync( + ITranslationEngine? Create( string engineDir, IRangeTokenizer<string, int, string> tokenizer, IDetokenizer<string, string> detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ); - Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + void InitNew(string engineDir); + void Cleanup(string engineDir); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs index e83337d3..c4470925 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs @@ -2,12 +2,7 @@ public interface ITruecaserFactory { - Task<ITruecaser> CreateAsync(string engineDir, CancellationToken cancellationToken = default); - Task<ITrainer> CreateTrainerAsync( - string engineDir, - ITokenizer<string, int, string> tokenizer, - ITextCorpus corpus, - CancellationToken cancellationToken = default - ); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + ITruecaser Create(string engineDir); + ITrainer CreateTrainer(string engineDir, ITokenizer<string, int, string> tokenizer, ITextCorpus corpus); + void Cleanup(string engineDir); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs index 5a2fb912..fc1c2c95 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs @@ -2,7 +2,6 @@ public class NmtEngineService( IPlatformService platformService, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IRepository<TranslationEngine> engines, IBuildJobService buildJobService, @@ -11,7 +10,6 @@ public class NmtEngineService( ISharedFileService sharedFileService ) : ITranslationEngineService { - private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; private readonly IPlatformService _platformService = platformService; private readonly IDataAccessContext _dataAccessContext = dataAccessContext; private readonly IRepository<TranslationEngine> _engines = engines; @@ -61,15 +59,10 @@ public async Task<TranslationEngine> CreateAsync( public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await CancelBuildJobAsync(engineId, cancellationToken); + await CancelBuildJobAsync(engineId, cancellationToken); - await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); - await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); - } - await _lockFactory.DeleteAsync(engineId, CancellationToken.None); + await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); } public async Task StartBuildAsync( @@ -80,33 +73,26 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - // If there is a pending/running build, then no need to start a new one. - if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is already building or in the process of canceling."); - - await _buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Preprocess, - corpora, - buildOptions, - cancellationToken - ); - } + bool building = !await _buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.Nmt, + engineId, + buildId, + BuildStage.Preprocess, + corpora, + buildOptions, + cancellationToken + ); + // If there is a pending/running build, then no need to start a new one. + if (building) + throw new InvalidOperationException("The engine is already building or in the process of canceling."); } public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - if (!await CancelBuildJobAsync(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is not currently building."); - } + bool building = await CancelBuildJobAsync(engineId, cancellationToken); + if (!building) + throw new InvalidOperationException("The engine is not currently building."); } public async Task<ModelDownloadUrl> GetModelDownloadUrlAsync( diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs index b4c61648..3c46a34e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs @@ -3,7 +3,6 @@ public class NmtPreprocessBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, ILogger<NmtPreprocessBuildJob> logger, IBuildJobService buildJobService, @@ -14,7 +13,6 @@ ILanguageTagService languageTagService : PreprocessBuildJob( platformService, engines, - lockFactory, dataAccessContext, logger, buildJobService, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs index f208161c..0992293f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs @@ -3,23 +3,19 @@ public class PostprocessBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger<PostprocessBuildJob> logger, - ISharedFileService sharedFileService, - IOptionsMonitor<BuildJobOptions> buildJobOptions -) : HangfireBuildJob<(int, double)>(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) + ISharedFileService sharedFileService +) : HangfireBuildJob<(int, double)>(platformService, engines, dataAccessContext, buildJobService, logger) { protected ISharedFileService SharedFileService { get; } = sharedFileService; - private readonly BuildJobOptions _buildJobOptions = buildJobOptions.CurrentValue; protected override async Task DoWorkAsync( string engineId, string buildId, (int, double) data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -35,33 +31,20 @@ CancellationToken cancellationToken await PlatformService.InsertPretranslationsAsync(engineId, pretranslationsStream, cancellationToken); } - await using ( - await @lock.WriterLockAsync( - lifetime: _buildJobOptions.PostProcessLockLifetime, - cancellationToken: CancellationToken.None - ) - ) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - int additionalCorpusSize = await SaveModelAsync(engineId, buildId); - await PlatformService.BuildCompletedAsync( - buildId, - corpusSize + additionalCorpusSize, - Math.Round(confidence, 2, MidpointRounding.AwayFromZero), - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: true, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + int additionalCorpusSize = await SaveModelAsync(engineId, buildId); + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildCompletedAsync( + buildId, + corpusSize + additionalCorpusSize, + Math.Round(confidence, 2, MidpointRounding.AwayFromZero), + ct + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, ct); + }, + cancellationToken: CancellationToken.None + ); Logger.LogInformation("Build completed ({0}).", buildId); } @@ -75,7 +58,6 @@ protected override async Task CleanupAsync( string engineId, string buildId, (int, double) data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index d15e5a69..5a909f39 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -14,14 +14,13 @@ public class PreprocessBuildJob : HangfireBuildJob<IReadOnlyList<Corpus>> public PreprocessBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, ILogger<PreprocessBuildJob> logger, IBuildJobService buildJobService, ISharedFileService sharedFileService, ICorpusService corpusService ) - : base(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) + : base(platformService, engines, dataAccessContext, buildJobService, logger) { _sharedFileService = sharedFileService; _corpusService = corpusService; @@ -46,7 +45,6 @@ protected override async Task DoWorkAsync( string buildId, IReadOnlyList<Corpus> data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -86,19 +84,17 @@ CancellationToken cancellationToken cancellationToken.ThrowIfCancellationRequested(); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - bool canceling = !await BuildJobService.StartBuildJobAsync( - TrainJobRunnerType, - engineId, - buildId, - BuildStage.Train, - buildOptions: buildOptions, - cancellationToken: cancellationToken - ); - if (canceling) - throw new OperationCanceledException(); - } + bool canceling = !await BuildJobService.StartBuildJobAsync( + TrainJobRunnerType, + engine.Type, + engine.Id, + buildId, + BuildStage.Train, + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); } private async Task<(int TrainCount, int PretranslateCount)> WriteDataFilesAsync( @@ -209,7 +205,6 @@ protected override async Task CleanupAsync( string engineId, string buildId, IReadOnlyList<Corpus> data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs deleted file mode 100644 index d4ba43ef..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs +++ /dev/null @@ -1,173 +0,0 @@ -namespace Serval.Machine.Shared.Services; - -public class SmtTransferBuildJob( - IPlatformService platformService, - IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, - IDataAccessContext dataAccessContext, - IBuildJobService buildJobService, - ILogger<SmtTransferBuildJob> logger, - IRepository<TrainSegmentPair> trainSegmentPairs, - ITruecaserFactory truecaserFactory, - ISmtModelFactory smtModelFactory, - ICorpusService corpusService, - IOptions<BuildJobOptions> buildJobOptions -) - : HangfireBuildJob<IReadOnlyList<Corpus>>( - platformService, - engines, - lockFactory, - dataAccessContext, - buildJobService, - logger - ) -{ - private readonly IRepository<TrainSegmentPair> _trainSegmentPairs = trainSegmentPairs; - private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; - private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; - private readonly ICorpusService _corpusService = corpusService; - private readonly BuildJobOptions _buildJobOptions = buildJobOptions.Value; - - protected override Task InitializeAsync( - string engineId, - string buildId, - IReadOnlyList<Corpus> data, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - return _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); - } - - protected override async Task DoWorkAsync( - string engineId, - string buildId, - IReadOnlyList<Corpus> data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - await PlatformService.BuildStartedAsync(buildId, cancellationToken); - Logger.LogInformation("Build started ({0})", buildId); - var stopwatch = new Stopwatch(); - stopwatch.Start(); - - cancellationToken.ThrowIfCancellationRequested(); - - JsonObject? buildOptionsObject = null; - if (buildOptions is not null) - buildOptionsObject = JsonSerializer.Deserialize<JsonObject>(buildOptions); - - var targetCorpora = new List<ITextCorpus>(); - var parallelCorpora = new List<IParallelTextCorpus>(); - foreach (Corpus corpus in data) - { - ITextCorpus? sourceTextCorpus = _corpusService.CreateTextCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTextCorpus = _corpusService.CreateTextCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTextCorpus is null || targetTextCorpus is null) - continue; - - targetCorpora.Add(targetTextCorpus); - parallelCorpora.Add(sourceTextCorpus.AlignRows(targetTextCorpus)); - - if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) - { - ITextCorpus? sourceTermCorpus = _corpusService.CreateTermCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTermCorpus = _corpusService.CreateTermCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTermCorpus is not null && targetTermCorpus is not null) - { - IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); - parallelCorpora.Add(parallelKeyTermsCorpus); - } - } - } - - IParallelTextCorpus parallelCorpus = parallelCorpora.Flatten(); - ITextCorpus targetCorpus = targetCorpora.Flatten(); - - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - - using ITrainer smtModelTrainer = await _smtModelFactory.CreateTrainerAsync( - engineId, - tokenizer, - parallelCorpus, - cancellationToken - ); - using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( - engineId, - tokenizer, - targetCorpus, - cancellationToken - ); - - cancellationToken.ThrowIfCancellationRequested(); - - var progress = new BuildProgress(PlatformService, buildId); - await smtModelTrainer.TrainAsync(progress, cancellationToken); - await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); - - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new OperationCanceledException(); - - await using ( - await @lock.WriterLockAsync( - lifetime: _buildJobOptions.PostProcessLockLifetime, - cancellationToken: cancellationToken - ) - ) - { - cancellationToken.ThrowIfCancellationRequested(); - await smtModelTrainer.SaveAsync(CancellationToken.None); - await truecaseTrainer.SaveAsync(CancellationToken.None); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId, CancellationToken.None); - IReadOnlyList<TrainSegmentPair> segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine.Id, - CancellationToken.None - ); - using ( - IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( - engineId, - tokenizer, - detokenizer, - truecaser, - CancellationToken.None - ) - ) - { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - } - - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildCompletedAsync( - buildId, - smtModelTrainer.Stats.TrainCorpusSize + segmentPairs.Count, - smtModelTrainer.Stats.Metrics["bleu"] * 100.0, - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: true, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } - - stopwatch.Stop(); - Logger.LogInformation("Build completed in {0}s ({1})", stopwatch.Elapsed.TotalSeconds, buildId); - } -} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs index bdda5353..5789d67d 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs @@ -40,7 +40,7 @@ public async Task<TranslationEngine> CreateAsync( } TranslationEngine translationEngine = await _dataAccessContext.WithTransactionAsync( - async (ct) => + async ct => { var translationEngine = new TranslationEngine { @@ -57,38 +57,30 @@ public async Task<TranslationEngine> CreateAsync( cancellationToken: cancellationToken ); - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, CancellationToken.None); - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - SmtTransferEngineState state = _stateService.Get(engineId); - await state.InitNewAsync(CancellationToken.None); - } + SmtTransferEngineState state = _stateService.Get(engineId); + state.InitNew(); return translationEngine; } public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await CancelBuildJobAsync(engineId, cancellationToken); - - await _dataAccessContext.WithTransactionAsync( - async (ct) => - { - await _engines.DeleteAsync(e => e.EngineId == engineId, ct); - await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); - }, - cancellationToken: cancellationToken - ); - await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); + await CancelBuildJobAsync(engineId, cancellationToken); - if (_stateService.TryRemove(engineId, out SmtTransferEngineState? state)) + await _dataAccessContext.WithTransactionAsync( + async ct => { - await state.DeleteDataAsync(); - await state.DisposeAsync(); - } - } + await _engines.DeleteAsync(e => e.EngineId == engineId, ct); + await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); + }, + cancellationToken: cancellationToken + ); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); + + SmtTransferEngineState state = _stateService.Get(engineId); + _stateService.Remove(engineId); + // there is no way to cancel this call + state.DeleteData(); + state.Dispose(); await _lockFactory.DeleteAsync(engineId, CancellationToken.None); } @@ -99,16 +91,22 @@ public async Task<IReadOnlyList<TranslationResult>> TranslateAsync( CancellationToken cancellationToken = default ) { + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); - SmtTransferEngineState state = _stateService.Get(engineId); - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - IReadOnlyList<TranslationResult> results = await hybridEngine.TranslateAsync(n, segment, cancellationToken); - state.LastUsedTime = DateTime.Now; - return results; - } + IReadOnlyList<TranslationResult> results = await @lock.ReaderLockAsync( + async ct => + { + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + return hybridEngine.Translate(n, segment); + }, + cancellationToken: cancellationToken + ); + + state.Touch(); + return results; } public async Task<WordGraph> GetWordGraphAsync( @@ -117,16 +115,22 @@ public async Task<WordGraph> GetWordGraphAsync( CancellationToken cancellationToken = default ) { + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); - SmtTransferEngineState state = _stateService.Get(engineId); - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken); - state.LastUsedTime = DateTime.Now; - return result; - } + WordGraph result = await @lock.ReaderLockAsync( + async ct => + { + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + return hybridEngine.GetWordGraph(segment); + }, + cancellationToken: cancellationToken + ); + + state.Touch(); + return result; } public async Task TrainSegmentPairAsync( @@ -137,47 +141,39 @@ public async Task TrainSegmentPairAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); - async Task TrainSubroutineAsync(SmtTransferEngineState state, CancellationToken ct) + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => { - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - await hybridEngine.TrainSegmentAsync(sourceSegment, targetSegment, sentenceStart, ct); - await _platformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); - } + TranslationEngine engine = await GetEngineAsync(engineId, ct); + + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + hybridEngine.TrainSegment(sourceSegment, targetSegment, sentenceStart); - SmtTransferEngineState state = _stateService.Get(engineId); - await _dataAccessContext.WithTransactionAsync( - async (ct) => + if (engine.CollectTrainSegmentPairs ?? false) { - if (engine.CurrentBuild?.JobState is BuildJobState.Active) - { - await _trainSegmentPairs.InsertAsync( - new TrainSegmentPair - { - TranslationEngineRef = engineId, - Source = sourceSegment, - Target = targetSegment, - SentenceStart = sentenceStart - }, - CancellationToken.None - ); - await TrainSubroutineAsync(state, CancellationToken.None); - } - else - { - await TrainSubroutineAsync(state, ct); - } - }, - cancellationToken: cancellationToken - ); + await _trainSegmentPairs.InsertAsync( + new TrainSegmentPair + { + TranslationEngineRef = engineId, + Source = sourceSegment, + Target = targetSegment, + SentenceStart = sentenceStart + }, + CancellationToken.None + ); + } - state.IsUpdated = true; - state.LastUsedTime = DateTime.Now; - } + state.IsUpdated = true; + }, + cancellationToken: cancellationToken + ); + + await _platformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); + state.Touch(); } public async Task StartBuildAsync( @@ -188,37 +184,32 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - // If there is a pending/running build, then no need to start a new one. - if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is already building or in the process of canceling."); + bool building = !await _buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.SmtTransfer, + engineId, + buildId, + BuildStage.Preprocess, + corpora, + buildOptions, + cancellationToken + ); + // If there is a pending/running build, then no need to start a new one. + if (building) + throw new InvalidOperationException("The engine is already building or in the process of canceling."); - await _buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Preprocess, - corpora, - buildOptions, - cancellationToken - ); - SmtTransferEngineState state = _stateService.Get(engineId); - state.LastUsedTime = DateTime.UtcNow; - } + SmtTransferEngineState state = _stateService.Get(engineId); + state.Touch(); } public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - if (!await CancelBuildJobAsync(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is not currently building."); - SmtTransferEngineState state = _stateService.Get(engineId); - state.LastUsedTime = DateTime.UtcNow; - } + bool building = await CancelBuildJobAsync(engineId, cancellationToken); + if (!building) + throw new InvalidOperationException("The engine is not currently building."); + + SmtTransferEngineState state = _stateService.Get(engineId); + state.Touch(); } public int GetQueueSize() @@ -235,7 +226,7 @@ private async Task<bool> CancelBuildJobAsync(string engineId, CancellationToken { string? buildId = null; await _dataAccessContext.WithTransactionAsync( - async (ct) => + async ct => { (buildId, BuildJobState jobState) = await _buildJobService.CancelBuildJobAsync(engineId, ct); if (buildId is not null && jobState is BuildJobState.None) diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs index a5f4300a..1e6fde8f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs @@ -1,4 +1,6 @@ -namespace Serval.Machine.Shared.Services; +using SIL.ObjectModel; + +namespace Serval.Machine.Shared.Services; public class SmtTransferEngineState( ISmtModelFactory smtModelFactory, @@ -6,7 +8,7 @@ public class SmtTransferEngineState( ITruecaserFactory truecaserFactory, IOptionsMonitor<SmtTransferEngineOptions> options, string engineId -) : AsyncDisposableBase +) : DisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; @@ -21,34 +23,37 @@ string engineId public bool IsUpdated { get; set; } public int CurrentBuildRevision { get; set; } = -1; - public DateTime LastUsedTime { get; set; } = DateTime.UtcNow; + public DateTime LastUsedTime { get; private set; } = DateTime.UtcNow; public bool IsLoaded => _hybridEngine != null; private string EngineDir => Path.Combine(_options.CurrentValue.EnginesDir, EngineId); - public async Task InitNewAsync(CancellationToken cancellationToken = default) + public void InitNew() { - await _smtModelFactory.InitNewAsync(EngineDir, cancellationToken); - await _transferEngineFactory.InitNewAsync(EngineDir, cancellationToken); + _smtModelFactory.InitNew(EngineDir); + _transferEngineFactory.InitNew(EngineDir); } - public async Task<HybridTranslationEngine> GetHybridEngineAsync(int buildRevision) + public async Task<HybridTranslationEngine> GetHybridEngineAsync( + int buildRevision, + CancellationToken cancellationToken = default + ) { - using (await _lock.LockAsync()) + using (await _lock.LockAsync(cancellationToken)) { if (_hybridEngine is not null && CurrentBuildRevision != -1 && buildRevision != CurrentBuildRevision) { IsUpdated = false; - await UnloadAsync(); + Unload(); } if (_hybridEngine is null) { LatinWordTokenizer tokenizer = new(); LatinWordDetokenizer detokenizer = new(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(EngineDir); - _smtModel = await _smtModelFactory.CreateAsync(EngineDir, tokenizer, detokenizer, truecaser); - ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(EngineDir); + _smtModel = _smtModelFactory.Create(EngineDir, tokenizer, detokenizer, truecaser); + ITranslationEngine? transferEngine = _transferEngineFactory.Create( EngineDir, tokenizer, detokenizer, @@ -64,19 +69,15 @@ public async Task<HybridTranslationEngine> GetHybridEngineAsync(int buildRevisio } } - public async Task DeleteDataAsync() + public void DeleteData() { - await UnloadAsync(); - await _smtModelFactory.CleanupAsync(EngineDir); - await _transferEngineFactory.CleanupAsync(EngineDir); - await _truecaserFactory.CleanupAsync(EngineDir); + Unload(); + _smtModelFactory.Cleanup(EngineDir); + _transferEngineFactory.Cleanup(EngineDir); + _truecaserFactory.Cleanup(EngineDir); } - public async Task CommitAsync( - int buildRevision, - TimeSpan inactiveTimeout, - CancellationToken cancellationToken = default - ) + public void Commit(int buildRevision, TimeSpan inactiveTimeout) { if (_hybridEngine is null) return; @@ -85,34 +86,39 @@ public async Task CommitAsync( CurrentBuildRevision = buildRevision; if (buildRevision != CurrentBuildRevision) { - await UnloadAsync(cancellationToken); + Unload(); CurrentBuildRevision = buildRevision; } else if (DateTime.Now - LastUsedTime > inactiveTimeout) { - await UnloadAsync(cancellationToken); + Unload(); } else { - await SaveModelAsync(cancellationToken); + SaveModel(); } } - private async Task SaveModelAsync(CancellationToken cancellationToken = default) + public void Touch() + { + LastUsedTime = DateTime.UtcNow; + } + + private void SaveModel() { if (_smtModel is not null && IsUpdated) { - await _smtModel.SaveAsync(cancellationToken); + _smtModel.Save(); IsUpdated = false; } } - private async Task UnloadAsync(CancellationToken cancellationToken = default) + private void Unload() { if (_hybridEngine is null) return; - await SaveModelAsync(cancellationToken); + SaveModel(); _hybridEngine.Dispose(); @@ -121,8 +127,8 @@ private async Task UnloadAsync(CancellationToken cancellationToken = default) CurrentBuildRevision = -1; } - protected override async ValueTask DisposeAsyncCore() + protected override void DisposeManagedResources() { - await UnloadAsync(); + Unload(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs index 03ef2ad8..9b97e004 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs @@ -1,16 +1,20 @@ -namespace Serval.Machine.Shared.Services; +using SIL.ObjectModel; + +namespace Serval.Machine.Shared.Services; public class SmtTransferEngineStateService( ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory, ITruecaserFactory truecaserFactory, - IOptionsMonitor<SmtTransferEngineOptions> options -) : AsyncDisposableBase + IOptionsMonitor<SmtTransferEngineOptions> options, + ILogger<SmtTransferEngineStateService> logger +) : DisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly IOptionsMonitor<SmtTransferEngineOptions> _options = options; + private readonly ILogger<SmtTransferEngineStateService> _logger = logger; private readonly ConcurrentDictionary<string, SmtTransferEngineState> _engineStates = new ConcurrentDictionary<string, SmtTransferEngineState>(); @@ -20,9 +24,9 @@ public SmtTransferEngineState Get(string engineId) return _engineStates.GetOrAdd(engineId, CreateState); } - public bool TryRemove(string engineId, [MaybeNullWhen(false)] out SmtTransferEngineState state) + public void Remove(string engineId) { - return _engineStates.TryRemove(engineId, out state); + _engineStates.TryRemove(engineId, out _); } public async Task CommitAsync( @@ -34,20 +38,24 @@ public async Task CommitAsync( { foreach (SmtTransferEngineState state in _engineStates.Values) { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + try { - TranslationEngine? engine = await engines.GetAsync( - e => e.EngineId == state.EngineId, - cancellationToken + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => + { + TranslationEngine? engine = await engines.GetAsync(state.EngineId, ct); + if (engine is not null && !(engine.CollectTrainSegmentPairs ?? false)) + // there is no way to cancel this call + state.Commit(engine.BuildRevision, inactiveTimeout); + }, + _options.CurrentValue.EngineCommitTimeout, + cancellationToken: cancellationToken ); - if ( - engine is not null - && (engine.CurrentBuild is null || engine.CurrentBuild.JobState is BuildJobState.Pending) - ) - { - await state.CommitAsync(engine.BuildRevision, inactiveTimeout, cancellationToken); - } + } + catch (Exception e) + { + _logger.LogError(e, "Error occurred while committing SMT transfer engine {EngineId}.", state.EngineId); } } } @@ -63,10 +71,10 @@ private SmtTransferEngineState CreateState(string engineId) ); } - protected override async ValueTask DisposeAsyncCore() + protected override void DisposeManagedResources() { foreach (SmtTransferEngineState state in _engineStates.Values) - await state.DisposeAsync(); + state.Dispose(); _engineStates.Clear(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs index 1f8a4d48..8d2c12ca 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs @@ -3,64 +3,66 @@ public class SmtTransferPostprocessBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger<SmtTransferPostprocessBuildJob> logger, ISharedFileService sharedFileService, - IOptionsMonitor<BuildJobOptions> buildJobOptions, + IDistributedReaderWriterLockFactory lockFactory, IRepository<TrainSegmentPair> trainSegmentPairs, ISmtModelFactory smtModelFactory, ITruecaserFactory truecaserFactory, IOptionsMonitor<SmtTransferEngineOptions> engineOptions -) - : PostprocessBuildJob( - platformService, - engines, - lockFactory, - dataAccessContext, - buildJobService, - logger, - sharedFileService, - buildJobOptions - ) +) : PostprocessBuildJob(platformService, engines, dataAccessContext, buildJobService, logger, sharedFileService) { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly IRepository<TrainSegmentPair> _trainSegmentPairs = trainSegmentPairs; private readonly IOptionsMonitor<SmtTransferEngineOptions> _engineOptions = engineOptions; + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; protected override async Task<int> SaveModelAsync(string engineId, string buildId) { - await using ( - Stream engineStream = await SharedFileService.OpenReadAsync( - $"builds/{buildId}/model.tar.gz", - CancellationToken.None - ) - ) - { - await _smtModelFactory.UpdateEngineFromAsync( - Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), - engineStream, - CancellationToken.None - ); - } - return await TrainOnNewSegmentPairsAsync(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId); + return await @lock.WriterLockAsync( + async ct => + { + await using ( + Stream engineStream = await SharedFileService.OpenReadAsync($"builds/{buildId}/model.tar.gz", ct) + ) + { + await _smtModelFactory.UpdateEngineFromAsync( + Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), + engineStream, + ct + ); + } + IReadOnlyList<TrainSegmentPair> segmentPairs = await _trainSegmentPairs.GetAllAsync( + p => p.TranslationEngineRef == engineId, + ct + ); + TrainOnNewSegmentPairs(engineId, segmentPairs, ct); + await Engines.UpdateAsync( + engineId, + u => u.Set(e => e.CollectTrainSegmentPairs, false), + cancellationToken: ct + ); + return segmentPairs.Count; + }, + _engineOptions.CurrentValue.SaveModelTimeout + ); } - private async Task<int> TrainOnNewSegmentPairsAsync(string engineId) + private void TrainOnNewSegmentPairs( + string engineId, + IReadOnlyList<TrainSegmentPair> segmentPairs, + CancellationToken cancellationToken + ) { - IReadOnlyList<TrainSegmentPair> segmentPairs = await _trainSegmentPairs.GetAllAsync(p => - p.TranslationEngineRef == engineId - ); - if (segmentPairs.Count == 0) - return segmentPairs.Count; - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); var tokenizer = new LatinWordTokenizer(); var detokenizer = new LatinWordDetokenizer(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir); - using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(engineDir); + using IInteractiveTranslationModel smtModel = _smtModelFactory.Create( engineDir, tokenizer, detokenizer, @@ -68,9 +70,10 @@ private async Task<int> TrainOnNewSegmentPairsAsync(string engineId) ); foreach (TrainSegmentPair segmentPair in segmentPairs) { - await smtModel.TrainSegmentAsync(segmentPair.Source, segmentPair.Target); + cancellationToken.ThrowIfCancellationRequested(); + smtModel.TrainSegment(segmentPair.Source, segmentPair.Target); } - await smtModel.SaveAsync(); - return segmentPairs.Count; + cancellationToken.ThrowIfCancellationRequested(); + smtModel.Save(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs new file mode 100644 index 00000000..9e14037a --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs @@ -0,0 +1,48 @@ +namespace Serval.Machine.Shared.Services; + +public class SmtTransferPreprocessBuildJob( + IPlatformService platformService, + IRepository<TranslationEngine> engines, + IDataAccessContext dataAccessContext, + ILogger<PreprocessBuildJob> logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService, + IDistributedReaderWriterLockFactory lockFactory, + IRepository<TrainSegmentPair> trainSegmentPairs +) + : PreprocessBuildJob( + platformService, + engines, + dataAccessContext, + logger, + buildJobService, + sharedFileService, + corpusService + ) +{ + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; + private readonly IRepository<TrainSegmentPair> _trainSegmentPairs = trainSegmentPairs; + + protected override async Task InitializeAsync( + string engineId, + string buildId, + IReadOnlyList<Corpus> data, + CancellationToken cancellationToken + ) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => + { + await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); + await Engines.UpdateAsync( + engineId, + u => u.Set(e => e.CollectTrainSegmentPairs, true), + cancellationToken: ct + ); + }, + cancellationToken: cancellationToken + ); + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs index bb4870c1..e81fc354 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs @@ -3,7 +3,6 @@ public class SmtTransferTrainBuildJob( IPlatformService platformService, IRepository<TranslationEngine> engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger<SmtTransferTrainBuildJob> logger, @@ -11,7 +10,7 @@ public class SmtTransferTrainBuildJob( ITruecaserFactory truecaserFactory, ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory -) : HangfireBuildJob(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) +) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger) { private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; private static readonly JsonSerializerOptions JsonSerializerOptions = @@ -28,7 +27,6 @@ protected override async Task DoWorkAsync( string buildId, object? data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -55,27 +53,24 @@ CancellationToken cancellationToken await GeneratePretranslationsAsync(buildId, engineDir, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - bool canceling = !await BuildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Postprocess, - data: (trainCorpusSize, confidence), - buildOptions: buildOptions, - cancellationToken: cancellationToken - ); - if (canceling) - throw new OperationCanceledException(); - } + bool canceling = !await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.SmtTransfer, + engineId, + buildId, + BuildStage.Postprocess, + data: (trainCorpusSize, confidence), + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); } protected override async Task CleanupAsync( string engineId, string buildId, object? data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { @@ -118,22 +113,12 @@ private async Task DownloadDataAsync(string buildId, string corpusDir, Cancellat CancellationToken cancellationToken ) { - await _smtModelFactory.InitNewAsync(engineDir, cancellationToken); + _smtModelFactory.InitNew(engineDir); LatinWordTokenizer tokenizer = new(); int trainCorpusSize; double confidence; - using ITrainer smtModelTrainer = await _smtModelFactory.CreateTrainerAsync( - engineDir, - tokenizer, - parallelCorpus, - cancellationToken - ); - using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( - engineDir, - tokenizer, - targetCorpus, - cancellationToken - ); + using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineDir, tokenizer, parallelCorpus); + using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineDir, tokenizer, targetCorpus); cancellationToken.ThrowIfCancellationRequested(); var progress = new BuildProgress(PlatformService, buildId); @@ -179,20 +164,18 @@ CancellationToken cancellationToken LatinWordTokenizer tokenizer = new(); LatinWordDetokenizer detokenizer = new(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir, CancellationToken.None); - using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(engineDir); + using IInteractiveTranslationModel smtModel = _smtModelFactory.Create( engineDir, tokenizer, detokenizer, - truecaser, - cancellationToken + truecaser ); - using ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + using ITranslationEngine? transferEngine = _transferEngineFactory.Create( engineDir, tokenizer, detokenizer, - truecaser, - cancellationToken + truecaser ); HybridTranslationEngine hybridEngine = new(smtModel, transferEngine) { TargetDetokenizer = detokenizer }; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs index 031891c4..03f4ab5d 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs @@ -4,12 +4,11 @@ public class ThotSmtModelFactory(IOptionsMonitor<ThotSmtModelOptions> options) : { private readonly IOptionsMonitor<ThotSmtModelOptions> _options = options; - public Task<IInteractiveTranslationModel> CreateAsync( + public IInteractiveTranslationModel Create( string engineDir, IRangeTokenizer<string, int, string> tokenizer, IDetokenizer<string, string> detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ) { string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -22,14 +21,13 @@ public Task<IInteractiveTranslationModel> CreateAsync( LowercaseTarget = true, Truecaser = truecaser }; - return Task.FromResult(model); + return model; } - public Task<ITrainer> CreateTrainerAsync( + public ITrainer CreateTrainer( string engineDir, IRangeTokenizer<string, int, string> tokenizer, - IParallelTextCorpus corpus, - CancellationToken cancellationToken = default + IParallelTextCorpus corpus ) { string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -40,21 +38,20 @@ public Task<ITrainer> CreateTrainerAsync( LowercaseSource = true, LowercaseTarget = true }; - return Task.FromResult(trainer); + return trainer; } - public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) + public void InitNew(string engineDir) { if (!Directory.Exists(engineDir)) Directory.CreateDirectory(engineDir); ZipFile.ExtractToDirectory(_options.CurrentValue.NewModelFile, engineDir); - return Task.CompletedTask; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { if (!Directory.Exists(engineDir)) - return Task.CompletedTask; + return; DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "lm")); DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "tm")); string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -62,7 +59,6 @@ public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = File.Delete(smtConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); - return Task.CompletedTask; } public async Task UpdateEngineFromAsync( diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs b/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs new file mode 100644 index 00000000..8f33674d --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs @@ -0,0 +1,23 @@ +namespace Serval.Machine.Shared.Services; + +public class TimeoutInterceptor(ILogger<TimeoutInterceptor> logger) : Interceptor +{ + private readonly ILogger<TimeoutInterceptor> _logger = logger; + + public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>( + TRequest request, + ServerCallContext context, + UnaryServerMethod<TRequest, TResponse> continuation + ) + { + try + { + return await continuation(request, context); + } + catch (TimeoutException te) + { + _logger.LogError(te, "The method {Method} took too long to complete.", context.Method); + throw new RpcException(new Status(StatusCode.Unavailable, "The method took too long to complete.")); + } + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs index a140792b..7834bd73 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs @@ -2,12 +2,11 @@ public class TransferEngineFactory : ITransferEngineFactory { - public Task<ITranslationEngine?> CreateAsync( + public ITranslationEngine? Create( string engineDir, IRangeTokenizer<string, int, string> tokenizer, IDetokenizer<string, string> detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ) { string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); @@ -35,19 +34,18 @@ public class TransferEngineFactory : ITransferEngineFactory Truecaser = truecaser }; } - return Task.FromResult(transferEngine); + return transferEngine; } - public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) + public void InitNew(string engineDir) { // TODO: generate source and target config files - return Task.CompletedTask; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { if (!Directory.Exists(engineDir)) - return Task.CompletedTask; + return; string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); if (File.Exists(hcSrcConfigFileName)) File.Delete(hcSrcConfigFileName); @@ -56,6 +54,5 @@ public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = File.Delete(hcTrgConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); - return Task.CompletedTask; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs index cbf9c8b5..0821c10e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs @@ -2,32 +2,26 @@ public class UnigramTruecaserFactory : ITruecaserFactory { - public async Task<ITruecaser> CreateAsync(string engineDir, CancellationToken cancellationToken = default) + public ITruecaser Create(string engineDir) { var truecaser = new UnigramTruecaser(); string path = GetModelPath(engineDir); - await truecaser.LoadAsync(path); + truecaser.Load(path); return truecaser; } - public Task<ITrainer> CreateTrainerAsync( - string engineDir, - ITokenizer<string, int, string> tokenizer, - ITextCorpus corpus, - CancellationToken cancellationToken = default - ) + public ITrainer CreateTrainer(string engineDir, ITokenizer<string, int, string> tokenizer, ITextCorpus corpus) { string path = GetModelPath(engineDir); ITrainer trainer = new UnigramTruecaserTrainer(path, corpus) { Tokenizer = tokenizer }; - return Task.FromResult(trainer); + return trainer; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { string path = GetModelPath(engineDir); if (File.Exists(path)) File.Delete(path); - return Task.CompletedTask; } private static string GetModelPath(string engineDir) diff --git a/src/Machine/src/Serval.Machine.Shared/Usings.cs b/src/Machine/src/Serval.Machine.Shared/Usings.cs index 159f4f01..8d75abec 100644 --- a/src/Machine/src/Serval.Machine.Shared/Usings.cs +++ b/src/Machine/src/Serval.Machine.Shared/Usings.cs @@ -1,7 +1,6 @@ global using System.Collections.Concurrent; global using System.Data; global using System.Diagnostics; -global using System.Diagnostics.CodeAnalysis; global using System.Formats.Tar; global using System.Globalization; global using System.IO.Compression; diff --git a/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj b/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj index d40fb933..2d8c5622 100644 --- a/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj +++ b/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj @@ -21,9 +21,9 @@ <PackageReference Include="AspNetCore.HealthChecks.Aws.S3" Version="6.0.2" /> <PackageReference Include="AspNetCore.HealthChecks.OpenIdConnectServer" Version="6.0.2" /> <PackageReference Include="AspNetCore.HealthChecks.System" Version="6.0.2" /> - <PackageReference Include="Hangfire" Version="1.7.33" /> - <PackageReference Include="Hangfire.Mongo" Version="1.9.2" /> - <PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="6.0.14" /> + <PackageReference Include="Hangfire" Version="1.8.14" /> + <PackageReference Include="Hangfire.Mongo" Version="1.9.10" /> + <PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.8" /> <PackageReference Include="NSwag.AspNetCore" Version="14.1.0" /> <PackageReference Include="NSwag.MSBuild" Version="14.1.0"> <PrivateAssets>all</PrivateAssets> diff --git a/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs b/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs index 17cdf116..459d3b34 100644 --- a/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs +++ b/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs @@ -311,7 +311,7 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange<Job> change = await TaskEx.Timeout( + (_, EntityChange<Job> change) = await TaskEx.Timeout( ct => _jobService.GetNewerRevisionAsync(jobId, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, cancellationToken: cancellationToken diff --git a/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj b/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj index 96391e79..81838382 100644 --- a/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj +++ b/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj @@ -19,6 +19,7 @@ </ItemGroup> <ItemGroup> + <ProjectReference Include="..\..\..\ServiceToolkit\src\SIL.ServiceToolkit\SIL.ServiceToolkit.csproj" /> <ProjectReference Include="..\Serval.Grpc\Serval.Grpc.csproj" /> <ProjectReference Include="..\Serval.Shared\Serval.Shared.csproj" /> </ItemGroup> diff --git a/src/Serval/src/Serval.Assessment/Usings.cs b/src/Serval/src/Serval.Assessment/Usings.cs index 17020327..29d2b2f7 100644 --- a/src/Serval/src/Serval.Assessment/Usings.cs +++ b/src/Serval/src/Serval.Assessment/Usings.cs @@ -29,3 +29,4 @@ global using Serval.Shared.Utils; global using SIL.DataAccess; global using SIL.Scripture; +global using SIL.ServiceToolkit.Utils; diff --git a/src/Serval/src/Serval.Shared/Utils/TaskEx.cs b/src/Serval/src/Serval.Shared/Utils/TaskEx.cs deleted file mode 100644 index edceaa93..00000000 --- a/src/Serval/src/Serval.Shared/Utils/TaskEx.cs +++ /dev/null @@ -1,50 +0,0 @@ -namespace Serval.Shared.Utils; - -public static class TaskEx -{ - public static async Task<T?> Timeout<T>( - Func<CancellationToken, Task<T?>> action, - TimeSpan timeout, - CancellationToken cancellationToken = default - ) - { - if (timeout == System.Threading.Timeout.InfiniteTimeSpan) - return await action(cancellationToken); - - var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - Task<T?> task = action(cts.Token); - Task<T?> delayTask = Delay<T?>(timeout, cancellationToken); - var completedTask = await Task.WhenAny(task, delayTask); - if (delayTask.Status == TaskStatus.RanToCompletion) - cts.Cancel(); - return await completedTask; - } - - public static async Task Timeout( - Func<CancellationToken, Task> action, - TimeSpan timeout, - CancellationToken cancellationToken = default - ) - { - if (timeout == System.Threading.Timeout.InfiniteTimeSpan) - { - await action(cancellationToken); - } - else - { - var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - Task task = action(cts.Token); - Task delayTask = Task.Delay(timeout, cancellationToken); - var completedTask = await Task.WhenAny(task, delayTask); - if (delayTask.Status == TaskStatus.RanToCompletion) - cts.Cancel(); - await completedTask; - } - } - - public static async Task<T?> Delay<T>(TimeSpan timeout, CancellationToken cancellationToken = default) - { - await Task.Delay(timeout, cancellationToken); - return default; - } -} diff --git a/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs b/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs index 9b13f1f3..75f29483 100644 --- a/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs +++ b/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs @@ -754,7 +754,7 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange<Build> change = await TaskEx.Timeout( + (_, EntityChange<Build> change) = await TaskEx.Timeout( ct => _buildService.GetNewerRevisionAsync(buildId, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, cancellationToken: cancellationToken @@ -864,7 +864,7 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange<Build> change = await TaskEx.Timeout( + (_, EntityChange<Build> change) = await TaskEx.Timeout( ct => _buildService.GetActiveNewerRevisionAsync(id, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, cancellationToken: cancellationToken diff --git a/src/Serval/src/Serval.Translation/Serval.Translation.csproj b/src/Serval/src/Serval.Translation/Serval.Translation.csproj index 96391e79..81838382 100644 --- a/src/Serval/src/Serval.Translation/Serval.Translation.csproj +++ b/src/Serval/src/Serval.Translation/Serval.Translation.csproj @@ -19,6 +19,7 @@ </ItemGroup> <ItemGroup> + <ProjectReference Include="..\..\..\ServiceToolkit\src\SIL.ServiceToolkit\SIL.ServiceToolkit.csproj" /> <ProjectReference Include="..\Serval.Grpc\Serval.Grpc.csproj" /> <ProjectReference Include="..\Serval.Shared\Serval.Shared.csproj" /> </ItemGroup> diff --git a/src/Serval/src/Serval.Translation/Usings.cs b/src/Serval/src/Serval.Translation/Usings.cs index 77bb4439..1d3800f3 100644 --- a/src/Serval/src/Serval.Translation/Usings.cs +++ b/src/Serval/src/Serval.Translation/Usings.cs @@ -27,3 +27,4 @@ global using Serval.Translation.Models; global using Serval.Translation.Services; global using SIL.DataAccess; +global using SIL.ServiceToolkit.Utils; diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs new file mode 100644 index 00000000..b6404bf5 --- /dev/null +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs @@ -0,0 +1,53 @@ +namespace SIL.ServiceToolkit.Utils; + +public static class TaskEx +{ + public static async Task<(bool, T?)> Timeout<T>( + Func<CancellationToken, Task<T>> action, + TimeSpan timeout, + CancellationToken cancellationToken = default + ) + { + if (timeout == System.Threading.Timeout.InfiniteTimeSpan) + return (true, await action(cancellationToken)); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task<T> task = action(cts.Token); + Task<T> completedTask = await Task.WhenAny(task, Delay<T>(timeout, cancellationToken)); + if (completedTask == task) + return (true, await task); + + cts.Cancel(); + return (false, default); + } + + public static async Task<bool> Timeout( + Func<CancellationToken, Task> action, + TimeSpan timeout, + CancellationToken cancellationToken = default + ) + { + if (timeout == System.Threading.Timeout.InfiniteTimeSpan) + { + await action(cancellationToken); + return true; + } + else + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task task = action(cts.Token); + Task completedTask = await Task.WhenAny(task, Task.Delay(timeout, cancellationToken)); + if (completedTask == task) + return true; + + cts.Cancel(); + return false; + } + } + + private static async Task<T> Delay<T>(TimeSpan timeout, CancellationToken cancellationToken = default) + { + await Task.Delay(timeout, cancellationToken); + return default!; + } +}