Skip to content

Commit

Permalink
Add 'attacking rate' option to the attacking tool. (#4017)
Browse files Browse the repository at this point in the history
  • Loading branch information
VipAlekseyPetrenko authored Dec 20, 2023
1 parent 428b52f commit f79e0f3
Showing 1 changed file with 52 additions and 15 deletions.
67 changes: 52 additions & 15 deletions src/tools/attack/attack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#define ATTACK_THREADS_DEFAULT CxPlatProcCount()

#define ATTACK_RATE_DEFAULT 1000000

#define ATTACK_PORT_DEFAULT 443

const QUIC_HKDF_LABELS HkdfLabels = { "quic key", "quic iv", "quic hp", "quic ku" };
Expand All @@ -37,6 +39,7 @@ static const char* IpAddress;
static QUIC_ADDR ServerAddress;
static uint64_t TimeoutMs = ATTACK_TIMEOUT_DEFAULT_MS;
static uint32_t ThreadCount = ATTACK_THREADS_DEFAULT;
static uint64_t AttackRate = ATTACK_RATE_DEFAULT;
static const char* Alpn = "h3-29";
static uint32_t Version = QUIC_VERSION_DRAFT_29;

Expand All @@ -50,7 +53,7 @@ void PrintUsage()

printf("Usage:\n");
printf(" quicattack.exe -list\n\n");
printf(" quicattack.exe -type:<number> -ip:<ip_address_and_port> [-alpn:<protocol_name>] [-sni:<host_name>] [-timeout:<ms>] [-threads:<count>]\n\n");
printf(" quicattack.exe -type:<number> -ip:<ip_address_and_port> [-alpn:<protocol_name>] [-sni:<host_name>] [-timeout:<ms>] [-threads:<count>] [-rate:<packet_rate>]\n\n");
}

void PrintUsageList()
Expand All @@ -61,7 +64,7 @@ void PrintUsageList()
printf("#1 - Random UDP 1 byte UDP packets.\n");
printf("#2 - Random UDP full length UDP packets.\n");
printf("#3 - Random QUIC initial packets.\n");
printf("#4 - Valid QUIC initial packets.\n");
printf("#4 - Valid QUIC initial packets.\n\n");
}

