Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking client network addresses, close #100. #101

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CoreRemoting.Channels.Quic/QuicServerConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ private Guid CreateRemotingSession()
clientPublicKey = null;

Session = RemotingServer.SessionRepository.CreateSession(
clientPublicKey, RemotingServer, this);
clientPublicKey, Connection.RemoteEndPoint.ToString(),
RemotingServer, this);

return Session.SessionId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ protected override void OnMessage(MessageEventArgs e)
_session =
_server.SessionRepository.CreateSession(
clientPublicKey,
Context.UserEndPoint.ToString(),
_server,
this);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System;
using System.Net.Sockets;
using System.Reflection;
using WebSocketSharp.Server;

namespace CoreRemoting.Channels.WebsocketSharp
Expand Down
65 changes: 65 additions & 0 deletions CoreRemoting.Tests/RpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -902,5 +902,70 @@ public void Authentication_handler_has_access_to_the_current_session()
server.Config.AuthenticationRequired = false;
}
}

[Fact]
public void Authentication_handler_can_check_client_address()
{
var server = _serverFixture.Server;
var authProvider = server.Config.AuthenticationProvider;
server.Config.AuthenticationRequired = true;
server.Config.AuthenticationProvider = new FakeAuthProvider
{
AuthenticateFake = c =>
{
var address = RemotingSession.Current.ClientAddress ??
throw new ArgumentNullException();

// allow only localhost connections
return address.Contains("127.0.0.1") || // ipv4
address.Contains("[::1]"); // ipv6
}
};

try
{
using var client = new RemotingClient(new ClientConfig()
{
ConnectionTimeout = 0,
InvocationTimeout = 0,
SendTimeout = 0,
Channel = ClientChannel,
MessageEncryption = false,
ServerPort = _serverFixture.Server.Config.NetworkPort,
Credentials = [new Credential()],
});

client.Connect();

var proxy = client.CreateProxy<ITestService>();
Assert.Equal("123", proxy.Reverse("321"));
}
finally
{
server.Config.AuthenticationProvider = authProvider;
server.Config.AuthenticationRequired = false;
}
}

[Fact]
public void ServerComponent_can_track_client_network_address()
{
using var client = new RemotingClient(new ClientConfig()
{
ConnectionTimeout = 0,
InvocationTimeout = 0,
SendTimeout = 0,
MessageEncryption = false,
Channel = ClientChannel,
ServerPort = _serverFixture.Server.Config.NetworkPort,
});

client.Connect();

var proxy = client.CreateProxy<ISessionAwareService>();

// what's my address as seen by remote server?
Assert.NotNull(proxy.ClientAddress);
}
}
}
2 changes: 2 additions & 0 deletions CoreRemoting.Tests/Tools/ISessionAwareService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
public interface ISessionAwareService
{
bool HasSameSessionInstance { get; }

string ClientAddress { get; }
}
}
7 changes: 7 additions & 0 deletions CoreRemoting.Tests/Tools/SessionAwareService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ public SessionAwareService()
CurrentSession = RemotingSession.Current;
if (CurrentSession == null)
throw new ArgumentNullException(nameof(CurrentSession));

if (CurrentSession.ClientAddress == null)
throw new ArgumentNullException(nameof(CurrentSession.ClientAddress));
Console.WriteLine(CurrentSession.ClientAddress);
}

public RemotingSession CurrentSession { get; }

public bool HasSameSessionInstance =>
ReferenceEquals(CurrentSession, RemotingSession.Current);

