Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Added tests for Stop, Start and Use #74

Merged
merged 6 commits into from
Oct 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
dotnet:
[{ ch: "3.1", ver: "netcoreapp31" }, { ch: "6.0", ver: "net60" }]
nightly: [true, false]
fail-fast: false

steps:
- name: Checkout
Expand Down
9 changes: 9 additions & 0 deletions src/Driver/Rest/DatabaseRest.IDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace SurrealDB.Driver.Rest;

public sealed partial class DatabaseRest : IDatabase {
async Task<IResponse> IDatabase.Info(CancellationToken ct) {
ThrowIfInvalidConnection();
return await Info(ct);
}

Expand All @@ -29,34 +30,42 @@ async Task<IResponse> IDatabase.Authenticate(string token, CancellationToken ct)
}

async Task<IResponse> IDatabase.Let(string key, object? value, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Let(key, value, ct);
}

async Task<IResponse> IDatabase.Query(string sql, IReadOnlyDictionary<string,object?>? vars, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Query(sql, vars, ct);
}

async Task<IResponse> IDatabase.Select(Thing thing, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Select(thing, ct);
}

async Task<IResponse> IDatabase.Create(Thing thing, object data, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Create(thing, data, ct);
}

async Task<IResponse> IDatabase.Update(Thing thing, object data, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Update(thing, data, ct);
}

async Task<IResponse> IDatabase.Change(Thing thing, object data, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Change(thing, data, ct);
}

async Task<IResponse> IDatabase.Modify(Thing thing, Patch[] patches, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Modify(thing, patches, ct);
}

async Task<IResponse> IDatabase.Delete(Thing thing, CancellationToken ct) {
ThrowIfInvalidConnection();
return await Delete(thing, ct);
}
}
40 changes: 31 additions & 9 deletions src/Driver/Rest/DatabaseRest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Net.Http.Headers;
using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;

