Skip to content

Commit

Permalink
Fix auto flush for Enumerable and AsyncEnumerable methods
Browse files Browse the repository at this point in the history
  • Loading branch information
maca88 committed Mar 1, 2020
1 parent 50c0420 commit ff31f42
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 67 deletions.
8 changes: 6 additions & 2 deletions src/NHibernate.Test/Async/GenericTest/Methods/Fixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ protected override string[] Mappings
{
get
{
return new string[] { "One.hbm.xml", "Many.hbm.xml" };
return new string[] { "One.hbm.xml", "Many.hbm.xml", "Simple.hbm.xml" };
}
}

Expand All @@ -49,12 +49,15 @@ protected override void OnSetUp()
many2.One = one;
one.Manies.Add( many2 );

using( ISession s = OpenSession() )
var simple = new Simple(1) {Count = 1};

using ( ISession s = OpenSession() )
using( ITransaction t = s.BeginTransaction() )
{
s.Save( one );
s.Save( many1 );
s.Save( many2 );
s.Save(simple, 1);
t.Commit();
}
}
Expand All @@ -66,6 +69,7 @@ protected override void OnTearDown()
{
session.Delete( "from Many" );
session.Delete( "from One" );
session.Delete("from Simple");
tx.Commit();
}
base.OnTearDown();
Expand Down
76 changes: 74 additions & 2 deletions src/NHibernate.Test/GenericTest/Methods/Fixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ protected override string[] Mappings
{
get
{
return new string[] { "One.hbm.xml", "Many.hbm.xml" };
return new string[] { "One.hbm.xml", "Many.hbm.xml", "Simple.hbm.xml" };
}
}

Expand All @@ -39,12 +39,15 @@ protected override void OnSetUp()
many2.One = one;
one.Manies.Add( many2 );

using( ISession s = OpenSession() )
var simple = new Simple(1) {Count = 1};

using ( ISession s = OpenSession() )
using( ITransaction t = s.BeginTransaction() )
{
s.Save( one );
s.Save( many1 );
s.Save( many2 );
s.Save(simple, 1);
t.Commit();
}
}
Expand All @@ -56,6 +59,7 @@ protected override void OnTearDown()
{
session.Delete( "from Many" );
session.Delete( "from One" );
session.Delete("from Simple");
tx.Commit();
}
base.OnTearDown();
Expand Down Expand Up @@ -106,6 +110,40 @@ public void QueryEnumerable()
}
}

[Test]
public void AutoFlushQueryEnumerable()
{
using (var s = OpenSession())
using (var t = s.BeginTransaction())
{
Assert.That(s.FlushMode, Is.EqualTo(FlushMode.Auto));
var results = s.CreateQuery("from Simple").Enumerable<Simple>();

var id = 2;
var simple = new Simple(id) {Count = id};
s.Save(simple, id);
var enumerator = results.GetEnumerator();

Assert.That(enumerator.MoveNext(), Is.True);
Assert.That(enumerator.MoveNext(), Is.True);
Assert.That(enumerator.MoveNext(), Is.False);
enumerator.Dispose();

id++;
simple = new Simple(id) {Count = id};
s.Save(simple, id);
enumerator = results.GetEnumerator();

Assert.That(enumerator.MoveNext(), Is.True);
Assert.That(enumerator.MoveNext(), Is.True);
Assert.That(enumerator.MoveNext(), Is.True);
Assert.That(enumerator.MoveNext(), Is.False);
enumerator.Dispose();

t.Rollback();
}
}

[Test]
public async Task QueryEnumerableAsync()
{
Expand All @@ -120,6 +158,40 @@ public async Task QueryEnumerableAsync()
}
}

