Skip to content

Commit

Permalink
Remove Lua state pool and simplify Lua state management
Browse files Browse the repository at this point in the history
The Lua state pool and related functions have been removed to simplify the codebase. Lua states are now created and closed directly within functions, with necessary libraries and modules registered each time. This change reduces complexity and eliminates potential concurrency issues related to the pool management.

Signed-off-by: Christian Roessner <[email protected]>
  • Loading branch information
Christian Roessner committed Sep 2, 2024
1 parent 4c24310 commit 3c7e277
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 488 deletions.
81 changes: 37 additions & 44 deletions server/backend/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,6 @@ var LuaRequestChan chan *LuaRequest
// LuaMainWorkerEndChan is a channel that signals the termination of the main Lua worker.
var LuaMainWorkerEndChan chan Done

// LuaPool is a pool of Lua state instances.
var LuaPool = lualib.NewLuaBackendResultStatePool(
global.LuaBackendResultAuthenticated,
global.LuaBackendResultUserFound,
global.LuaBackendResultAccountField,
global.LuaBackendResultTOTPSecretField,
global.LuaBackendResultTOTPRecoveryField,
global.LuaBAckendResultUniqueUserIDField,
global.LuaBackendResultDisplayNameField,
global.LuaBackendResultAttributes,
)

// LuaRequest is a subset from the Authentication struct.
// LuaRequest is a struct that includes various information for a request to Lua.
type LuaRequest struct {
Expand Down Expand Up @@ -114,14 +102,15 @@ func LuaMainWorker(ctx context.Context) {
}
}

