diff --git a/wasmplugin/plugin.go b/wasmplugin/plugin.go index 6e8930451f598..a78bbd6ccf489 100644 --- a/wasmplugin/plugin.go +++ b/wasmplugin/plugin.go @@ -97,6 +97,35 @@ func (ctx *corazaPlugin) NewHttpContext(contextID uint32) types.HttpContext { } } +type interruptionPhase int8 + +func (p interruptionPhase) isInterrupted() bool { + return p != interruptionPhaseNone +} + +func (p interruptionPhase) String() string { + switch p { + case interruptionPhaseHttpRequestHeaders: + return "http_request_headers" + case interruptionPhaseHttpRequestBody: + return "http_request_body" + case interruptionPhaseHttpResponseHeaders: + return "http_response_headers" + case interruptionPhaseHttpResponseBody: + return "http_response_body" + default: + return "no interruption yet" + } +} + +const ( + interruptionPhaseNone = iota + interruptionPhaseHttpRequestHeaders = iota + interruptionPhaseHttpRequestBody = iota + interruptionPhaseHttpResponseHeaders = iota + interruptionPhaseHttpResponseBody = iota +) + type httpContext struct { // Embed the default http context here, // so that we don't need to reimplement all the methods. @@ -108,7 +137,7 @@ type httpContext struct { processedResponseBody bool bodyReadIndex int metrics *wafMetrics - interruptionHandled bool + interruptedAt interruptionPhase logger debuglog.Logger metricLabels map[string]string } @@ -180,7 +209,7 @@ func (ctx *httpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) t interruption := tx.ProcessRequestHeaders() if interruption != nil { - return ctx.handleInterruption("http_request_headers", interruption) + return ctx.handleInterruption(interruptionPhaseHttpRequestHeaders, interruption) } return types.ActionContinue @@ -189,8 +218,10 @@ func (ctx *httpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) t func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { defer logTime("OnHttpRequestBody", currentTime()) - if ctx.interruptionHandled { - ctx.logger.Error().Msg("Interruption already handled") + if ctx.interruptedAt.isInterrupted() { + ctx.logger.Error(). + Str("interruption_handled_phase", ctx.interruptedAt.String()). + Msg("Interruption already handled") return types.ActionPause } @@ -216,7 +247,7 @@ func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types. } if interruption != nil { - return ctx.handleInterruption("http_request_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpRequestBody, interruption) } return types.ActionContinue @@ -232,7 +263,7 @@ func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types. } if interruption != nil { - return ctx.handleInterruption("http_request_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpRequestBody, interruption) } ctx.bodyReadIndex += bodySize @@ -265,7 +296,7 @@ func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types. return types.ActionContinue } if interruption != nil { - return ctx.handleInterruption("http_request_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpRequestBody, interruption) } return types.ActionContinue @@ -277,13 +308,15 @@ func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types. func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) types.Action { defer logTime("OnHttpResponseHeaders", currentTime()) - if ctx.interruptionHandled { + if ctx.interruptedAt.isInterrupted() { // Handling the interruption (see handleInterruption) generates a HttpResponse with the required interruption status code. // If handleInterruption is raised during OnHttpRequestHeaders or OnHttpRequestBody, the crafted response is sent // downstream via the filter chain, therefore OnHttpResponseHeaders is called. It has to continue to properly send back the interruption action. // A doublecheck might be eventually added, checking that the :status header matches the expected interruption status code. // See https://github.com/corazawaf/coraza-proxy-wasm/pull/126 - ctx.logger.Debug().Msg("Interruption already handled, sending downstream the local response") + ctx.logger.Debug(). + Str("interruption_handled_phase", ctx.interruptedAt.String()). + Msg("Interruption already handled, sending downstream the local response") return types.ActionContinue } @@ -304,7 +337,7 @@ func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) return types.ActionContinue } if interruption != nil { - return ctx.handleInterruption("http_response_headers", interruption) + return ctx.handleInterruption(interruptionPhaseHttpResponseHeaders, interruption) } } @@ -334,7 +367,7 @@ func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) interruption := tx.ProcessResponseHeaders(code, ctx.httpProtocol) if interruption != nil { - return ctx.handleInterruption("http_response_headers", interruption) + return ctx.handleInterruption(interruptionPhaseHttpResponseHeaders, interruption) } return types.ActionContinue @@ -343,13 +376,14 @@ func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) func (ctx *httpContext) OnHttpResponseBody(bodySize int, endOfStream bool) types.Action { defer logTime("OnHttpResponseBody", currentTime()) - if ctx.interruptionHandled { + if ctx.interruptedAt.isInterrupted() { // At response body phase, proxy-wasm currently relies on emptying the response body as a way of // interruption the response. See https://github.com/corazawaf/coraza-proxy-wasm/issues/26. // If OnHttpResponseBody is called again and an interruption has already been raised, it means that // we have to keep going with the sanitization of the response, emptying it. // Sending the crafted HttpResponse with empty body, we don't expect to trigger OnHttpResponseBody - ctx.logger.Warn(). + ctx.logger.Debug(). + Str("interruption_handled_phase", ctx.interruptedAt.String()). Msg("Response body interruption already handled, keeping replacing the body") // Interruption happened, we don't want to send response body data return replaceResponseBodyWhenInterrupted(ctx.logger, bodySize) @@ -376,7 +410,7 @@ func (ctx *httpContext) OnHttpResponseBody(bodySize int, endOfStream bool) types // Proxy-wasm can not anymore deny the response. The best interruption is emptying the body // Coraza Multiphase evaluation will help here avoiding late interruptions ctx.bodyReadIndex = bodySize // hacky: bodyReadIndex stores the body size that has to be replaced - return ctx.handleInterruption("http_response_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpResponseBody, interruption) } return types.ActionContinue } @@ -393,7 +427,7 @@ func (ctx *httpContext) OnHttpResponseBody(bodySize int, endOfStream bool) types // it is internally needed to replace the full body if the tx is interrupted ctx.bodyReadIndex += bodySize if interruption != nil { - return ctx.handleInterruption("http_response_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpResponseBody, interruption) } } else if err != types.ErrorStatusNotFound { ctx.logger.Error(). @@ -418,7 +452,7 @@ func (ctx *httpContext) OnHttpResponseBody(bodySize int, endOfStream bool) types return types.ActionContinue } if interruption != nil { - return ctx.handleInterruption("http_response_body", interruption) + return ctx.handleInterruption(interruptionPhaseHttpResponseBody, interruption) } return types.ActionContinue } @@ -456,21 +490,21 @@ func (ctx *httpContext) OnHttpStreamDone() { const noGRPCStream int32 = -1 const defaultInterruptionStatusCode int = 403 -func (ctx *httpContext) handleInterruption(phase string, interruption *ctypes.Interruption) types.Action { - if ctx.interruptionHandled { +func (ctx *httpContext) handleInterruption(phase interruptionPhase, interruption *ctypes.Interruption) types.Action { + if ctx.interruptedAt.isInterrupted() { // handleInterruption should never be called more than once panic("Interruption already handled") } - ctx.metrics.CountTXInterruption(phase, interruption.RuleID, ctx.metricLabels) + ctx.metrics.CountTXInterruption(phase.String(), interruption.RuleID, ctx.metricLabels) ctx.logger.Info(). Str("action", interruption.Action). - Str("phase", phase). + Str("phase", phase.String()). Msg("Transaction interrupted") - ctx.interruptionHandled = true - if phase == "http_response_body" { + ctx.interruptedAt = phase + if phase == interruptionPhaseHttpResponseBody { return replaceResponseBodyWhenInterrupted(ctx.logger, ctx.bodyReadIndex) }