From df3e648abea964008c500af99553418970d5cdd6 Mon Sep 17 00:00:00 2001 From: CapsAdmin Date: Wed, 29 Jun 2022 01:26:26 +0200 Subject: [PATCH] improve mutation tracking and fix some bugs --- .vscode/launch.json | 36 ++- build_output.lua | 312 +++++++++++++------- nattlua/analyzer/base/lexical_scope.lua | 14 +- nattlua/analyzer/control_flow.lua | 2 +- nattlua/analyzer/mutations.lua | 19 +- nattlua/analyzer/statements/generic_for.lua | 2 + test/nattlua/analyzer/generic_for.lua | 14 + test/nattlua/analyzer/if.lua | 9 + 8 files changed, 263 insertions(+), 145 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 2d486cb7..03d3a30e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,20 +1,18 @@ { - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "debug related tests", - "type": "lua-local", - "request": "launch", - "args": [ - "${relativeFile}" - ], - "program": { - "lua": "luajit", - "file": "test/run.lua", - } - }, - ] -} \ No newline at end of file + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "debug related tests", + "type": "lua-local", + "request": "launch", + "args": ["${relativeFile}"], + "program": { + "lua": "luajit", + "file": "test/run.lua" + } + } + ] +} diff --git a/build_output.lua b/build_output.lua index ebf66691..74a855eb 100755 --- a/build_output.lua +++ b/build_output.lua @@ -468,6 +468,27 @@ do return str end + local function string_lengthsplit(str, len) + if #str > len then + local tbl = {} + local max = math.floor(#str / len) + + for i = 0, max do + local left = i * len + 1 + local right = (i * len) + len + local res = str:sub(left, right) + + if res ~= "" then table.insert(tbl, res) end + end + + return tbl + end + + return {str} + end + + local MAX_WIDTH = 127 + function helpers.BuildSourceCodePointMessage( lua_code, path, @@ -476,6 +497,26 @@ do stop, size ) + do + local new_str = "" + local pos = 1 + + for i, chunk in ipairs(string_lengthsplit(lua_code, MAX_WIDTH)) do + if pos < start and i > 1 then + start = start + 1 + end + + if pos < stop and i > 1 then + stop = stop + 1 + end + + new_str = new_str .. chunk .. "\n" + pos = pos + #chunk + end + + lua_code = new_str + end + size = size or 2 start = clamp(start or 1, 1, #lua_code) stop = clamp(stop or 1, 1, #lua_code) @@ -539,6 +580,7 @@ do if #line > longest_line then longest_line = #line end end + longest_line = math.min(longest_line, MAX_WIDTH) table.insert( lines, 1, @@ -726,13 +768,15 @@ end META:GetSet("Data", nil) function META:GetLuaType() + local contract = self:GetContract() + if - self.Contract and - self.Contract.TypeOverride and - self.Contract.TypeOverride.Type == "string" and - self.Contract.TypeOverride.Data + contract and + contract.TypeOverride and + contract.TypeOverride.Type == "string" and + contract.TypeOverride.Data then - return self.Contract.TypeOverride.Data + return contract.TypeOverride.Data end return self.TypeOverride and @@ -948,7 +992,7 @@ do META:GetSet("MetaTable", nil) function META:GetMetaTable() - local contract = self.Contract + local contract = self:GetContract() if contract and contract.MetaTable then return contract.MetaTable end @@ -3605,25 +3649,27 @@ function META:CopyLiteralness(from) if self:Equal(from) then return true end - for _, keyval_from in ipairs(from:GetData()) do - local keyval, reason = self:FindKeyVal(keyval_from.key) + if from.Type == "table" then + for _, keyval_from in ipairs(from:GetData()) do + local keyval, reason = self:FindKeyVal(keyval_from.key) - if not keyval then return type_errors.other(reason) end + if not keyval then return type_errors.other(reason) end - if keyval_from.key.Type == "table" then - self.suppress = true - keyval.key:CopyLiteralness(keyval_from.key) -- TODO: never called - self.suppress = false - else - keyval.key:SetLiteral(keyval_from.key:IsLiteral()) - end + if keyval_from.key.Type == "table" then + self.suppress = true + keyval.key:CopyLiteralness(keyval_from.key) -- TODO: never called + self.suppress = false + else + keyval.key:SetLiteral(keyval_from.key:IsLiteral()) + end - if keyval_from.val.Type == "table" then - self.suppress = true - keyval.val:CopyLiteralness(keyval_from.val) - self.suppress = false - else - keyval.val:SetLiteral(keyval_from.val:IsLiteral()) + if keyval_from.val.Type == "table" then + self.suppress = true + keyval.val:CopyLiteralness(keyval_from.val) + self.suppress = false + else + keyval.val:SetLiteral(keyval_from.val:IsLiteral()) + end end end @@ -3697,6 +3743,12 @@ end function META:HasLiteralKeys() if self.suppress then return true end + local contract = self:GetContract() + + if contract and contract ~= self and not contract:HasLiteralKeys() then + return false + end + for _, v in ipairs(self:GetData()) do if v.val ~= self and @@ -3960,6 +4012,10 @@ _G.arg = _ + + + + @@ -3998,11 +4054,6 @@ IMPORTS['nattlua/definitions/lua/luajit.nlua'] = assert(loadstring([=======[ ret IMPORTS['nattlua/definitions/lua/debug.nlua'] = assert(loadstring([=======[ return function() - - - - - end ]=======], '@nattlua/definitions/lua/debug.nlua'))() IMPORTS['nattlua/definitions/lua/package.nlua'] = assert(loadstring([=======[ return function() end ]=======], '@nattlua/definitions/lua/package.nlua'))() IMPORTS['nattlua/definitions/lua/bit.nlua'] = assert(loadstring([=======[ return function() @@ -13150,7 +13201,11 @@ do -- runtime self:check_integer_division_operator(self:GetToken()) while - runtime_syntax:GetBinaryOperatorInfo(self:GetToken()) and + ( + runtime_syntax:GetBinaryOperatorInfo(self:GetToken()) and + not self:IsValue("=", 1) + ) + and runtime_syntax:GetBinaryOperatorInfo(self:GetToken()).left_priority > priority do local left_node = node @@ -13701,6 +13756,40 @@ function META:ParseCallOrAssignmentStatement() self:SuppressOnNode() local left = self:ParseMultipleValues(math.huge, self.ExpectRuntimeExpression, 0) + if + ( + self:IsValue("+") or + self:IsValue("-") or + self:IsValue("*") or + self:IsValue("/") or + self:IsValue("%") or + self:IsValue("^") or + self:IsValue("..") + ) and + self:IsValue("=", 1) + then + -- roblox compound assignment + local op_token = self:ParseToken() + local eq_token = self:ParseToken() + local bop = self:StartNode("expression", "binary_operator") + bop.left = left[1] + bop.value = op_token + bop.right = self:ExpectRuntimeExpression(0) + self:EndNode(bop) + local node = self:StartNode("statement", "assignment", left[1]) + node.tokens["="] = eq_token + node.left = left + + for i, v in ipairs(node.left) do + v.is_left_assignment = true + end + + node.right = {bop} + self:ReRunOnNode(node.left) + node = self:EndNode(node) + return node + end + if self:IsValue("=") then local node = self:StartNode("statement", "assignment", left[1]) node.tokens["="] = self:ExpectValue("=") @@ -14550,11 +14639,15 @@ function META:ParseStatement() self:ParseLocalAnalyzerFunctionStatement() or self:ParseLocalTypeAssignmentStatement() or self:ParseLocalDestructureAssignmentStatement() or - self.TealCompat and - self:ParseLocalTealRecord() + ( + self.TealCompat and + self:ParseLocalTealRecord() + ) or - self.TealCompat and - self:ParseLocalTealEnumStatement() + ( + self.TealCompat and + self:ParseLocalTealEnumStatement() + ) or self:ParseLocalAssignmentStatement() or self:ParseTypeAssignmentStatement() or @@ -15112,14 +15205,26 @@ function META:SetStatement(statement) self.statement = statement end +function META:SetLoopIteration(i) + self.loop_iteration = i +end + function META:GetStatementType() return self.statement and self.statement.kind end function META.IsPartOfTestStatementAs(a, b) - return a:GetStatementType() == "if" and + local yes = a:GetStatementType() == "if" and b:GetStatementType() == "if" and a.statement == b.statement + + if yes then + local a_iteration = a:GetMemberInParents("loop_iteration") + local b_iteration = b:GetMemberInParents("loop_iteration") + return a_iteration == b_iteration + end + + return yes end function META:FindFirstConditionalScope() @@ -16340,7 +16445,7 @@ return function(META) self:ApplyMutationsAfterReturn( self:GetScope(), - nil, + self:GetScope():GetNearestFunctionScope(), false, self:GetTrackedUpvalues(old), self:GetTrackedTables() @@ -16509,11 +16614,10 @@ return function(META) for _, frame in ipairs(self:GetCallStack()) do local parent_scope = frame.scope - if - not parent_scope:IsCertain() or - parent_scope.uncertain_function_return == true - then - if parent_scope:IsCertainFromScope(scope) then return false end + if parent_scope.uncertain_function_return then return true end + + if not parent_scope:IsCertain() and parent_scope:IsCertainFromScope(scope) then + return false end end @@ -16576,7 +16680,7 @@ local ipairs = ipairs local table = _G.table local Union = IMPORTS['nattlua.types.union']("nattlua.types.union").Union -local function get_value_from_scope(current_if_statement, mutations, scope, obj) +local function get_value_from_scope(mutations, scope, obj) do do local last_scope @@ -16599,25 +16703,14 @@ local function get_value_from_scope(current_if_statement, mutations, scope, obj) if ( scope:IsPartOfTestStatementAs(mut.scope) or - ( - current_if_statement and - mut.scope.statement == current_if_statement - ) - or ( mut.from_tracking and - not mut.scope:IsCertainFromScope(scope) - ) - or - ( - obj.Type == "table" and - obj:GetContract() ~= mut.contract + not mut.scope:Contains(scope) ) ) and scope ~= mut.scope then - -- not inside the same if statement" table.remove(mutations, i) end end @@ -16741,51 +16834,47 @@ local function get_value_from_scope(current_if_statement, mutations, scope, obj) value = union:GetData()[1] if obj.Type == "upvalue" then value:SetUpvalue(obj) end + + return value end - if value.Type == "union" then - local found_scope, data = scope:FindResponsibleConditionalScopeFromUpvalue(obj) + local found_scope, data = scope:FindResponsibleConditionalScopeFromUpvalue(obj) - if found_scope then - local stack = data.stack + if not found_scope or not data.stack then return value end - if stack then - if - found_scope:IsElseConditionalScope() or - ( - found_scope ~= scope and - scope:IsPartOfTestStatementAs(found_scope) - ) - then - local union = stack[#stack].falsy + local stack = data.stack - if union:GetLength() == 0 then - union = Union() + if + found_scope:IsElseConditionalScope() or + ( + found_scope ~= scope and + scope:IsPartOfTestStatementAs(found_scope) + ) + then + local union = stack[#stack].falsy - for _, val in ipairs(stack) do - union:AddType(val.falsy) - end - end + if union:GetLength() == 0 then + union = Union() - if obj.Type == "upvalue" then union:SetUpvalue(obj) end + for _, val in ipairs(stack) do + union:AddType(val.falsy) + end + end - return union - else - local union = Union() + if obj.Type == "upvalue" then union:SetUpvalue(obj) end - for _, val in ipairs(stack) do - union:AddType(val.truthy) - end + return union + end - if obj.Type == "upvalue" then union:SetUpvalue(obj) end + local union = Union() - return union - end - end - end + for _, val in ipairs(stack) do + union:AddType(val.truthy) end - return value + if obj.Type == "upvalue" then union:SetUpvalue(obj) end + + return union end local function initialize_table_mutation_tracker(tbl, scope, key, hash) @@ -16827,7 +16916,7 @@ return function(META) local scope = self:GetScope() initialize_table_mutation_tracker(tbl, scope, key, hash) - return get_value_from_scope(self.current_if_statement, shallow_copy(tbl.mutations[hash]), scope, tbl) + return get_value_from_scope(shallow_copy(tbl.mutations[hash]), scope, tbl) end function META:MutateTable(tbl, key, val, scope_override, from_tracking) @@ -16853,7 +16942,7 @@ return function(META) function META:GetMutatedUpvalue(upvalue) upvalue.mutations = upvalue.mutations or {} - return get_value_from_scope(self.current_if_statement, shallow_copy(upvalue.mutations), self:GetScope(), upvalue) + return get_value_from_scope(shallow_copy(upvalue.mutations), self:GetScope(), upvalue) end function META:MutateUpvalue(upvalue, val, scope_override, from_tracking) @@ -16895,7 +16984,7 @@ return function(META) if mut.from_tracking then table.remove(obj.mutations, i) end end - else + elseif obj.mutations then for _, mutations in pairs(obj.mutations) do for i = #mutations, 1, -1 do local mut = mutations[i] @@ -17839,7 +17928,8 @@ local function mutate_type(self, i, arg, contract, arguments) env.mutated_types = env.mutated_types or {} arg:PushContract(contract) arg.argument_index = i - table.insert(env.mutated_types, arg) + arg.mutations = nil + table.insert(env.mutated_types, {arg = arg, mutations = arg.mutations}) arguments:Set(i, arg) end @@ -17848,10 +17938,11 @@ local function restore_mutated_types(self) if not env.mutated_types or not env.mutated_types[1] then return end - for _, arg in ipairs(env.mutated_types) do - arg:PopContract() - arg.argument_index = nil - self:MutateUpvalue(arg:GetUpvalue(), arg) + for _, data in ipairs(env.mutated_types) do + data.arg:PopContract() + data.arg.argument_index = nil + data.arg.mutations = data.mutations + self:MutateUpvalue(data.arg:GetUpvalue(), data.arg) end env.mutated_types = {} @@ -21108,7 +21199,9 @@ return { self:CreateLocalValue(identifier.value.value, obj) end + self:CreateAndPushScope():SetLoopIteration(i) self:AnalyzeStatements(statement.statements) + self:PopScope() if self._continue_ then self._continue_ = nil end @@ -23233,6 +23326,22 @@ type ipairs = function=(t: Table)>(empty_function, Table, number) type tonumber = function=(e: number | string, base: number | nil)>(number | nil) _G.arg = _ as List<|any|> +analyzer function setfenv(val: Function, table: Table) + if val and (val:IsLiteral() or val.Type == "function") then + if val.Type == "number" then + analyzer:SetEnvironmentOverride(analyzer.environment_nodes[val:GetData()], table, "runtime") + elseif val:GetFunctionBodyNode() then + analyzer:SetEnvironmentOverride(val:GetFunctionBodyNode(), table, "runtime") + end + end +end + +analyzer function getfenv(func: Function | nil) + if not func then return analyzer:GetDefaultEnvironment("typesystem") end + + return analyzer:GetGlobalEnvironmentOverride(func:GetFunctionBodyNode() or func, "runtime") +end + analyzer function type_print(...: ...any) print(...) end @@ -23939,23 +24048,8 @@ type debug = { setfenv = function=(object: any, Table: Table)>(any), setuservalue = function=(udata: userdata, value: Table | nil)>(userdata), } - -analyzer function debug.setfenv(val: Function, table: Table) - if val and (val:IsLiteral() or val.Type == "function") then - if val.Type == "number" then - analyzer:SetEnvironmentOverride(analyzer.environment_nodes[val:GetData()], table, "runtime") - elseif val:GetFunctionBodyNode() then - analyzer:SetEnvironmentOverride(val:GetFunctionBodyNode(), table, "runtime") - end - end -end - -analyzer function debug.getfenv(func: Function) - return analyzer:GetGlobalEnvironmentOverride(func:GetFunctionBodyNode() or func, "runtime") -end - -type getfenv = debug.getfenv -type setfenv = debug.setfenv end +type debug.getfenv = getfenv +type debug.setfenv = setfenv end IMPORTS['nattlua/definitions/lua/package.nlua'] = function() type package = { searchpath = function=(name: string, path: string, sep: string, rep: string)>(string | nil, string | nil) | function=(name: string, path: string, sep: string)>(string | nil, string | nil) | function=(name: string, path: string)>(string | nil, string | nil), seeall = function=(module: Table)>(nil), diff --git a/nattlua/analyzer/base/lexical_scope.lua b/nattlua/analyzer/base/lexical_scope.lua index 3714dda1..dd2027a5 100644 --- a/nattlua/analyzer/base/lexical_scope.lua +++ b/nattlua/analyzer/base/lexical_scope.lua @@ -230,14 +230,26 @@ function META:SetStatement(statement) self.statement = statement end +function META:SetLoopIteration(i) + self.loop_iteration = i +end + function META:GetStatementType() return self.statement and self.statement.kind end function META.IsPartOfTestStatementAs(a, b) - return a:GetStatementType() == "if" and + local yes = a:GetStatementType() == "if" and b:GetStatementType() == "if" and a.statement == b.statement + + if yes then + local a_iteration = a:GetMemberInParents("loop_iteration") + local b_iteration = b:GetMemberInParents("loop_iteration") + return a_iteration == b_iteration + end + + return yes end function META:FindFirstConditionalScope() diff --git a/nattlua/analyzer/control_flow.lua b/nattlua/analyzer/control_flow.lua index 42096c71..9e71ce3b 100644 --- a/nattlua/analyzer/control_flow.lua +++ b/nattlua/analyzer/control_flow.lua @@ -157,7 +157,7 @@ return function(META) self:ApplyMutationsAfterReturn( self:GetScope(), - nil, + self:GetScope():GetNearestFunctionScope(), false, self:GetTrackedUpvalues(old), self:GetTrackedTables() diff --git a/nattlua/analyzer/mutations.lua b/nattlua/analyzer/mutations.lua index 5dde1666..cf6aaff5 100644 --- a/nattlua/analyzer/mutations.lua +++ b/nattlua/analyzer/mutations.lua @@ -7,7 +7,7 @@ local ipairs = ipairs local table = _G.table local Union = require("nattlua.types.union").Union -local function get_value_from_scope(current_if_statement, mutations, scope, obj) +local function get_value_from_scope(mutations, scope, obj) do do local last_scope @@ -30,25 +30,14 @@ local function get_value_from_scope(current_if_statement, mutations, scope, obj) if ( scope:IsPartOfTestStatementAs(mut.scope) or - ( - current_if_statement and - mut.scope.statement == current_if_statement - ) - or ( mut.from_tracking and - not mut.scope:IsCertainFromScope(scope) - ) - or - ( - obj.Type == "table" and - obj:GetContract() ~= mut.contract + not mut.scope:Contains(scope) ) ) and scope ~= mut.scope then - -- not inside the same if statement" table.remove(mutations, i) end end @@ -254,7 +243,7 @@ return function(META) local scope = self:GetScope() initialize_table_mutation_tracker(tbl, scope, key, hash) - return get_value_from_scope(self.current_if_statement, shallow_copy(tbl.mutations[hash]), scope, tbl) + return get_value_from_scope(shallow_copy(tbl.mutations[hash]), scope, tbl) end function META:MutateTable(tbl, key, val, scope_override, from_tracking) @@ -280,7 +269,7 @@ return function(META) function META:GetMutatedUpvalue(upvalue) upvalue.mutations = upvalue.mutations or {} - return get_value_from_scope(self.current_if_statement, shallow_copy(upvalue.mutations), self:GetScope(), upvalue) + return get_value_from_scope(shallow_copy(upvalue.mutations), self:GetScope(), upvalue) end function META:MutateUpvalue(upvalue, val, scope_override, from_tracking) diff --git a/nattlua/analyzer/statements/generic_for.lua b/nattlua/analyzer/statements/generic_for.lua index f5cbcccc..e2a6422c 100644 --- a/nattlua/analyzer/statements/generic_for.lua +++ b/nattlua/analyzer/statements/generic_for.lua @@ -56,7 +56,9 @@ return { self:CreateLocalValue(identifier.value.value, obj) end + self:CreateAndPushScope():SetLoopIteration(i) self:AnalyzeStatements(statement.statements) + self:PopScope() if self._continue_ then self._continue_ = nil end diff --git a/test/nattlua/analyzer/generic_for.lua b/test/nattlua/analyzer/generic_for.lua index 694e0b26..b4088f25 100644 --- a/test/nattlua/analyzer/generic_for.lua +++ b/test/nattlua/analyzer/generic_for.lua @@ -56,3 +56,17 @@ analyze[[ attest.equal(sum, _ as number) ]] +analyze[[ + local e = { + SOCK_SEQPACKET = 5, + SOCK_DCCP = 6, + } + local what = "SOCK_" + + for k, v in pairs(e) do + if k:sub(0, #what) == what then + local lol = k:sub(#what + 1) + lol:lower() + end + end +]] \ No newline at end of file diff --git a/test/nattlua/analyzer/if.lua b/test/nattlua/analyzer/if.lua index 85622511..8ac70923 100644 --- a/test/nattlua/analyzer/if.lua +++ b/test/nattlua/analyzer/if.lua @@ -1438,6 +1438,15 @@ analyze[[ local str = last_error() attest.equal(str, _ as nil | "hello") ]] +analyze[[ + local ffi = require("ffi") + + do + assert(ffi.sizeof("int") == 4) + end + + attest.truthy(ffi.sizeof) +]] if false then analyze[==[