Skip to content

Commit

Permalink
detect: unify functions for multi-buffer
Browse files Browse the repository at this point in the history
Ticket: 6575

Multi buffers keywords now use a single registration function
DetectAppLayerMultiRegister with a GetBuffer argument.

This GetBuffer function pointer is similar to the ones used by
single-buffer keyword, except that it takes an additional
parameter which is the index of the buffer to get.
Under the hood, an anonymous union between these 2 functions
pointers types is used.

In the end, this deduplicates code, especially the calls to
DetectEngineContentInspection
  • Loading branch information
catenacyber authored and victorjulien committed May 24, 2024
1 parent 55bc5f2 commit ce16a56
Show file tree
Hide file tree
Showing 21 changed files with 366 additions and 1,600 deletions.
90 changes: 5 additions & 85 deletions src/detect-dns-answer-name.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
#include "util-profiling.h"
#include "rust.h"

typedef struct PrefilterMpm {
int list_id;
const MpmCtx *mpm_ctx;
const DetectEngineTransforms *transforms;
} PrefilterMpm;

static int detect_buffer_id = 0;

static int DetectSetup(DetectEngineCtx *de_ctx, Signature *s, const char *str)
Expand All @@ -50,8 +44,9 @@ static int DetectSetup(DetectEngineCtx *de_ctx, Signature *s, const char *str)
return 0;
}

