diff --git a/VindicateLib/ClientActioner.cs b/VindicateLib/ClientActioner.cs index 9415f81..7433ebe 100644 --- a/VindicateLib/ClientActioner.cs +++ b/VindicateLib/ClientActioner.cs @@ -61,12 +61,12 @@ public Byte[] Receive(Socket client, out IPEndPoint remoteEndPoint) { remoteEndPoint = (IPEndPoint) newSocket.RemoteEndPoint; var data = new List(); - while (client.Available != 0) + while (newSocket.Available != 0) { - Int32 read = client.Receive(buffer); + Int32 read = newSocket.Receive(buffer); data.AddRange(buffer.Take(read)); } - client.Close(); + newSocket.Disconnect(true); return data.ToArray(); } } diff --git a/VindicateLib/Detector.cs b/VindicateLib/Detector.cs index 64ffe04..5e5fccb 100644 --- a/VindicateLib/Detector.cs +++ b/VindicateLib/Detector.cs @@ -37,7 +37,7 @@ public sealed class Detector : IDisposable private readonly NameServiceClientImpl _nameServiceClient; private ConfidenceLevel _highestConfidenceLevel = ConfidenceLevel.FalsePositive; private Random _fastRandom; - private Socket _llmnrClient, _nbnsClient, _mdnsClient; + private Socket _llmnrUDPClient, _nbnsUDPClient, _mdnsUDPClient; private String _localBroadcast; private Boolean _performSending = false; private Boolean _performListening = false; @@ -100,11 +100,14 @@ public void BeginSendingAndListening() while (_performSending) { if (_settings.UseLLMNR) - _nameServiceClient.SendRequest(_llmnrClient, Protocol.LLMNR, _settings.LLMNRTarget, _localBroadcast, clientActioner); + _nameServiceClient.SendRequest(_llmnrUDPClient, Protocol.LLMNR, _settings.LLMNRTarget, _localBroadcast, clientActioner); if (_settings.UseNBNS) - _nameServiceClient.SendRequest(_nbnsClient, Protocol.NBNS, _settings.NBNSTarget, _localBroadcast, clientActioner); + { + _nameServiceClient.SendRequest(_nbnsUDPClient, Protocol.NBNS, _settings.NBNSTarget, _localBroadcast, clientActioner); + + } if (_settings.UsemDNS) - _nameServiceClient.SendRequest(_mdnsClient, Protocol.mDNS, _settings.mDNSTarget, null, clientActioner); + _nameServiceClient.SendRequest(_mdnsUDPClient, Protocol.mDNS, _settings.mDNSTarget, null, clientActioner); OnMessagesSent(); Thread.Sleep(_settings.SendRequestFrequency); @@ -120,7 +123,7 @@ public void BeginSendingAndListening() while (_performListening) { //Valid transaction IDs should be acquired from sent requests, but for now we don't validate so send whatever - HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_llmnrClient, Protocol.LLMNR, null, clientActioner)); + HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_llmnrUDPClient, Protocol.LLMNR, null, clientActioner)); spinWait.SpinOnce(); } }); @@ -134,7 +137,7 @@ public void BeginSendingAndListening() var spinWait = new SpinWait(); while (_performListening) { - HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_nbnsClient, Protocol.NBNS, null, clientActioner)); + HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_nbnsUDPClient, Protocol.NBNS, null, clientActioner)); spinWait.SpinOnce(); } }); @@ -148,7 +151,7 @@ public void BeginSendingAndListening() var spinWait = new SpinWait(); while (_performListening) { - HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_mdnsClient, Protocol.mDNS, null, clientActioner)); + HandleResponseReceivedResult(_nameServiceClient.ReceiveAndHandleReply(_mdnsUDPClient, Protocol.mDNS, null, clientActioner)); spinWait.SpinOnce(); } }); @@ -159,9 +162,9 @@ public void EndSendingAndListening() { _performSending = false; _performListening = false; - _llmnrClient?.Close(); - _nbnsClient?.Close(); - _mdnsClient?.Close(); + _llmnrUDPClient?.Close(); + _nbnsUDPClient?.Close(); + _mdnsUDPClient?.Close(); } @@ -173,7 +176,7 @@ private void InitialiseBroadcastAddress() { if (_localBroadcast == null) { - _logger.LogMessage("Unable to find broadcast address for NBNS", EventLogEntryType.Information, + _logger.LogMessage("Unable to find broadcast address for NBNS - disabling NBNS", EventLogEntryType.Information, (Int32) LogEvents.NoBroadcastAdapterFound, (Int16) LogCategories.NonFatalError); _settings.UseNBNS = false; } @@ -281,20 +284,20 @@ private void InitialiseClients() { if (_settings.UseLLMNR) { - _llmnrClient = SocketLoader.LoadUDPSocket(Protocol.LLMNR, _settings.LLMNRPort, _settings.Verbose, _logger); - if (_llmnrClient == null) + _llmnrUDPClient = SocketLoader.LoadUDPSocket(Protocol.LLMNR, _settings.LLMNRPort, _settings.Verbose, _logger); + if (_llmnrUDPClient == null) _settings.UseLLMNR = false; } if (_settings.UseNBNS) { - _nbnsClient = SocketLoader.LoadUDPSocket(Protocol.NBNS, _settings.NBNSPort, _settings.Verbose, _logger); - if (_nbnsClient == null) + _nbnsUDPClient = SocketLoader.LoadUDPSocket(Protocol.NBNS, _settings.NBNSPort, _settings.Verbose, _logger); + if (_nbnsUDPClient == null) _settings.UseNBNS = false; } if (_settings.UsemDNS) { - _mdnsClient = SocketLoader.LoadUDPSocket(Protocol.mDNS, _settings.mDNSPort, _settings.Verbose, _logger); - if (_mdnsClient == null) + _mdnsUDPClient = SocketLoader.LoadUDPSocket(Protocol.mDNS, _settings.mDNSPort, _settings.Verbose, _logger); + if (_mdnsUDPClient == null) _settings.UsemDNS = false; } } @@ -329,9 +332,9 @@ private void OnConfidenceLevelChange() [ExcludeFromCodeCoverage()] public void Dispose() { - ((IDisposable) _llmnrClient)?.Dispose(); - ((IDisposable) _nbnsClient)?.Dispose(); - ((IDisposable) _mdnsClient)?.Dispose(); + ((IDisposable) _llmnrUDPClient)?.Dispose(); + ((IDisposable) _nbnsUDPClient)?.Dispose(); + ((IDisposable) _mdnsUDPClient)?.Dispose(); } } } diff --git a/VindicateLib/Enums/LogEvents.cs b/VindicateLib/Enums/LogEvents.cs index 1812e99..74d6f70 100644 --- a/VindicateLib/Enums/LogEvents.cs +++ b/VindicateLib/Enums/LogEvents.cs @@ -34,6 +34,8 @@ public enum LogEvents SMBTestSucceeded = 11, SMBTestFailed = 12, RunningAsAdmin = 13, - InvalidArguments = 14 + InvalidArguments = 14, + LoadedTcpClient = 15, + UnableToLoadTcpClient = 16 } } \ No newline at end of file diff --git a/VindicateLib/SocketLoader.cs b/VindicateLib/SocketLoader.cs index 00c4c1a..6927b2f 100644 --- a/VindicateLib/SocketLoader.cs +++ b/VindicateLib/SocketLoader.cs @@ -47,5 +47,34 @@ public static Socket LoadUDPSocket(Protocol protocol, Int32 port, Boolean verbos return socket; } + + public static Socket LoadTCPSocket(Protocol protocol, Int32 port, Boolean verbose, Logger logger) + { + Socket socket = null; + try + { + socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.ReceiveTimeout = 0; + socket.DontFragment = true; + socket.Bind(new IPEndPoint(IPAddress.Any, port)); + + if (protocol == Protocol.mDNS) + { + socket.SetSocketOption(SocketOptionLevel.IP, SocketOptionName.AddMembership, + new MulticastOption(IPAddress.Parse("224.0.0.251"))); + } + + if (verbose) + logger.LogMessage(String.Format("Loaded {0} service on TCP port {1}", protocol, port), EventLogEntryType.Information, (Int32)LogEvents.LoadedTcpClient, (Int16)LogCategories.LoadingInfo); + } + catch (SocketException ex) + { + logger.LogMessage(String.Format("Unable to load {0} service ({2}). Disabling. TCP Port {1} in use or insufficient privileges?", protocol, port, ex.Message), EventLogEntryType.Error + , (Int32)LogEvents.UnableToLoadTcpClient, (Int16)LogCategories.NonFatalError); + return null; + } + + return socket; + } } }