Skip to content

Commit

Permalink
[wip] load all policies through the policy loader
Browse files Browse the repository at this point in the history
- and fallback to native just when needed
  • Loading branch information
mikz committed Jan 31, 2018
1 parent fcf9023 commit 95ddd64
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 86 deletions.
2 changes: 1 addition & 1 deletion gateway/cpanfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
requires 'Test::APIcast', '0.03';
requires 'Test::APIcast', '0.05';
requires 'Crypt::JWT';
2 changes: 1 addition & 1 deletion gateway/src/apicast/configuration.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ local function build_policy_chain(policies)
local chain = {}

for i=1, #policies do
chain[i] = policy_chain.load(policies[i].name, policies[i].configuration)
chain[i] = policy_chain.load_policy(policies[i].name, policies[i].version, policies[i].configuration)
end

return policy_chain.new(chain)
Expand Down
2 changes: 1 addition & 1 deletion gateway/src/apicast/policy/local_chain/policy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ local function build_default_chain()

if resty_env.get('APICAST_MODULE') then
-- Needed to keep compatibility with the old module system.
module = 'apicast.module'
module = assert(require('apicast.module'), 'could not load custom module')
else
module = 'apicast.policy.apicast'
end
Expand Down
13 changes: 10 additions & 3 deletions gateway/src/apicast/policy_chain.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ local rawset = rawset
local type = type
local require = require
local insert = table.insert
local sub = string.sub
local format = string.format
local noop = function() end

require('apicast.loader')
Expand Down Expand Up @@ -47,7 +49,7 @@ function _M.build(modules)

for i=1, #list do
-- TODO: make this error better, possibly not crash and just log and skip the module
chain[i] = _M.load(list[i]) or error("module " .. list[i] .. ' could not be loaded')
chain[i] = _M.load_policy(list[i]) or error(format('module %q could not be loaded', list[i]))
end

return _M.new(chain)
Expand All @@ -73,9 +75,14 @@ end
-- @tparam string|table module the module or its name
-- @tparam ?table ... params needed to initialize the module
-- @treturn object The module instantiated
function _M.load(module, ...)
function _M.load_policy(module, version, ...)
if type(module) == 'string' then
local mod = policy_loader.call(module)
if sub(module, 1, 14) == 'apicast.policy' then
module = sub(module, 16)
version = 'builtin'
end

local mod = policy_loader(module, version or 'builtin')

if mod then
return mod.new(...)
Expand Down
117 changes: 94 additions & 23 deletions gateway/src/apicast/policy_loader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ local loadfile = loadfile
local getenv = os.getenv
local insert = table.insert
local setmetatable = setmetatable
local concat = table.concat

local _M = {

}
local _M = { }

local searchpath = package.searchpath
local root_loaded = package.loaded
Expand All @@ -26,21 +25,34 @@ preload['apicast.policy_loader'] = function() return _M end

local prequire = function(...) return pcall(root_require, ...) end

--- create a require function not using the global namespace
-- loading code from policy namespace should have no effect on the global namespace
-- but poliocy can load shared libraries that would be cached globally
local function gen_require(package)

local function not_found(modname, err)
return error(format("module '%s' not found:%s", modname, err), 0)
end

--- helper function to safely use the native require function
local function fallback(modname, err)
ngx.log(ngx.DEBUG, 'native require for : ', modname)
local ok, mod = prequire(modname)
local mod

if not ok then return error(err or format("module '%s' not found\n", modname)) end
mod = package.loaded[modname]

return mod
end
if not mod then
local ok
ngx.log(ngx.DEBUG, 'native require for: ', modname)
ok, mod = prequire(modname)

return function(modname)
local mod = root_loaded[modname] or package.loaded[modname]
if not ok then return not_found(modname, err) end
end

if mod then return mod end
return mod
end

--- helper function to find and return correct loader for a module
local function find_loader(modname)
local loader, file, err, ret

for i=1, #package.searchers do
Expand All @@ -55,11 +67,31 @@ local function gen_require(package)
end
end

return loader, file, err
end

--- reimplemented require function
-- - return a module if it was already loaded (globally or locally)
-- - try to find loader function
-- - fallback to global require
-- @tparam string modname module name
-- @tparam boolean exclusive load only policy code, turns off the fallback loader
return function(modname, exclusive)
ngx.log(ngx.DEBUG, 'sandbox require: ', modname)
local mod = root_loaded[modname]