static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx, uint8_t flags,
const DetectEngineTransforms *transforms, void *txv, uint32_t index, int list_id)
static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx,
const DetectEngineTransforms *transforms, Flow *f, uint8_t flags, void *txv, int list_id,
uint32_t index)
{
InspectionBuffer *buffer = InspectionBufferMultipleForListGet(det_ctx, list_id, index);
if (buffer == NULL) {
Expand All @@ -74,74 +69,6 @@ static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx, uint8_t flags
return buffer;
}

static uint8_t DetectEngineInspectCb(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx,
const struct DetectEngineAppInspectionEngine_ *engine, const Signature *s, Flow *f,
uint8_t flags, void *alstate, void *txv, uint64_t tx_id)
{
const DetectEngineTransforms *transforms = NULL;
if (!engine->mpm) {
transforms = engine->v2.transforms;
}

for (uint32_t i = 0;; i++) {
InspectionBuffer *buffer = GetBuffer(det_ctx, flags, transforms, txv, i, engine->sm_list);
if (buffer == NULL || buffer->inspect == NULL) {
break;
}

const bool match = DetectEngineContentInspectionBuffer(de_ctx, det_ctx, s, engine->smd,
NULL, f, buffer, DETECT_ENGINE_CONTENT_INSPECTION_MODE_STATE);
if (match) {
return DETECT_ENGINE_INSPECT_SIG_MATCH;
}
}

return DETECT_ENGINE_INSPECT_SIG_NO_MATCH;
}

static void PrefilterTx(DetectEngineThreadCtx *det_ctx, const void *pectx, Packet *p, Flow *f,
void *txv, const uint64_t idx, const AppLayerTxData *_txd, const uint8_t flags)
{
SCEnter();

const PrefilterMpm *ctx = (const PrefilterMpm *)pectx;
const MpmCtx *mpm_ctx = ctx->mpm_ctx;
const int list_id = ctx->list_id;

for (uint32_t i = 0;; i++) {
InspectionBuffer *buffer = GetBuffer(det_ctx, flags, ctx->transforms, txv, i, list_id);
if (buffer == NULL) {
break;
}

if (buffer->inspect_len >= mpm_ctx->minlen) {
(void)mpm_table[mpm_ctx->mpm_type].Search(
mpm_ctx, &det_ctx->mtc, &det_ctx->pmq, buffer->inspect, buffer->inspect_len);
PREFILTER_PROFILING_ADD_BYTES(det_ctx, buffer->inspect_len);
}
}
}

static void PrefilterMpmFree(void *ptr)
{
SCFree(ptr);
}

static int PrefilterMpmRegister(DetectEngineCtx *de_ctx, SigGroupHead *sgh, MpmCtx *mpm_ctx,
const DetectBufferMpmRegistry *mpm_reg, int list_id)
{
PrefilterMpm *pectx = SCCalloc(1, sizeof(*pectx));
if (pectx == NULL) {
return -1;
}
pectx->list_id = list_id;
pectx->mpm_ctx = mpm_ctx;
pectx->transforms = &mpm_reg->transforms;

return PrefilterAppendTxEngine(de_ctx, sgh, PrefilterTx, mpm_reg->app_v2.alproto,
mpm_reg->app_v2.tx_min_progress, pectx, PrefilterMpmFree, mpm_reg->pname);
}

void DetectDnsAnswerNameRegister(void)
{
static const char *keyword = "dns.answer.name";
Expand All @@ -154,16 +81,9 @@ void DetectDnsAnswerNameRegister(void)

/* Register in the TO_SERVER direction, even though this is not
normal, it could be provided as part of a request. */
DetectAppLayerInspectEngineRegister(
keyword, ALPROTO_DNS, SIG_FLAG_TOSERVER, 0, DetectEngineInspectCb, NULL);
DetectAppLayerMpmRegister(
keyword, SIG_FLAG_TOSERVER, 2, PrefilterMpmRegister, NULL, ALPROTO_DNS, 1);

DetectAppLayerMultiRegister(keyword, ALPROTO_DNS, SIG_FLAG_TOSERVER, 0, GetBuffer, 2, 1);
/* Register in the TO_CLIENT direction. */
DetectAppLayerInspectEngineRegister(
keyword, ALPROTO_DNS, SIG_FLAG_TOCLIENT, 0, DetectEngineInspectCb, NULL);
DetectAppLayerMpmRegister(
keyword, SIG_FLAG_TOCLIENT, 2, PrefilterMpmRegister, NULL, ALPROTO_DNS, 1);
DetectAppLayerMultiRegister(keyword, ALPROTO_DNS, SIG_FLAG_TOCLIENT, 0, GetBuffer, 2, 1);

DetectBufferTypeSetDescriptionByName(keyword, "dns answer name");
DetectBufferTypeSupportsMultiInstance(keyword);
Expand Down
90 changes: 5 additions & 85 deletions src/detect-dns-query-name.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
#include "util-profiling.h"
#include "rust.h"

typedef struct PrefilterMpm {
int list_id;
const MpmCtx *mpm_ctx;
const DetectEngineTransforms *transforms;
} PrefilterMpm;

static int detect_buffer_id = 0;

static int DetectSetup(DetectEngineCtx *de_ctx, Signature *s, const char *str)
Expand All @@ -50,8 +44,9 @@ static int DetectSetup(DetectEngineCtx *de_ctx, Signature *s, const char *str)
return 0;
}

static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx, const uint8_t flags,
const DetectEngineTransforms *transforms, void *txv, uint32_t index, int list_id)
static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx,
const DetectEngineTransforms *transforms, Flow *f, const uint8_t flags, void *txv,
int list_id, uint32_t index)
{
InspectionBuffer *buffer = InspectionBufferMultipleForListGet(det_ctx, list_id, index);
if (buffer == NULL) {
Expand All @@ -74,74 +69,6 @@ static InspectionBuffer *GetBuffer(DetectEngineThreadCtx *det_ctx, const uint8_t
return buffer;
}

static uint8_t DetectEngineInspectCb(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx,
const struct DetectEngineAppInspectionEngine_ *engine, const Signature *s, Flow *f,
uint8_t flags, void *alstate, void *txv, uint64_t tx_id)
{
const DetectEngineTransforms *transforms = NULL;
if (!engine->mpm) {
transforms = engine->v2.transforms;
}

for (uint32_t i = 0;; i++) {
InspectionBuffer *buffer = GetBuffer(det_ctx, flags, transforms, txv, i, engine->sm_list);
if (buffer == NULL || buffer->inspect == NULL) {
break;
}

const bool match = DetectEngineContentInspectionBuffer(de_ctx, det_ctx, s, engine->smd,
NULL, f, buffer, DETECT_ENGINE_CONTENT_INSPECTION_MODE_STATE);
if (match) {
return DETECT_ENGINE_INSPECT_SIG_MATCH;
}
}

return DETECT_ENGINE_INSPECT_SIG_NO_MATCH;
}

static void PrefilterTx(DetectEngineThreadCtx *det_ctx, const void *pectx, Packet *p, Flow *f,
void *txv, const uint64_t idx, const AppLayerTxData *_txd, const uint8_t flags)
{
SCEnter();

const PrefilterMpm *ctx = (const PrefilterMpm *)pectx;
const MpmCtx *mpm_ctx = ctx->mpm_ctx;
const int list_id = ctx->list_id;

for (uint32_t i = 0;; i++) {
InspectionBuffer *buffer = GetBuffer(det_ctx, flags, ctx->transforms, txv, i, list_id);
if (buffer == NULL) {
break;
}

if (buffer->inspect_len >= mpm_ctx->minlen) {
(void)mpm_table[mpm_ctx->mpm_type].Search(
mpm_ctx, &det_ctx->mtc, &det_ctx->pmq, buffer->inspect, buffer->inspect_len);
PREFILTER_PROFILING_ADD_BYTES(det_ctx, buffer->inspect_len);
}
}
}

static void PrefilterMpmFree(void *ptr)
{
SCFree(ptr);
}

static int PrefilterMpmRegister(DetectEngineCtx *de_ctx, SigGroupHead *sgh, MpmCtx *mpm_ctx,
const DetectBufferMpmRegistry *mpm_reg, int list_id)
{
PrefilterMpm *pectx = SCCalloc(1, sizeof(*pectx));
if (pectx == NULL) {
return -1;
}
pectx->list_id = list_id;
pectx->mpm_ctx = mpm_ctx;
pectx->transforms = &mpm_reg->transforms;

return PrefilterAppendTxEngine(de_ctx, sgh, PrefilterTx, mpm_reg->app_v2.alproto,
mpm_reg->app_v2.tx_min_progress, pectx, PrefilterMpmFree, mpm_reg->pname);
}

void DetectDnsQueryNameRegister(void)
{
static const char *keyword = "dns.query.name";
Expand All @@ -154,15 +81,8 @@ void DetectDnsQueryNameRegister(void)

/* Register in both directions as the query is usually echoed back
in the response. */
DetectAppLayerInspectEngineRegister(
keyword, ALPROTO_DNS, SIG_FLAG_TOSERVER, 0, DetectEngineInspectCb, NULL);
DetectAppLayerMpmRegister(
keyword, SIG_FLAG_TOSERVER, 2, PrefilterMpmRegister, NULL, ALPROTO_DNS, 1);

DetectAppLayerInspectEngineRegister(
keyword, ALPROTO_DNS, SIG_FLAG_TOCLIENT, 0, DetectEngineInspectCb, NULL);
DetectAppLayerMpmRegister(
keyword, SIG_FLAG_TOCLIENT, 2, PrefilterMpmRegister, NULL, ALPROTO_DNS, 1);
DetectAppLayerMultiRegister(keyword, ALPROTO_DNS, SIG_FLAG_TOSERVER, 0, GetBuffer, 2, 1);
DetectAppLayerMultiRegister(keyword, ALPROTO_DNS, SIG_FLAG_TOCLIENT, 0, GetBuffer, 2, 1);

DetectBufferTypeSetDescriptionByName(keyword, "dns query name");
DetectBufferTypeSupportsMultiInstance(keyword);
Expand Down
111 changes: 6 additions & 105 deletions src/detect-dns-query.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,21 @@ static void DetectDnsQueryRegisterTests(void);
#endif
static int g_dns_query_buffer_id = 0;

struct DnsQueryGetDataArgs {
uint32_t local_id; /**< used as index into thread inspect array */
void *txv;
};

static InspectionBuffer *DnsQueryGetData(DetectEngineThreadCtx *det_ctx,
const DetectEngineTransforms *transforms, Flow *f, struct DnsQueryGetDataArgs *cbdata,
int list_id)
const DetectEngineTransforms *transforms, Flow *f, const uint8_t flags, void *txv,
int list_id, uint32_t local_id)
{
SCEnter();

InspectionBuffer *buffer =
InspectionBufferMultipleForListGet(det_ctx, list_id, cbdata->local_id);
InspectionBuffer *buffer = InspectionBufferMultipleForListGet(det_ctx, list_id, local_id);
if (buffer == NULL)
return NULL;
if (buffer->initialized)
return buffer;

const uint8_t *data;
uint32_t data_len;
if (SCDnsTxGetQueryName(cbdata->txv, false, cbdata->local_id, &data, &data_len) == 0) {
if (SCDnsTxGetQueryName(txv, false, local_id, &data, &data_len) == 0) {
InspectionBufferSetupMultiEmpty(buffer);
return NULL;
}
Expand All @@ -97,96 +91,6 @@ static InspectionBuffer *DnsQueryGetData(DetectEngineThreadCtx *det_ctx,
SCReturnPtr(buffer, "InspectionBuffer");
}

static uint8_t DetectEngineInspectDnsQuery(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx,
const DetectEngineAppInspectionEngine *engine, const Signature *s, Flow *f, uint8_t flags,
void *alstate, void *txv, uint64_t tx_id)
{
uint32_t local_id = 0;

const DetectEngineTransforms *transforms = NULL;
if (!engine->mpm) {
transforms = engine->v2.transforms;
}

while(1) {
struct DnsQueryGetDataArgs cbdata = { local_id, txv, };
InspectionBuffer *buffer =
DnsQueryGetData(det_ctx, transforms, f, &cbdata, engine->sm_list);
if (buffer == NULL || buffer->inspect == NULL)
break;

const bool match = DetectEngineContentInspectionBuffer(de_ctx, det_ctx, s, engine->smd,
NULL, f, buffer, DETECT_ENGINE_CONTENT_INSPECTION_MODE_STATE);
if (match) {
return DETECT_ENGINE_INSPECT_SIG_MATCH;
}
local_id++;
}
return DETECT_ENGINE_INSPECT_SIG_NO_MATCH;
}

typedef struct PrefilterMpmDnsQuery {
int list_id;
const MpmCtx *mpm_ctx;
const DetectEngineTransforms *transforms;
} PrefilterMpmDnsQuery;

/** \brief DnsQuery DnsQuery Mpm prefilter callback
*
* \param det_ctx detection engine thread ctx
* \param p packet to inspect
* \param f flow to inspect
* \param txv tx to inspect
* \param pectx inspection context
*/
static void PrefilterTxDnsQuery(DetectEngineThreadCtx *det_ctx, const void *pectx, Packet *p,
Flow *f, void *txv, const uint64_t idx, const AppLayerTxData *_txd, const uint8_t flags)
{
SCEnter();

const PrefilterMpmDnsQuery *ctx = (const PrefilterMpmDnsQuery *)pectx;
const MpmCtx *mpm_ctx = ctx->mpm_ctx;
const int list_id = ctx->list_id;

uint32_t local_id = 0;
while(1) {
// loop until we get a NULL

struct DnsQueryGetDataArgs cbdata = { local_id, txv };
InspectionBuffer *buffer = DnsQueryGetData(det_ctx, ctx->transforms, f, &cbdata, list_id);
if (buffer == NULL)
break;

if (buffer->inspect_len >= mpm_ctx->minlen) {
(void)mpm_table[mpm_ctx->mpm_type].Search(
mpm_ctx, &det_ctx->mtc, &det_ctx->pmq, buffer->inspect, buffer->inspect_len);
PREFILTER_PROFILING_ADD_BYTES(det_ctx, buffer->inspect_len);
}

local_id++;
}
}

static void PrefilterMpmDnsQueryFree(void *ptr)
{
SCFree(ptr);
}

static int PrefilterMpmDnsQueryRegister(DetectEngineCtx *de_ctx, SigGroupHead *sgh, MpmCtx *mpm_ctx,
const DetectBufferMpmRegistry *mpm_reg, int list_id)
{
PrefilterMpmDnsQuery *pectx = SCCalloc(1, sizeof(*pectx));
if (pectx == NULL)
return -1;
pectx->list_id = list_id;
pectx->mpm_ctx = mpm_ctx;
pectx->transforms = &mpm_reg->transforms;

return PrefilterAppendTxEngine(de_ctx, sgh, PrefilterTxDnsQuery,
mpm_reg->app_v2.alproto, mpm_reg->app_v2.tx_min_progress,
pectx, PrefilterMpmDnsQueryFree, mpm_reg->pname);
}

/**
* \brief Registration function for keyword: dns_query
*/
Expand All @@ -203,11 +107,8 @@ void DetectDnsQueryRegister (void)
sigmatch_table[DETECT_AL_DNS_QUERY].flags |= SIGMATCH_NOOPT;
sigmatch_table[DETECT_AL_DNS_QUERY].flags |= SIGMATCH_INFO_STICKY_BUFFER;

DetectAppLayerMpmRegister(
"dns_query", SIG_FLAG_TOSERVER, 2, PrefilterMpmDnsQueryRegister, NULL, ALPROTO_DNS, 1);

DetectAppLayerInspectEngineRegister(
"dns_query", ALPROTO_DNS, SIG_FLAG_TOSERVER, 1, DetectEngineInspectDnsQuery, NULL);
DetectAppLayerMultiRegister(
"dns_query", ALPROTO_DNS, SIG_FLAG_TOSERVER, 1, DnsQueryGetData, 2, 1);

DetectBufferTypeSetDescriptionByName("dns_query",
"dns request query");
Expand Down
Loading

0 comments on commit ce16a56

Please sign in to comment.