From ffddcf6ab83ba5c4092a25552f3840359b5c5de0 Mon Sep 17 00:00:00 2001 From: Ali Chraghi Date: Sun, 25 Feb 2024 02:02:35 +0330 Subject: [PATCH] AstGen: order declarations correctly --- src/shader/Air.zig | 1 + src/shader/AstGen.zig | 52 ++++++++++++++++++++++-------------- src/shader/codegen/spirv.zig | 13 ++++----- 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/src/shader/Air.zig b/src/shader/Air.zig index a0c84c52a..81efaa5bf 100644 --- a/src/shader/Air.zig +++ b/src/shader/Air.zig @@ -44,6 +44,7 @@ pub fn generate( defer { astgen.instructions.deinit(allocator); astgen.scratch.deinit(allocator); + astgen.globals.deinit(allocator); astgen.global_var_refs.deinit(allocator); astgen.scope_pool.deinit(); astgen.inst_arena.deinit(); diff --git a/src/shader/AstGen.zig b/src/shader/AstGen.zig index ece3eb4e2..de276e5cd 100644 --- a/src/shader/AstGen.zig +++ b/src/shader/AstGen.zig @@ -25,6 +25,7 @@ strings: std.ArrayListUnmanaged(u8) = .{}, values: std.ArrayListUnmanaged(u8) = .{}, scratch: std.ArrayListUnmanaged(InstIndex) = .{}, global_var_refs: std.AutoArrayHashMapUnmanaged(InstIndex, void) = .{}, +globals: std.ArrayListUnmanaged(InstIndex) = .{}, has_array_length: bool = false, compute_stage: InstIndex = .none, vertex_stage: InstIndex = .none, @@ -59,9 +60,6 @@ pub const Scope = struct { }; pub fn genTranslationUnit(astgen: *AstGen) !RefIndex { - const scratch_top = astgen.scratch.items.len; - defer astgen.scratch.shrinkRetainingCapacity(scratch_top); - var root_scope = try astgen.scope_pool.create(); root_scope.* = .{ .tag = .root, .parent = undefined }; @@ -70,15 +68,24 @@ pub fn genTranslationUnit(astgen: *AstGen) !RefIndex { for (global_nodes) |node| { var global = root_scope.decls.get(node).? catch continue; - if (global == .none) { - // declaration has not analysed - global = astgen.genGlobalDecl(root_scope, node) catch |err| switch (err) { - error.AnalysisFail => continue, - error.OutOfMemory => return error.OutOfMemory, - }; - } - - try astgen.scratch.append(astgen.allocator, global); + global = switch (astgen.tree.nodeTag(node)) { + .@"fn" => blk: { + std.debug.assert(global == .none); + break :blk astgen.genFn(root_scope, node, false) catch |err| switch (err) { + error.Skiped => continue, + else => |e| e, + }; + }, + else => continue, + } catch |err| { + if (err == error.AnalysisFail) { + root_scope.decls.putAssumeCapacity(node, error.AnalysisFail); + continue; + } + return err; + }; + root_scope.decls.putAssumeCapacity(node, global); + try astgen.globals.append(astgen.allocator, global); } if (astgen.errors.list.items.len > 0) return error.AnalysisFail; @@ -91,7 +98,7 @@ pub fn genTranslationUnit(astgen: *AstGen) !RefIndex { try astgen.errors.add(Loc{ .start = 0, .end = 1 }, "entry point not found", .{}, null); } - return astgen.addRefList(astgen.scratch.items[scratch_top..]); + return astgen.addRefList(astgen.globals.items); } /// adds `nodes` to scope and checks for re-declarations @@ -122,23 +129,26 @@ fn scanDecls(astgen: *AstGen, scope: *Scope, nodes: []const NodeIndex) !void { } } -fn genGlobalDecl(astgen: *AstGen, scope: *Scope, node: NodeIndex) !InstIndex { +fn genGlobalDecl(astgen: *AstGen, scope: *Scope, node: NodeIndex) error{ OutOfMemory, AnalysisFail }!InstIndex { const decl = switch (astgen.tree.nodeTag(node)) { .global_var => astgen.genGlobalVar(scope, node), .override => astgen.genOverride(scope, node), .@"const" => astgen.genConst(scope, node), .@"struct" => astgen.genStruct(scope, node), - .@"fn" => astgen.genFn(scope, node), + .@"fn" => astgen.genFn(scope, node, false), .type_alias => astgen.genTypeAlias(scope, node), else => unreachable, - } catch |err| { - if (err == error.AnalysisFail) { + } catch |err| switch (err) { + error.AnalysisFail => { scope.decls.putAssumeCapacity(node, error.AnalysisFail); - } - return err; + return error.AnalysisFail; + }, + error.Skiped => unreachable, + else => |e| return e, }; scope.decls.putAssumeCapacity(node, decl); + try astgen.globals.append(astgen.allocator, decl); return decl; } @@ -402,7 +412,7 @@ fn genStructMembers(astgen: *AstGen, scope: *Scope, node: NodeIndex) !RefIndex { return astgen.addRefList(astgen.scratch.items[scratch_top..]); } -fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex) !InstIndex { +fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex, only_entry_point: bool) !InstIndex { const scratch_top = astgen.global_var_refs.count(); defer astgen.global_var_refs.shrinkRetainingCapacity(scratch_top); @@ -480,6 +490,8 @@ fn genFn(astgen: *AstGen, root_scope: *Scope, node: NodeIndex) !InstIndex { } } + if (only_entry_point and stage == .none) return error.Skiped; + if (stage == .compute) { if (return_type != .none) { try astgen.errors.add( diff --git a/src/shader/codegen/spirv.zig b/src/shader/codegen/spirv.zig index e34ddab00..270086d87 100644 --- a/src/shader/codegen/spirv.zig +++ b/src/shader/codegen/spirv.zig @@ -128,9 +128,7 @@ pub fn gen(allocator: std.mem.Allocator, air: *const Air, debug_info: DebugInfo) for (air.refToList(air.globals_index)) |inst_idx| { switch (spv.air.getInst(inst_idx)) { - .@"fn" => |@"fn"| if (@"fn".stage != .none) { - _ = try spv.emitFn(inst_idx); - }, + .@"fn" => _ = try spv.emitFn(inst_idx), .@"const" => _ = try spv.emitConst(&spv.global_section, inst_idx), .@"var" => _ = try spv.emitVarProto(&spv.global_section, inst_idx), .@"struct" => _ = try spv.emitStruct(inst_idx), @@ -1358,8 +1356,11 @@ const PtrAccess = struct { fn emitVarAccess(spv: *SpirV, section: *Section, inst: InstIndex) !PtrAccess { const decl = spv.decl_map.get(inst) orelse blk: { - std.debug.assert(spv.air.getInst(inst) == .@"const"); - _ = try spv.emitConst(section, inst); + switch (spv.air.getInst(inst)) { + .@"const" => _ = try spv.emitConst(&spv.global_section, inst), + .@"var" => _ = try spv.emitVarProto(&spv.global_section, inst), + else => unreachable, + } break :blk spv.decl_map.get(inst).?; }; @@ -1969,7 +1970,7 @@ fn emitTripleIntrinsic(spv: *SpirV, section: *Section, triple: Inst.TripleIntrin .smoothstep => 49, }; - if (triple.op == .mix) { + if (triple.op == .mix and spv.air.getInst(triple.result_type) == .vector) { const vec_type_inst = spv.air.getInst(triple.result_type).vector; var constituents = std.BoundedArray(IdRef, 4){}; constituents.appendNTimesAssumeCapacity(a3, @intFromEnum(vec_type_inst.size));