[Test]
public async Task AutoFlushQueryEnumerableAsync()
{
using (var s = OpenSession())
using (var t = s.BeginTransaction())
{
Assert.That(s.FlushMode, Is.EqualTo(FlushMode.Auto));
var results = s.CreateQuery("from Simple").AsyncEnumerable<Simple>();

var id = 2;
var simple = new Simple(id) {Count = id};
s.Save(simple, id);
var enumerator = results.GetAsyncEnumerator();

Assert.That(await enumerator.MoveNextAsync(), Is.True);
Assert.That(await enumerator.MoveNextAsync(), Is.True);
Assert.That(await enumerator.MoveNextAsync(), Is.False);
await enumerator.DisposeAsync();

id++;
simple = new Simple(id) {Count = id};
s.Save(simple, id);
enumerator = results.GetAsyncEnumerator();

Assert.That(await enumerator.MoveNextAsync(), Is.True);
Assert.That(await enumerator.MoveNextAsync(), Is.True);
Assert.That(await enumerator.MoveNextAsync(), Is.True);
Assert.That(await enumerator.MoveNextAsync(), Is.False);
await enumerator.DisposeAsync();

await t.RollbackAsync();
}
}

[Test]
public void Filter()
{
Expand Down
44 changes: 22 additions & 22 deletions src/NHibernate/Async/Impl/SessionImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,38 +302,38 @@ public override async Task<IQueryTranslator[]> GetQueriesAsync(IQueryExpression
/// <inheritdoc />
// Since v5.3
[Obsolete("Use AsyncEnumerable extension method instead.")]
public override async Task<IEnumerable<T>> EnumerableAsync<T>(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
public override Task<IEnumerable<T>> EnumerableAsync<T>(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
using (BeginProcess())
if (cancellationToken.IsCancellationRequested)
{
queryParameters.ValidateParameters();
var plan = GetHQLQueryPlan(queryExpression, true);
await (AutoFlushIfRequiredAsync(plan.QuerySpaces, cancellationToken)).ConfigureAwait(false);

using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
{
return plan.PerformIterate<T>(queryParameters, this);
}
return Task.FromCanceled<IEnumerable<T>>(cancellationToken);
}
try
{
return Task.FromResult<IEnumerable<T>>(Enumerable<T>(queryExpression, queryParameters));
}
catch (Exception ex)
{
return Task.FromException<IEnumerable<T>>(ex);
}
}

/// <inheritdoc />
// Since v5.3
[Obsolete("Use AsyncEnumerable extension method instead.")]
public override async Task<IEnumerable> EnumerableAsync(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
public override Task<IEnumerable> EnumerableAsync(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
using (BeginProcess())
if (cancellationToken.IsCancellationRequested)
{
queryParameters.ValidateParameters();
var plan = GetHQLQueryPlan(queryExpression, true);
await (AutoFlushIfRequiredAsync(plan.QuerySpaces, cancellationToken)).ConfigureAwait(false);

using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
{
return plan.PerformIterate(queryParameters, this);
}
return Task.FromCanceled<IEnumerable>(cancellationToken);
}
try
{
return Task.FromResult<IEnumerable>(Enumerable(queryExpression, queryParameters));
}
catch (Exception ex)
{
return Task.FromException<IEnumerable>(ex);
}
}

Expand Down
30 changes: 17 additions & 13 deletions src/NHibernate/Async/Loader/Hql/QueryLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,26 @@ protected override async Task<object[]> GetResultRowAsync(object[] row, DbDataRe
internal async Task<InitializeEnumerableResult> InitializeEnumerableAsync(QueryParameters queryParameters, ISessionImplementor session, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Stopwatch stopWatch = null;
if (session.Factory.Statistics.IsStatisticsEnabled)
await (session.AutoFlushIfRequiredAsync(_queryTranslator.QuerySpaces, cancellationToken)).ConfigureAwait(false);
using (session.SuspendAutoFlush())
{
stopWatch = Stopwatch.StartNew();
}
Stopwatch stopWatch = null;
if (session.Factory.Statistics.IsStatisticsEnabled)
{
stopWatch = Stopwatch.StartNew();
}

var command = await (PrepareQueryCommandAsync(queryParameters, false, session, cancellationToken)).ConfigureAwait(false);
var dataReader = await (GetResultSetAsync(command, queryParameters, session, null, cancellationToken)).ConfigureAwait(false);
if (stopWatch != null)
{
stopWatch.Stop();
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
}
var command = await (PrepareQueryCommandAsync(queryParameters, false, session, cancellationToken)).ConfigureAwait(false);
var dataReader = await (GetResultSetAsync(command, queryParameters, session, null, cancellationToken)).ConfigureAwait(false);
if (stopWatch != null)
{
stopWatch.Stop();
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
}

return new InitializeEnumerableResult(command, dataReader);
return new InitializeEnumerableResult(command, dataReader);
}
}
}
}
5 changes: 5 additions & 0 deletions src/NHibernate/Engine/ISessionImplementor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ internal static void AutoFlushIfRequired(this ISessionImplementor implementor, I
(implementor as AbstractSessionImpl)?.AutoFlushIfRequired(querySpaces);
}

internal static IDisposable SuspendAutoFlush(this ISessionImplementor implementor)
{
return (implementor as IEventSource)?.SuspendAutoFlush();
}

/// <summary>
/// Returns an <see cref="IAsyncEnumerable{T}" /> which can be enumerated asynchronously.
/// </summary>
Expand Down
21 changes: 6 additions & 15 deletions src/NHibernate/Impl/SessionImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,9 @@ public override IEnumerable<T> Enumerable<T>(IQueryExpression queryExpression, Q
{
queryParameters.ValidateParameters();
var plan = GetHQLQueryPlan(queryExpression, true);
AutoFlushIfRequired(plan.QuerySpaces);

using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
{
return plan.PerformIterate<T>(queryParameters, this);
}
// AutoFlushIfRequired will be called when iterating through the enumerable
return plan.PerformIterate<T>(queryParameters, this);
}
}

Expand All @@ -600,12 +597,9 @@ public override IAsyncEnumerable<T> AsyncEnumerable<T>(IQueryExpression queryExp
{
queryParameters.ValidateParameters();
var plan = GetHQLQueryPlan(queryExpression, true);
AutoFlushIfRequired(plan.QuerySpaces);

using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
{
return plan.PerformAsyncIterate<T>(queryParameters, this);
}
// AutoFlushIfRequired will be called when iterating through the enumerable
return plan.PerformAsyncIterate<T>(queryParameters, this);
}
}

Expand All @@ -615,12 +609,9 @@ public override IEnumerable Enumerable(IQueryExpression queryExpression, QueryPa
{
queryParameters.ValidateParameters();
var plan = GetHQLQueryPlan(queryExpression, true);
AutoFlushIfRequired(plan.QuerySpaces);

using (SuspendAutoFlush()) //stops flush being called multiple times if this method is recursively called
{
return plan.PerformIterate(queryParameters, this);
}
// AutoFlushIfRequired will be called when iterating through the enumerable
return plan.PerformIterate(queryParameters, this);
}
}

Expand Down
30 changes: 17 additions & 13 deletions src/NHibernate/Loader/Hql/QueryLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -486,22 +486,26 @@ internal AsyncEnumerableImpl<T> GetAsyncEnumerable<T>(QueryParameters queryParam

internal InitializeEnumerableResult InitializeEnumerable(QueryParameters queryParameters, ISessionImplementor session)
{
Stopwatch stopWatch = null;
if (session.Factory.Statistics.IsStatisticsEnabled)
session.AutoFlushIfRequired(_queryTranslator.QuerySpaces);
using (session.SuspendAutoFlush())
{
stopWatch = Stopwatch.StartNew();
}
Stopwatch stopWatch = null;
if (session.Factory.Statistics.IsStatisticsEnabled)
{
stopWatch = Stopwatch.StartNew();
}

var command = PrepareQueryCommand(queryParameters, false, session);
var dataReader = GetResultSet(command, queryParameters, session, null);
if (stopWatch != null)
{
stopWatch.Stop();
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
}
var command = PrepareQueryCommand(queryParameters, false, session);
var dataReader = GetResultSet(command, queryParameters, session, null);
if (stopWatch != null)
{
stopWatch.Stop();
session.Factory.StatisticsImplementor.QueryExecuted("HQL: " + _queryTranslator.QueryString, 0, stopWatch.Elapsed);
session.Factory.StatisticsImplementor.QueryExecuted(QueryIdentifier, 0, stopWatch.Elapsed);
}

return new InitializeEnumerableResult(command, dataReader);
return new InitializeEnumerableResult(command, dataReader);
}
}

protected override void ResetEffectiveExpectedType(IEnumerable<IParameterSpecification> parameterSpecs, QueryParameters queryParameters)
Expand Down

0 comments on commit ff31f42

Please sign in to comment.