if mod then return mod end

local loader, file, err = find_loader(modname)

if loader then
ngx.log(ngx.DEBUG, 'sandboxed require for: ', modname, ' file: ', file)
mod = loader(modname, file)
else
elseif not exclusive then
ngx.log(ngx.DEBUG, 'fallback loader for: ', modname)
mod = fallback(modname, err)
else
return not_found(modname, err)
end

if mod ~= nil then
Expand All @@ -72,6 +104,9 @@ local function gen_require(package)
end
end

--- this is environment exposed to the policies
-- that means this is very light sandbox so policies don't mutate global env
-- and most importantly we replace the require function with our own
_M.env = {
math = math,
table = table,
Expand All @@ -84,44 +119,80 @@ _M.env = {
setmetatable = setmetatable,
getmetatable = getmetatable,
coroutine = coroutine,
ipairs = ipairs, pairs = pairs,
ipairs = ipairs, pairs = pairs, next = next,
ngx = ngx,
}
_M.env._G = _M.env

function _M.call(name, version)
ngx.log(ngx.DEBUG, 'loading policy: ', name, ' version: ', version)
local mt = {
__call = function(loader, ...) return loader.env.require(...) end
}

function _M.new(name, version)
local apicast_dir = getenv('APICAST_DIR') or '.'

local path = {
-- first path contains
format('%s/policies/%s/%s/?.lua', apicast_dir, name, version),
}

local apicast_dir = getenv('APICAST_DIR')
if version == 'builtin' then
insert(path, format('%s/src/apicast/policy/%s/?.lua', apicast_dir, name))
end

-- need to create global variable package that mimics the native one
local package = {
loaded = {},
preload = preload,
searchers = {},
searchers = {}, -- http://www.lua.org/manual/5.2/manual.html#pdf-package.searchers
searchpath = searchpath,
path = format('%s/src/?/policy.lua', apicast_dir), -- FIXME: lock the path to just one policy tree
path = concat(path, ';'),
cpath = '', -- no C libraries allowed in policies
}

-- creating new env for each policy means they can't accidentaly share global variables
local env = setmetatable({
require = gen_require(package),
package = package,
}, { __index = _M.env })

-- The first searcher simply looks for a loader in the package.preload table.
insert(package.searchers, function(modname) return package.preload[modname] end)
-- The second searcher looks for a loader as a Lua library, using the path stored at package.path.
-- The search is done as described in function package.searchpath.
insert(package.searchers, function(modname)
local file, err = searchpath(modname, package.path)
local loader

if file then
file, err = loadfile(file, 'bt', env)
loader, err = loadfile(file, 'bt', env)

ngx.log(ngx.DEBUG, 'loading file: ', file)

if loader then return loader, file end
end

return file, err
return err
end)

local mod = env.require(name)
local self = {
env = env,
name = name,
version = version,
}

return setmetatable(self, mt)
end

function _M:call(name, version)
local v = version or 'builtin'
local loader = self.new(name, v)

ngx.log(ngx.DEBUG, 'loading policy: ', name, ' version: ', v)

return mod
-- passing the "exclusive" flag for the require so it does not fallback to native require
-- it should load only policies and not other code and fail if there is no such policy
return loader('policy', true)
end

return _M
return setmetatable(_M, { __call = _M.call })
50 changes: 26 additions & 24 deletions spec/executor_spec.lua
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@
-- Policies included by default in the executor
local default_executor_chain = {
require 'apicast.policy.load_configuration',
require 'apicast.policy.find_service',
require 'apicast.policy.local_chain'
}

local executor = require 'apicast.executor'
local policy_chain = require 'apicast.policy_chain'
local PolicyChain = require 'apicast.policy_chain'
local Policy = require 'apicast.policy'

describe('executor', function()
local phases = {
'init', 'init_worker',
'rewrite', 'access', 'balancer',
'header_filter', 'body_filter',
'post_action', 'log'
}

it('forwards all the policy methods to the policy chain', function()
local chain = PolicyChain.default()
local exec = executor.new(chain)
-- Stub all the nginx phases methods for each of the policies
for _, phase in ipairs(phases) do
for _, policy in ipairs(default_executor_chain) do
for _, phase in Policy.phases() do
for _, policy in ipairs(chain) do
stub(policy, phase)
end
end

-- For each one of the nginx phases, verify that when called on the
-- executor, each one of the policies executes the code for that phase.
for _, phase in ipairs(phases) do
executor[phase](executor)
for _, policy in ipairs(default_executor_chain) do
for _, phase in Policy.phases() do
exec[phase](exec)
for _, policy in ipairs(chain) do
assert.stub(policy[phase]).was_called()
end
end
end)

it('is initialized with default chain', function()
local default = PolicyChain.default()
local policy_chain = executor.policy_chain

assert.same(#default, #policy_chain)

for i,policy in ipairs(default) do
assert.same(policy._NAME, policy_chain[i]._NAME)
assert.same(policy._VERSION, policy_chain[i]._VERSION)

assert.equal(policy, policy_chain[i])
end
end)

it('freezes the policy chain', function()
local chain = policy_chain.new({})
local chain = PolicyChain.new({})
assert.falsy(chain.frozen)

executor.new(chain)
Expand All @@ -51,7 +53,7 @@ describe('executor', function()
local policy_2 = Policy.new('2')
policy_2.export = function() return { p2 = '2' } end

local chain = policy_chain.new({ policy_1, policy_2 })
local chain = PolicyChain.new({ policy_1, policy_2 })
local context = executor.new(chain):context('rewrite')

assert.equal('1', context.p1)
Expand All @@ -68,8 +70,8 @@ describe('executor', function()
local policy_3 = Policy.new('3')
policy_3.export = function() return { p3 = '3' } end

local inner_chain = policy_chain.new({ policy_2, policy_3 })
local outer_chain = policy_chain.new({ policy_1, inner_chain })
local inner_chain = PolicyChain.new({ policy_2, policy_3 })
local outer_chain = PolicyChain.new({ policy_1, inner_chain })

local context = executor.new(outer_chain):context('rewrite')

Expand Down
12 changes: 2 additions & 10 deletions spec/policy_chain_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,8 @@ describe('policy_chain', function()
end)

it('uses APIcast as default when no policies are specified', function()
-- Stub apicast methods to avoid calling them. We are just interested in
-- knowing whether they were called.
for _, phase in ipairs(phases) do
stub(apicast, phase)
end

for _, phase in ipairs(phases) do
_M[phase](_M)
assert.stub(apicast[phase]).was_called()
end
assert.equal(1, #_M)
assert.equal('APIcast', _M[1]._NAME)
end)

it('calls the policies in the order specified when building the chain', function()
Expand Down
11 changes: 10 additions & 1 deletion spec/policy_loader_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ describe('APIcast Policy Loader', function()
describe('.call', function()
it('finds apicast builtin policy', function()
local root = require('apicast.policy.apicast')
local sandbox = _M.call('apicast')
local sandbox = _M:call('apicast')

assert.is_table(sandbox)
assert.are_not.same(root, sandbox)
assert.equal(root._NAME, sandbox._NAME)
assert.equal(root._VERSION, sandbox._VERSION)
end)

it('uses sandboxed load paths', function()
local ok, ret = pcall(_M.call, _M, 'unknown', '0.1')

assert.falsy(ok)
assert.match([[module 'policy' not found:
%s+no file '%g+/gateway/policies/unknown/0.1/policy.lua']], ret)
end)
end)
end)
4 changes: 3 additions & 1 deletion t/apicast-path-routing.t
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ apicast cache miss key: 42:one-key:usage%5Bone%5D=1
apicast cache write key: 42:one-key:usage%5Bone%5D=1
apicast cache miss key: 21:two-id:two-key:usage%5Btwo%5D=2
apicast cache write key: 21:two-id:two-key:usage%5Btwo%5D=2
--- no_error_log
[error]
=== TEST 2: multi service configuration with path based routing defaults to host routing
If none of the services match it goes for the host.
Expand Down Expand Up @@ -136,5 +137,6 @@ env APICAST_PATH_ROUTING_ENABLED;
--- request eval
["GET /foo?user_key=1","GET /foo?user_key=2"]
--- no_error_log
[error]
--- error_code eval
[ 412, 412 ]
Loading

0 comments on commit 95ddd64

Please sign in to comment.