// handleLuaRequest handles a Lua request by executing the compiled script and handling any errors.
// It registers libraries and globals, sets Lua request parameters, and calls the Lua command function.
// It then handles the specific return types based on the Lua request function.
// handleLuaRequest is a function that handles a Lua request. It takes a context, a LuaRequest object, and a compiled Lua script as parameters.
// It sets up the Lua state, registers libraries, and preloads modules. It sets up global variables and creates a Lua table for the request.
// It sets the Lua request parameters based on the LuaRequest object and the Lua table. Then it executes the Lua script and handles any errors.
// Finally, it handles the specific return types based on the result of the Lua script execution.
//
// Parameters:
// - ctx: The context.Context object.
// - luaRequest: The LuaRequest object containing the request parameters.
// - ctx: The Context object.
// - compiledScript: The compiled Lua script function.
// - compiledScript: The compiled Lua script.
//
// Returns: None.
func handleLuaRequest(ctx context.Context, luaRequest *LuaRequest, compiledScript *lua.FunctionProto) {
Expand All @@ -133,61 +122,67 @@ func handleLuaRequest(ctx context.Context, luaRequest *LuaRequest, compiledScrip
logs := new(lualib.CustomLogKeyValue)
luaCtx, luaCancel := context.WithTimeout(ctx, viper.GetDuration("lua_script_timeout")*time.Second)

L := LuaPool.Get()

defer LuaPool.Put(L)
defer L.SetGlobal(global.LuaDefaultTable, lua.LNil)
defer luaCancel()

L.SetContext(luaCtx)
L := lua.NewState()

defer luaCancel()
defer L.Close()

L.SetContext(luaCtx)
lualib.RegisterLibraries(L)
L.PreloadModule(global.LuaModContext, lualib.LoaderModContext(luaRequest.Context))
L.PreloadModule(global.LuaModHTTPRequest, lualib.LoaderModHTTPRequest(luaRequest.HTTPClientContext.Request))

lualib.RegisterBackendResultType(
L,
global.LuaBackendResultAuthenticated,
global.LuaBackendResultUserFound,
global.LuaBackendResultAccountField,
global.LuaBackendResultTOTPSecretField,
global.LuaBackendResultTOTPRecoveryField,
global.LuaBAckendResultUniqueUserIDField,
global.LuaBackendResultDisplayNameField,
global.LuaBackendResultAttributes,
)

if config.LoadableConfig.HaveLDAPBackend() {
L.PreloadModule(global.LuaModLDAP, LoaderModLDAP(ctx))
}

globals := setupGlobals(luaRequest, L, logs)
setupGlobals(luaRequest, L, logs)

request := L.NewTable()

luaCommand, nret = setLuaRequestParameters(luaRequest, request)

err := executeAndHandleError(compiledScript, luaCommand, luaRequest, L, request, nret, logs)

lualib.CleanupLTable(globals)

request = nil
globals = nil

// Handle the specific return types
if err == nil {
handleReturnTypes(L, nret, luaRequest, logs)
}
}

// setupGlobals registers global variables and functions used in Lua scripts.
// Registers the backend result types LuaBackendResultOk and LuaBackendResultFail with global variables 0 and 1 respectively.
// Registers the lua function ctx.Set with name "context_set" which sets a value in the LuaRequest.Context.
// Registers the lua function ctx.Get with name "context_get" which retrieves a value from the LuaRequest.Context.
// Registers the lua function ctx.Delete with name "context_delete" which deletes a value from the LuaRequest.Context.
// Registers the lua function AddCustomLog with name "custom_log_add" which adds a custom log entry to the LuaRequest.Logs.
// The registered global table is assigned to the global variable LuaDefaultTable.
// The generated table is returned from the function.
func setupGlobals(luaRequest *LuaRequest, L *lua.LState, logs *lualib.CustomLogKeyValue) *lua.LTable {
// setupGlobals sets up global variables for the Lua state. It creates a new Lua table to hold the global variables,
// and assigns values to the predefined global variables. It also registers Lua functions for custom log addition and
// setting the status message. Finally, it sets the global table in the Lua state.
//
// Parameters:
// - luaRequest: The LuaRequest object containing the request parameters.
// - L: The Lua state.
// - logs: The custom log key-value pairs.
//
// Returns: None.
func setupGlobals(luaRequest *LuaRequest, L *lua.LState, logs *lualib.CustomLogKeyValue) {
globals := L.NewTable()

globals.RawSet(lua.LString(global.LuaBackendResultOk), lua.LNumber(0))
globals.RawSet(lua.LString(global.LuaBackendResultFail), lua.LNumber(1))

globals.RawSetString(global.LuaFnAddCustomLog, L.NewFunction(lualib.AddCustomLog(logs)))
globals.RawSetString(global.LuaFnSetStatusMessage, L.NewFunction(lualib.SetStatusMessage(&luaRequest.StatusMessage)))
globals.RawSetString(global.LuaFnGetAllHTTPRequestHeaders, L.NewFunction(lualib.GetAllHTTPRequestHeaders(luaRequest.HTTPClientContext.Request)))
globals.RawSetString(global.LuaFnGetHTTPRequestHeader, L.NewFunction(lualib.GetHTTPRequestHeader(luaRequest.HTTPClientContext.Request)))

L.SetGlobal(global.LuaDefaultTable, globals)

return globals
}

// setLuaRequestParameters sets the Lua request parameters based on the given LuaRequest object and Lua table.
Expand Down Expand Up @@ -256,8 +251,6 @@ func executeAndHandleError(compiledScript *lua.FunctionProto, luaCommand string,
processError(err, luaRequest, logs)
}

lualib.CleanupLTable(request)

return err
}

Expand Down
3 changes: 3 additions & 0 deletions server/global/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,9 @@ const (
// LuaModBackend is a constant that holds the name of the Lua module for the Nauthilus backend.
LuaModBackend = "nauthilus_backend"

// LuaModHTTPRequest is a constant representing the value "nauthilus_http_request".
LuaModHTTPRequest = "nauthilus_http_request"

// LuaFnCallFeature represents the function name for "nauthilus_call_feature" in Lua
LuaFnCallFeature = "nauthilus_call_feature"

Expand Down
5 changes: 3 additions & 2 deletions server/lua-plugins.d/callback/callback.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local nauthilus_util = require("nauthilus_util")
local nauthilus_redis = require("nauthilus_redis")
local nauthilus_http_request = require("nauthilus_http_request")

local crypto = require("crypto")
local json = require("json")
Expand Down Expand Up @@ -41,8 +42,8 @@ function nauthilus_run_callback(logging)
end
end

local header = nauthilus_builtin.get_http_request_header("Content-Type")
local body = nauthilus_builtin.get_http_request_body()
local header = nauthilus_http_request.get_http_request_header("Content-Type")
local body = nauthilus_http_request.get_http_request_body()

if nauthilus_util.table_length(header) == 0 or header[1] ~= "application/json" then
print_result("HTTP request header: Wrong 'Content-Type'")
Expand Down
3 changes: 2 additions & 1 deletion server/lua-plugins.d/filters/monitoring.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local nauthilus_util = require("nauthilus_util")
local nauthilus_redis = require("nauthilus_redis")
local nauthilus_backend = require("nauthilus_backend")
local nauthilus_http_request = require("nauthilus_http_request")

local crypto = require("crypto")

Expand Down Expand Up @@ -39,7 +40,7 @@ function nauthilus_call_filter(request)
end

local function get_dovecot_session()
local header = nauthilus_builtin.get_http_request_header("X-Dovecot-Session")
local header = nauthilus_http_request.get_http_request_header("X-Dovecot-Session")
if nauthilus_util.table_length(header) == 1 then
return header[1]
end
Expand Down
55 changes: 24 additions & 31 deletions server/lualib/action/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ var (
RequestChan chan *Action
)

// LuaPool is a pool of Lua state instances.
var LuaPool = lualib.NewLuaStatePool()

// Done is an empty struct that can be used to signal the completion of a task or operation.
type Done struct{}

Expand Down Expand Up @@ -63,7 +60,7 @@ type Action struct {
// Worker struct holds the data required for a worker process.
type Worker struct {
// ctx is a pointer to a Context object used for managing and carrying context deadlines, cancel signals, and other request-scoped values across API boundaries and between processes.
ctx *context.Context
ctx context.Context

// luaActionRequest is a pointer to an Action. This specifies the action to be performed by the Lua scripting environment.
luaActionRequest *Action
Expand Down Expand Up @@ -100,7 +97,7 @@ func NewWorker() *Worker {
// If a request is received, it handles the request by running the corresponding script.
// If the context is cancelled, it sends a WorkerEndChan signal to indicate that the worker has ended.
func (aw *Worker) Work(ctx context.Context) {
aw.ctx = &ctx
aw.ctx = ctx

if !config.LoadableConfig.HaveLuaActions() {
return
Expand Down Expand Up @@ -208,38 +205,41 @@ func (aw *Worker) handleRequest(httpRequest *http.Request) {
return
}

L := LuaPool.Get()
L := lua.NewState()

defer LuaPool.Put(L)
defer L.SetGlobal(global.LuaDefaultTable, lua.LNil)
defer L.Close()

lualib.RegisterLibraries(L)
L.PreloadModule(global.LuaModContext, lualib.LoaderModContext(aw.luaActionRequest.Context))
L.PreloadModule(global.LuaModHTTPRequest, lualib.LoaderModHTTPRequest(httpRequest))

if config.LoadableConfig.HaveLDAPBackend() {
L.PreloadModule(global.LuaModLDAP, backend.LoaderModLDAP(aw.ctx))
}

logs := new(lualib.CustomLogKeyValue)
globals := aw.setupGlobals(L, logs, httpRequest)

aw.setupGlobals(L, logs)

request := aw.setupRequest(L)

for index := range aw.actionScripts {
if aw.actionScripts[index].LuaAction == aw.luaActionRequest.LuaAction && !errors.Is((*aw.ctx).Err(), context.Canceled) {
if aw.actionScripts[index].LuaAction == aw.luaActionRequest.LuaAction && !errors.Is((aw.ctx).Err(), context.Canceled) {
aw.runScript(index, L, request, logs)
}
}

lualib.CleanupLTable(request)
lualib.CleanupLTable(globals)

request = nil
globals = nil

aw.luaActionRequest.FinishedChan <- Done{}
}

// setupGlobals sets up global Lua variables for the Worker.
// It creates a new Lua table to hold the global variables.
// If the DevMode flag is true in the EnvConfig, it calls the DebugModule function to log debug information.
// It sets the global variables LString(global.LuaActionResultOk) and LString(global.LuaActionResultFail) with the corresponding values.
// It sets the global functions LString(global.LuaFnCtxSet), LString(global.LuaFnCtxGet), LString(global.LuaFnCtxDelete), and LString(global.LuaFnAddCustomLog) to their respective Lua functions
func (aw *Worker) setupGlobals(L *lua.LState, logs *lualib.CustomLogKeyValue, httpRequest *http.Request) *lua.LTable {
// setupGlobals initializes the global variables in the Lua state.
// It creates a new Lua table and sets the necessary variables and functions.
// If the DevMode configuration is enabled, it logs the Lua action request.
// The Lua table includes two variables, LuaActionResultOk and LuaActionResultFail,
// which are set to 0 and 1 respectively.
// It also includes a function LuaFnAddCustomLog, which is set to the AddCustomLog function
// from the lualib package. Finally, it sets the LuaDefaultTable global variable to the created table.
func (aw *Worker) setupGlobals(L *lua.LState, logs *lualib.CustomLogKeyValue) {
globals := L.NewTable()

if config.EnvConfig.DevMode {
Expand All @@ -250,16 +250,8 @@ func (aw *Worker) setupGlobals(L *lua.LState, logs *lualib.CustomLogKeyValue, ht
globals.RawSet(lua.LString(global.LuaActionResultFail), lua.LNumber(1))

globals.RawSetString(global.LuaFnAddCustomLog, L.NewFunction(lualib.AddCustomLog(logs)))
globals.RawSetString(global.LuaFnGetAllHTTPRequestHeaders, L.NewFunction(lualib.GetAllHTTPRequestHeaders(httpRequest)))
globals.RawSetString(global.LuaFnGetHTTPRequestHeader, L.NewFunction(lualib.GetHTTPRequestHeader(httpRequest)))

if config.LoadableConfig.HaveLDAPBackend() {
globals.RawSetString(global.LuaFnLDAPSearch, L.NewFunction(backend.LuaLDAPSearch(context.Background())))
}

L.SetGlobal(global.LuaDefaultTable, globals)

return globals
}

// setupRequest creates a Lua table representing the request data.
Expand Down Expand Up @@ -298,7 +290,8 @@ func (aw *Worker) runScript(index int, L *lua.LState, request *lua.LTable, logs

defer stopTimer()

luaCtx, luaCancel := context.WithTimeout(*(aw.ctx), viper.GetDuration("lua_script_timeout")*time.Second)
luaCtx, luaCancel := context.WithTimeout(aw.ctx, viper.GetDuration("lua_script_timeout")*time.Second)

L.SetContext(luaCtx)

if err = aw.executeScript(L, index, request); err != nil {
Expand Down
31 changes: 2 additions & 29 deletions server/lualib/backendresult.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package lualib

import (
"sync"

"github.com/croessner/nauthilus/server/global"
"github.com/croessner/nauthilus/server/lualib/convert"
"github.com/yuin/gopher-lua"
Expand Down Expand Up @@ -42,34 +40,9 @@ type LuaBackendResult struct {
Logs *CustomLogKeyValue
}

// LuaBackendResultStatePool embeds the LuaStatePool type.
// It provides methods for retrieving, returning, and shutting down Lua states.
type LuaBackendResultStatePool struct {
*LuaStatePool
}

// NewLuaBackendResultStatePool creates a new instance of LuaBackendResultStatePool that implements the LuaBaseStatePool
// interface. It initializes a LuaStatePool with a New function
func NewLuaBackendResultStatePool(methods ...string) LuaBaseStatePool {
lp := &LuaStatePool{
New: func() *lua.LState {
L := NewLStateWithDefaultLibraries()

registerBackendResultType(L, methods...)

return L
},
MaxStates: global.MaxLuaStatePoolSize,
}

lp.Cond = sync.Cond{L: &lp.Mu}

return &LuaBackendResultStatePool{lp.InitializeStatePool()}
}

// registerBackendResultType registers the Lua type "backend_result" in the given Lua state.
// RegisterBackendResultType registers the Lua type "nauthilus_backend_result" in the given Lua state.
// It sets the type metatable with the given name and creates the necessary static attributes and methods.
func registerBackendResultType(L *lua.LState, methods ...string) {
func RegisterBackendResultType(L *lua.LState, methods ...string) {
mt := L.NewTypeMetatable(global.LuaBackendResultTypeName)

L.SetGlobal(global.LuaBackendResultTypeName, mt)
Expand Down
Loading

0 comments on commit 3c7e277

Please sign in to comment.