diff --git a/src/tools/attack/attack.cpp b/src/tools/attack/attack.cpp index c2963684b1..dc7fe9e437 100644 --- a/src/tools/attack/attack.cpp +++ b/src/tools/attack/attack.cpp @@ -24,6 +24,8 @@ #define ATTACK_THREADS_DEFAULT CxPlatProcCount() +#define ATTACK_RATE_DEFAULT 1000000 + #define ATTACK_PORT_DEFAULT 443 const QUIC_HKDF_LABELS HkdfLabels = { "quic key", "quic iv", "quic hp", "quic ku" }; @@ -37,6 +39,7 @@ static const char* IpAddress; static QUIC_ADDR ServerAddress; static uint64_t TimeoutMs = ATTACK_TIMEOUT_DEFAULT_MS; static uint32_t ThreadCount = ATTACK_THREADS_DEFAULT; +static uint64_t AttackRate = ATTACK_RATE_DEFAULT; static const char* Alpn = "h3-29"; static uint32_t Version = QUIC_VERSION_DRAFT_29; @@ -50,7 +53,7 @@ void PrintUsage() printf("Usage:\n"); printf(" quicattack.exe -list\n\n"); - printf(" quicattack.exe -type: -ip: [-alpn:] [-sni:] [-timeout:] [-threads:]\n\n"); + printf(" quicattack.exe -type: -ip: [-alpn:] [-sni:] [-timeout:] [-threads:] [-rate:]\n\n"); } void PrintUsageList() @@ -61,7 +64,7 @@ void PrintUsageList() printf("#1 - Random UDP 1 byte UDP packets.\n"); printf("#2 - Random UDP full length UDP packets.\n"); printf("#3 - Random QUIC initial packets.\n"); - printf("#4 - Valid QUIC initial packets.\n"); + printf("#4 - Valid QUIC initial packets.\n\n"); } struct CallbackContext { @@ -126,14 +129,16 @@ ResolveRouteComplete( { UNREFERENCED_PARAMETER(PathId); CallbackContext* CContext = (CallbackContext*) Context; - if(Succeeded) { + if (Succeeded) { CxPlatResolveRouteComplete(nullptr, CContext->Route, PhysicalAddress, 0); } CxPlatEventSet(CContext->Event); } -void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bool TCP = false) +void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t DatagramLength, bool ValidQuic, bool TCP = false) { + const uint16_t HeadersLength = ((TCP)? 20 : ((ValidQuic)? 8 + MIN_LONG_HEADER_LENGTH_V1 : 8)) + 20; + CXPLAT_ROUTE Route = {0}; CxPlatSocketGetLocalAddress(Binding, &Route.LocalAddress); CxPlatSocketGetRemoteAddress(Binding, &Route.RemoteAddress); @@ -152,21 +157,34 @@ void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bo uint64_t ConnectionId = 0; CxPlatRandom(sizeof(ConnectionId), &ConnectionId); - while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs) { + uint64_t BucketTime = CxPlatTimeMs64(), CurTime; + uint64_t BucketCount = 0; + uint64_t BucketThreshold = CXPLAT_MAX(1, AttackRate / ThreadCount); + + while (CxPlatTimeDiff64(TimeStart, (CurTime = CxPlatTimeMs64())) < TimeoutMs) { + + if (CxPlatTimeDiff64(BucketTime, CurTime) > 1000) { + BucketTime = CurTime; + BucketCount = 0; + } + + if (BucketCount >= BucketThreshold) { + continue; + } - CXPLAT_SEND_CONFIG SendConfig = {&Route, Length, CXPLAT_ECN_NON_ECT, 0 }; + CXPLAT_SEND_CONFIG SendConfig = {&Route, DatagramLength, CXPLAT_ECN_NON_ECT, 0 }; CXPLAT_SEND_DATA* SendData = CxPlatSendDataAlloc(Binding, &SendConfig); if (SendData == nullptr) { continue; } do { - QUIC_BUFFER* SendBuffer = CxPlatSendDataAllocBuffer(SendData, Length); + QUIC_BUFFER* SendBuffer = CxPlatSendDataAllocBuffer(SendData, DatagramLength); if (SendBuffer == nullptr) { continue; } - CxPlatRandom(Length, SendBuffer->Buffer); + CxPlatRandom(DatagramLength, SendBuffer->Buffer); if (ValidQuic) { QUIC_LONG_HEADER_V1* Header = @@ -182,12 +200,12 @@ void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bo Header->DestCid[8] = 8; Header->DestCid[17] = 0; QuicVarIntEncode( - Length - (MIN_LONG_HEADER_LENGTH_V1 + 19), + DatagramLength - (MIN_LONG_HEADER_LENGTH_V1 + 19), Header->DestCid + 18); } InterlockedExchangeAdd64(&TotalPacketCount, 1); - InterlockedExchangeAdd64(&TotalByteCount, Length); + InterlockedExchangeAdd64(&TotalByteCount, DatagramLength + HeadersLength); } while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs && !CxPlatSendDataIsFull(SendData)); @@ -195,7 +213,9 @@ void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bo Binding, &Route, SendData); - + + BucketCount++; + if (TCP) { CxPlatSendDataFree(SendData); Route.LocalAddress.Ipv4.sin_port++; @@ -252,7 +272,20 @@ void RunAttackValidInitial(CXPLAT_SOCKET* Binding) CxPlatRandom(sizeof(uint64_t), DestCid); CxPlatRandom(sizeof(uint64_t), SrcCid); - while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs) { + uint64_t BucketTime = CxPlatTimeMs64(), CurTime; + uint64_t BucketCount = 0; + uint64_t BucketThreshold = CXPLAT_MAX(1, AttackRate / ThreadCount); + + while (CxPlatTimeDiff64(TimeStart, (CurTime = CxPlatTimeMs64())) < TimeoutMs) { + + if (CxPlatTimeDiff64(BucketTime, CurTime) > 1000) { + BucketTime = CurTime; + BucketCount = 0; + } + + if (BucketCount >= BucketThreshold) { + continue; + } CXPLAT_SEND_CONFIG SendConfig = {&Route, DatagramLength, CXPLAT_ECN_NON_ECT, 0 }; CXPLAT_SEND_DATA* SendData = CxPlatSendDataAlloc(Binding, &SendConfig); @@ -310,7 +343,7 @@ void RunAttackValidInitial(CXPLAT_SOCKET* Binding) } InterlockedExchangeAdd64(&TotalPacketCount, 1); - InterlockedExchangeAdd64(&TotalByteCount, DatagramLength); + InterlockedExchangeAdd64(&TotalByteCount, DatagramLength + MIN_LONG_HEADER_LENGTH_V1); } while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs && !CxPlatSendDataIsFull(SendData)); @@ -318,6 +351,8 @@ void RunAttackValidInitial(CXPLAT_SOCKET* Binding) Binding, &Route, SendData); + + BucketCount++; } } @@ -368,7 +403,6 @@ CXPLAT_THREAD_CALLBACK(RunAttackThread, /* Context */) void RunAttack() { Writer = new PacketWriter(Version, Alpn, ServerName); - CXPLAT_THREAD* Threads = (CXPLAT_THREAD*)CXPLAT_ALLOC_PAGED(ThreadCount * sizeof(CXPLAT_THREAD), QUIC_POOL_TOOL); @@ -437,7 +471,10 @@ main( TryGetValue(argc, argv, "alpn", &Alpn); TryGetValue(argc, argv, "sni", &ServerName); TryGetValue(argc, argv, "timeout", &TimeoutMs); - TryGetValue(argc, argv, "threads", &ThreadCount); + TryGetValue(argc, argv, "rate", &AttackRate); + if (!TryGetValue(argc, argv, "threads", &ThreadCount)) { + ThreadCount = ATTACK_THREADS_DEFAULT; + }; if (IpAddress == nullptr) { if (ServerName == nullptr) {