struct CallbackContext {
Expand Down Expand Up @@ -126,14 +129,16 @@ ResolveRouteComplete(
{
UNREFERENCED_PARAMETER(PathId);
CallbackContext* CContext = (CallbackContext*) Context;
if(Succeeded) {
if (Succeeded) {
CxPlatResolveRouteComplete(nullptr, CContext->Route, PhysicalAddress, 0);
}
CxPlatEventSet(CContext->Event);
}

void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bool TCP = false)
void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t DatagramLength, bool ValidQuic, bool TCP = false)
{
const uint16_t HeadersLength = ((TCP)? 20 : ((ValidQuic)? 8 + MIN_LONG_HEADER_LENGTH_V1 : 8)) + 20;

CXPLAT_ROUTE Route = {0};
CxPlatSocketGetLocalAddress(Binding, &Route.LocalAddress);
CxPlatSocketGetRemoteAddress(Binding, &Route.RemoteAddress);
Expand All @@ -152,21 +157,34 @@ void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bo
uint64_t ConnectionId = 0;
CxPlatRandom(sizeof(ConnectionId), &ConnectionId);

while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs) {
uint64_t BucketTime = CxPlatTimeMs64(), CurTime;
uint64_t BucketCount = 0;
uint64_t BucketThreshold = CXPLAT_MAX(1, AttackRate / ThreadCount);

while (CxPlatTimeDiff64(TimeStart, (CurTime = CxPlatTimeMs64())) < TimeoutMs) {

if (CxPlatTimeDiff64(BucketTime, CurTime) > 1000) {
BucketTime = CurTime;
BucketCount = 0;
}

if (BucketCount >= BucketThreshold) {
continue;
}

CXPLAT_SEND_CONFIG SendConfig = {&Route, Length, CXPLAT_ECN_NON_ECT, 0 };
CXPLAT_SEND_CONFIG SendConfig = {&Route, DatagramLength, CXPLAT_ECN_NON_ECT, 0 };
CXPLAT_SEND_DATA* SendData = CxPlatSendDataAlloc(Binding, &SendConfig);
if (SendData == nullptr) {
continue;
}

do {
QUIC_BUFFER* SendBuffer = CxPlatSendDataAllocBuffer(SendData, Length);
QUIC_BUFFER* SendBuffer = CxPlatSendDataAllocBuffer(SendData, DatagramLength);
if (SendBuffer == nullptr) {
continue;
}

CxPlatRandom(Length, SendBuffer->Buffer);
CxPlatRandom(DatagramLength, SendBuffer->Buffer);

if (ValidQuic) {
QUIC_LONG_HEADER_V1* Header =
Expand All @@ -182,20 +200,22 @@ void RunAttackRandom(CXPLAT_SOCKET* Binding, uint16_t Length, bool ValidQuic, bo
Header->DestCid[8] = 8;
Header->DestCid[17] = 0;
QuicVarIntEncode(
Length - (MIN_LONG_HEADER_LENGTH_V1 + 19),
DatagramLength - (MIN_LONG_HEADER_LENGTH_V1 + 19),
Header->DestCid + 18);
}

InterlockedExchangeAdd64(&TotalPacketCount, 1);
InterlockedExchangeAdd64(&TotalByteCount, Length);
InterlockedExchangeAdd64(&TotalByteCount, DatagramLength + HeadersLength);
} while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs &&
!CxPlatSendDataIsFull(SendData));

CxPlatSocketSend(
Binding,
&Route,
SendData);


BucketCount++;

if (TCP) {
CxPlatSendDataFree(SendData);
Route.LocalAddress.Ipv4.sin_port++;
Expand Down Expand Up @@ -252,7 +272,20 @@ void RunAttackValidInitial(CXPLAT_SOCKET* Binding)
CxPlatRandom(sizeof(uint64_t), DestCid);
CxPlatRandom(sizeof(uint64_t), SrcCid);

while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs) {
uint64_t BucketTime = CxPlatTimeMs64(), CurTime;
uint64_t BucketCount = 0;
uint64_t BucketThreshold = CXPLAT_MAX(1, AttackRate / ThreadCount);

while (CxPlatTimeDiff64(TimeStart, (CurTime = CxPlatTimeMs64())) < TimeoutMs) {

if (CxPlatTimeDiff64(BucketTime, CurTime) > 1000) {
BucketTime = CurTime;
BucketCount = 0;
}

if (BucketCount >= BucketThreshold) {
continue;
}

CXPLAT_SEND_CONFIG SendConfig = {&Route, DatagramLength, CXPLAT_ECN_NON_ECT, 0 };
CXPLAT_SEND_DATA* SendData = CxPlatSendDataAlloc(Binding, &SendConfig);
Expand Down Expand Up @@ -310,14 +343,16 @@ void RunAttackValidInitial(CXPLAT_SOCKET* Binding)
}

InterlockedExchangeAdd64(&TotalPacketCount, 1);
InterlockedExchangeAdd64(&TotalByteCount, DatagramLength);
InterlockedExchangeAdd64(&TotalByteCount, DatagramLength + MIN_LONG_HEADER_LENGTH_V1);
} while (CxPlatTimeDiff64(TimeStart, CxPlatTimeMs64()) < TimeoutMs &&
!CxPlatSendDataIsFull(SendData));

CxPlatSocketSend(
Binding,
&Route,
SendData);

BucketCount++;
}
}

Expand Down Expand Up @@ -368,7 +403,6 @@ CXPLAT_THREAD_CALLBACK(RunAttackThread, /* Context */)
void RunAttack()
{
Writer = new PacketWriter(Version, Alpn, ServerName);

CXPLAT_THREAD* Threads =
(CXPLAT_THREAD*)CXPLAT_ALLOC_PAGED(ThreadCount * sizeof(CXPLAT_THREAD), QUIC_POOL_TOOL);

Expand Down Expand Up @@ -437,7 +471,10 @@ main(
TryGetValue(argc, argv, "alpn", &Alpn);
TryGetValue(argc, argv, "sni", &ServerName);
TryGetValue(argc, argv, "timeout", &TimeoutMs);
TryGetValue(argc, argv, "threads", &ThreadCount);
TryGetValue(argc, argv, "rate", &AttackRate);
if (!TryGetValue(argc, argv, "threads", &ThreadCount)) {
ThreadCount = ATTACK_THREADS_DEFAULT;
};

if (IpAddress == nullptr) {
if (ServerName == nullptr) {
Expand Down

0 comments on commit f79e0f3

Please sign in to comment.