public string ClientAddress =>
CurrentSession.ClientAddress;
}
}
1 change: 1 addition & 0 deletions CoreRemoting/Channels/Tcp/TcpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ internal void FireReceiveMessage(byte[] rawMessage, Dictionary<string, object> m
_session =
_server.SessionRepository.CreateSession(
clientPublicKey,
_clientMetadata.IpPort,
_server,
this);

Expand Down
4 changes: 3 additions & 1 deletion CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ private async Task ReceiveConnection()
// accept websocket request and start a new session
var websocketContext = await context.AcceptWebSocketAsync(null);
var websocket = websocketContext.WebSocket;
var connection = new WebsocketServerConnection(websocketContext, websocket, Server);
var connection = new WebsocketServerConnection(
context.Request.RemoteEndPoint.ToString(),
websocketContext, websocket, Server);

// handle incoming websocket messages
var sessionId = connection.StartListening();
Expand Down
7 changes: 5 additions & 2 deletions CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Net;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -17,12 +18,14 @@
/// <summary>
/// Initializes a new instance of the <see cref="WebsocketServerConnection"/> class.
/// </summary>
public WebsocketServerConnection(HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer)
public WebsocketServerConnection(string clientAddress, HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer)
{
ClientAddress = clientAddress ?? throw new ArgumentNullException(nameof(clientAddress));
WebSocketContext = websocketContext ?? throw new ArgumentNullException(nameof(websocketContext));
WebSocket = websocket ?? throw new ArgumentNullException(nameof(websocket));
RemotingServer = remotingServer ?? throw new ArgumentNullException(nameof(remotingServer));
}
public string ClientAddress { get; private set; }

Check warning on line 28 in CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'WebsocketServerConnection.ClientAddress'

Check warning on line 28 in CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'WebsocketServerConnection.ClientAddress'

private HttpListenerWebSocketContext WebSocketContext { get; set; }

Expand Down Expand Up @@ -96,7 +99,7 @@
}

Session = RemotingServer.SessionRepository.CreateSession(
clientPublicKey, RemotingServer, this);
clientPublicKey, ClientAddress, RemotingServer, this);

return Session.SessionId;
}
Expand Down
10 changes: 6 additions & 4 deletions CoreRemoting/ISessionRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,28 @@ public interface ISessionRepository : IDisposable
/// Creates a new session.
/// </summary>
/// <param name="clientPublicKey">Client's public key</param>
/// <param name="clientAddress">Client's network address</param>
/// <param name="server">Server instance</param>
/// <param name="rawMessageTransport">Component that does the raw message transport</param>
/// <returns>The newly created session</returns>
RemotingSession CreateSession(
byte[] clientPublicKey,
byte[] clientPublicKey,
string clientAddress,
IRemotingServer server,
IRawMessageTransport rawMessageTransport);

/// <summary>
/// Gets a specified session by its ID.
/// </summary>
/// <param name="sessionId">Session ID</param>
/// <returns>The session correlating to the specified session ID</returns>
RemotingSession GetSession(Guid sessionId);

/// <summary>
/// Gets a list of all sessions.
/// </summary>
IEnumerable<RemotingSession> Sessions { get; }