Expand All @@ -23,6 +23,23 @@ public DatabaseRest(in Configuration.Config config) {

private readonly Dictionary<string, object?> _vars = new();

private const string NAMESPACE = "NS";
private const string DATABASE = "DB";

/// <summary>
/// Indicates whether the client has valid connection details.
/// </summary>
public bool InvalidConnectionDetails =>
_client.DefaultRequestHeaders.Contains(NAMESPACE) &&
_client.DefaultRequestHeaders.Contains(DATABASE) &&
_client.DefaultRequestHeaders.Authorization != null;

private void ThrowIfInvalidConnection() {
if (!InvalidConnectionDetails) {
throw new InvalidOperationException("The connection details is invalid.");
}
}

public void Dispose() {
_client.Dispose();
}
Expand Down Expand Up @@ -60,6 +77,7 @@ public Task Open(CancellationToken ct = default) {
}

public Task Close(CancellationToken ct = default) {
Invalidate(ct);
return Task.CompletedTask;
}

Expand All @@ -71,8 +89,8 @@ public Task<RestResponse> Info(CancellationToken ct = default) {
}

public Task<RestResponse> Use(
string db,
string ns,
string? db,
string? ns,
CancellationToken ct = default) {
SetUse(db, ns);

Expand Down Expand Up @@ -199,13 +217,17 @@ private void RemoveAuth() {
private void SetUse(
string? db,
string? ns) {
_config.Database = db;
_config.Namespace = ns;
if (db != null) {
_config.Database = db;
_client.DefaultRequestHeaders.Remove(DATABASE);
_client.DefaultRequestHeaders.Add(DATABASE, db);
}

_client.DefaultRequestHeaders.Remove("DB");
_client.DefaultRequestHeaders.Add("DB", db);
_client.DefaultRequestHeaders.Remove("NS");
_client.DefaultRequestHeaders.Add("NS", ns);
if (ns != null) {
_config.Namespace = ns;
_client.DefaultRequestHeaders.Remove(NAMESPACE);
_client.DefaultRequestHeaders.Add(NAMESPACE, ns);
}
}

/// <inheritdoc cref="Signup(Authentication, CancellationToken)" />
Expand Down
6 changes: 4 additions & 2 deletions src/Driver/Rpc/DatabaseRpc.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SurrealDB.Abstractions;
using SurrealDB.Abstractions;
using SurrealDB.Configuration;
using SurrealDB.Models;
using SurrealDB.Ws;
Expand Down Expand Up @@ -33,6 +33,7 @@ public async Task Open(CancellationToken ct = default) {
return;
}
_config.ThrowIfInvalid();

_configured = true;

// Open connection
Expand All @@ -47,13 +48,14 @@ public async Task Open(CancellationToken ct = default) {
}

public async Task Close(CancellationToken ct = default) {
_configured = false;
await _client.Close(ct);
}

/// <param name="ct"> </param>
/// <inheritdoc />
public async Task<RpcResponse> Info(CancellationToken ct) {
return await _client.Send(new() { method = "info", }).ToSurreal();
return await _client.Send(new() { method = "info", }, ct).ToSurreal();
}

/// <inheritdoc />
Expand Down
11 changes: 10 additions & 1 deletion src/Ws/WsTx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ public async Task Close(CancellationToken ct = default) {
if (_ws.State == WebSocketState.Closed) {
return;
}
await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client disconnect", ct);

try {
await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client disconnect", ct);
} catch (OperationCanceledException) {
if (ct.IsCancellationRequested) {
// Catch any canceled exception that is generated during the close,
// but still throw for cancellations that we requested.
throw;
}
}
}

public void Dispose() {
Expand Down
114 changes: 114 additions & 0 deletions tests/Driver.Tests/Queries/GeneralQueryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,120 @@ private class MathResultDocument {
public float result {get; set;}
}

[Fact]
public async Task StopStartConnectionTest() => await DbHandle<T>.WithDatabase(
async db => {
string sql = "INFO FOR DB;";
var response = await db.Query(sql, null);
Assert.NotNull(response);
TestHelper.AssertOk(response);

await db.Close();
await Assert.ThrowsAsync<InvalidOperationException>(async () => await db.Query(sql, null));
db.Dispose();
await Assert.ThrowsAsync<InvalidOperationException>(async () => await db.Query(sql, null));

db = new();
await db.Open(TestHelper.Default);

response = await db.Query(sql, null);
Assert.NotNull(response);
TestHelper.AssertOk(response);
}
);

[Fact]
public async Task SwitchDatabaseTest() => await DbHandle<T>.WithDatabase(
async db => {
var nsName = db.GetConfig().Namespace!;
var originalDbName = db.GetConfig().Database!;
var otherDbName = "DifferentDb";

TestObject<int, string> expectedOriginalObject = new(1, originalDbName);
TestObject<int, string> expectedOtherObject = new(1, otherDbName);

Thing thing = Thing.From("object", expectedOriginalObject.Key.ToString());
await db.Create(thing, expectedOriginalObject);

{
var useResponse = await db.Use(otherDbName, nsName);
Assert.NotNull(useResponse);
TestHelper.AssertOk(useResponse);

await db.Create(thing, expectedOtherObject);

var response = await db.Select(thing);

Assert.NotNull(response);
TestHelper.AssertOk(response);
Assert.True(response.TryGetResult(out Result result));
TestObject<int, string>? doc = result.GetObject<TestObject<int, string>>();
doc.Should().BeEquivalentTo(expectedOtherObject);
}

{
var useResponse = await db.Use(originalDbName, nsName);
Assert.NotNull(useResponse);
TestHelper.AssertOk(useResponse);

var response = await db.Select(thing);

Assert.NotNull(response);
TestHelper.AssertOk(response);
Assert.True(response.TryGetResult(out Result result));
TestObject<int, string>? doc = result.GetObject<TestObject<int, string>>();
doc.Should().BeEquivalentTo(expectedOriginalObject);
}

}
);

[Fact]
public async Task SwitchNamespaceTest() => await DbHandle<T>.WithDatabase(
async db => {
var originalNsName = db.GetConfig().Namespace!;
var otherNsName = "DifferentNs";
var dbName = db.GetConfig().Database!;

TestObject<int, string> expectedOriginalObject = new(1, originalNsName);
TestObject<int, string> expectedOtherObject = new(1, otherNsName);

Thing thing = Thing.From("object", expectedOriginalObject.Key.ToString());
await db.Create(thing, expectedOriginalObject);

{
var useResponse = await db.Use(dbName, otherNsName);
Assert.NotNull(useResponse);
TestHelper.AssertOk(useResponse);

await db.Create(thing, expectedOtherObject);

var response = await db.Select(thing);

Assert.NotNull(response);
TestHelper.AssertOk(response);
Assert.True(response.TryGetResult(out Result result));
TestObject<int, string>? doc = result.GetObject<TestObject<int, string>>();
doc.Should().BeEquivalentTo(expectedOtherObject);
}

{
var useResponse = await db.Use(dbName, originalNsName);
Assert.NotNull(useResponse);
TestHelper.AssertOk(useResponse);

var response = await db.Select(thing);

Assert.NotNull(response);
TestHelper.AssertOk(response);
Assert.True(response.TryGetResult(out Result result));
TestObject<int, string>? doc = result.GetObject<TestObject<int, string>>();
doc.Should().BeEquivalentTo(expectedOriginalObject);
}

}
);

[Fact]
public async Task SimpleFuturesQueryTest() => await DbHandle<T>.WithDatabase(
async db => {
Expand Down
12 changes: 12 additions & 0 deletions tests/Shared/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,16 @@ public static void AssertOk(
Exception ex = new($"Expected Ok, got error code {err.Code} ({err.Message}) in {caller}");
throw ex;
}

public static void AssertError(
in IResponse rpcResponse,
// [CallerArgumentExpression("rpcResponse")]
string caller = "") {
if (rpcResponse.TryGetError(out Error err)) {
return;
}

Exception ex = new($"Expected Error, got ok response in {caller}");
throw ex;
}
}