/// <summary>
/// Removes a specified session by its ID.
/// </summary>
Expand Down
12 changes: 10 additions & 2 deletions CoreRemoting/RemotingSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public sealed class RemotingSession : IDisposable
private readonly RsaKeyPair _keyPair;
private readonly Guid _sessionId;
private readonly byte[] _clientPublicKeyBlob;
private readonly string _clientAddress;
private readonly RemoteDelegateInvocationEventAggregator _remoteDelegateInvocationEventAggregator;
private IDelegateProxyFactory _delegateProxyFactory;
private ConcurrentDictionary<Guid, IDelegateProxy> _delegateProxyCache;
Expand All @@ -53,10 +54,11 @@ public sealed class RemotingSession : IDisposable
/// </summary>
/// <param name="keySize">Key size of the RSA keys for asymmetric encryption</param>
/// <param name="clientPublicKey">Public key of this session's client</param>
/// <param name="clientAddress">Client's network address</param>
/// <param name="server">Server instance, that hosts this session</param>
/// <param name="rawMessageTransport">Component, that does the raw message transport (send and receive)</param>
internal RemotingSession(int keySize, byte[] clientPublicKey, IRemotingServer server,
IRawMessageTransport rawMessageTransport)
internal RemotingSession(int keySize, byte[] clientPublicKey, string clientAddress,
IRemotingServer server, IRawMessageTransport rawMessageTransport)
{
_isDisposing = false;
_currentlyProcessedMessagesCounter = new CountdownEvent(initialCount: 1);
Expand All @@ -71,6 +73,7 @@ internal RemotingSession(int keySize, byte[] clientPublicKey, IRemotingServer se
_delegateProxyCache = new ConcurrentDictionary<Guid, IDelegateProxy>();
_rawMessageTransport = rawMessageTransport ?? throw new ArgumentNullException(nameof(rawMessageTransport));
_clientPublicKeyBlob = clientPublicKey;
_clientAddress = clientAddress;

_rawMessageTransport.ReceiveMessage += OnReceiveMessage;
_rawMessageTransport.ErrorOccured += OnErrorOccured;
Expand Down Expand Up @@ -195,6 +198,11 @@ private void OnErrorOccured(string errorMessage, Exception ex)
/// </summary>
public Guid SessionId => _sessionId;

/// <summary>
/// Gets this session's client network address.
/// </summary>
public string ClientAddress => _clientAddress;

/// <summary>
/// Gets whether message encryption is enabled for this session.
/// </summary>
Expand Down
26 changes: 14 additions & 12 deletions CoreRemoting/SessionRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private void StartInactiveSessionSweepTimer(int inactiveSessionSweepInterval)
{
if (inactiveSessionSweepInterval <= 0)
return;

_inactiveSessionSweepTimer =
new Timer(Convert.ToDouble(inactiveSessionSweepInterval * 1000));

Expand All @@ -49,21 +49,21 @@ private void StartInactiveSessionSweepTimer(int inactiveSessionSweepInterval)
}

/// <summary>
/// Event procedure: Called when the inactive session sweep timer elapses.
/// Event procedure: Called when the inactive session sweep timer elapses.
/// </summary>
/// <param name="sender">Event sender</param>
/// <param name="e">Event arguments</param>
private void InactiveSessionSweepTimerOnElapsed(object sender, ElapsedEventArgs e)
{
if (_inactiveSessionSweepTimer == null)
return;

if (!_inactiveSessionSweepTimer.Enabled)
return;

var inactiveSessionIdList =
_sessions
.Where(item =>
.Where(item =>
DateTime.Now.Subtract(item.Value.LastActivityTimestamp).TotalMinutes > _maximumSessionInactivityTime)
.Select(item => item.Key);

Expand All @@ -82,25 +82,27 @@ private void InactiveSessionSweepTimerOnElapsed(object sender, ElapsedEventArgs
/// Creates a new session.
/// </summary>
/// <param name="clientPublicKey">Client's public key</param>
/// <param name="clientAddress">Client's network address</param>
/// <param name="server">Server instance</param>
/// <param name="rawMessageTransport">Component that does the raw message transport</param>
/// <returns>The newly created session</returns>
public RemotingSession CreateSession(byte[] clientPublicKey, IRemotingServer server, IRawMessageTransport rawMessageTransport)
public RemotingSession CreateSession(byte[] clientPublicKey, string clientAddress, IRemotingServer server, IRawMessageTransport rawMessageTransport)
{
if (server == null)
throw new ArgumentException(nameof(server));

if (rawMessageTransport == null)
throw new ArgumentNullException(nameof(rawMessageTransport));

var session = new RemotingSession(
KeySize,
KeySize,
clientPublicKey,
clientAddress,
server,
rawMessageTransport);

_sessions.TryAdd(session.SessionId, session);

return session;
}

Expand All @@ -114,7 +116,7 @@ public RemotingSession GetSession(Guid sessionId)
{
if (_sessions.TryGetValue(sessionId, out var session))
return session;

throw new KeyNotFoundException($"Session '{sessionId}' not found.");
}

Expand Down
Loading