diff --git a/spec/compiler/parser/parser_spec.cr b/spec/compiler/parser/parser_spec.cr index abe382f0d5fc..8e4e0777b813 100644 --- a/spec/compiler/parser/parser_spec.cr +++ b/spec/compiler/parser/parser_spec.cr @@ -4,8 +4,8 @@ private def regex(string, options = Regex::Options::None) RegexLiteral.new(StringLiteral.new(string), options) end -private def it_parses(string, expected_node, file = __FILE__, line = __LINE__) - it "parses #{string.dump}", file, line do +private def it_parses(string, expected_node, file = __FILE__, line = __LINE__, *, focus : Bool = false) + it "parses #{string.dump}", file, line, focus: focus do parser = Parser.new(string) parser.filename = "/foo/bar/baz.cr" node = parser.parse @@ -316,6 +316,26 @@ module Crystal it_parses "def foo(@@var = 1); 1; end", Def.new("foo", [Arg.new("var", 1.int32)], [Assign.new("@@var".class_var, "var".var), 1.int32] of ASTNode) it_parses "def foo(&@block); end", Def.new("foo", body: Assign.new("@block".instance_var, "block".var), block_arg: Arg.new("block"), yields: 0) + # Defs with annotated parameters + it_parses "def foo(@[Foo] var); end", Def.new("foo", ["var".arg(annotations: ["Foo".ann])]) + it_parses "def foo(@[Foo] outer inner); end", Def.new("foo", ["inner".arg(annotations: ["Foo".ann], external_name: "outer")]) + it_parses "def foo(@[Foo] var); end", Def.new("foo", ["var".arg(annotations: ["Foo".ann])]) + it_parses "def foo(a, @[Foo] var); end", Def.new("foo", ["a".arg, "var".arg(annotations: ["Foo".ann])]) + it_parses "def foo(a, @[Foo] &block); end", Def.new("foo", ["a".arg], block_arg: "block".arg(annotations: ["Foo".ann]), yields: 0) + it_parses "def foo(@[Foo] @var); end", Def.new("foo", ["var".arg(annotations: ["Foo".ann])], [Assign.new("@var".instance_var, "var".var)] of ASTNode) + it_parses "def foo(@[Foo] var : Int32); end", Def.new("foo", ["var".arg(restriction: "Int32".path, annotations: ["Foo".ann])]) + it_parses "def foo(@[Foo] @[Bar] var : Int32); end", Def.new("foo", ["var".arg(restriction: "Int32".path, annotations: ["Foo".ann, "Bar".ann])]) + it_parses "def foo(@[Foo] &@block); end", Def.new("foo", body: Assign.new("@block".instance_var, "block".var), block_arg: "block".arg(annotations: ["Foo".ann]), yields: 0) + it_parses "def foo(@[Foo] *args); end", Def.new("foo", args: ["args".arg(annotations: ["Foo".ann])], splat_index: 0) + it_parses "def foo(@[Foo] **args); end", Def.new("foo", double_splat: "args".arg(annotations: ["Foo".ann])) + it_parses <<-CR, Def.new("foo", ["id".arg(restriction: "Int32".path, annotations: ["Foo".ann]), "name".arg(restriction: "String".path, annotations: ["Bar".ann])]) + def foo( + @[Foo] + id : Int32, + @[Bar] name : String + ); end + CR + it_parses "def foo(\n&block\n); end", Def.new("foo", block_arg: Arg.new("block"), yields: 0) it_parses "def foo(&block :\n Int ->); end", Def.new("foo", block_arg: Arg.new("block", restriction: ProcNotation.new(["Int".path] of ASTNode)), yields: 1) it_parses "def foo(&block : Int ->\n); end", Def.new("foo", block_arg: Arg.new("block", restriction: ProcNotation.new(["Int".path] of ASTNode)), yields: 1) @@ -983,6 +1003,22 @@ module Crystal it_parses "macro foo;bar(end: 1);end", Macro.new("foo", body: Expressions.from(["bar(".macro_literal, "end: 1);".macro_literal] of ASTNode)) it_parses "def foo;bar(end: 1);end", Def.new("foo", body: Expressions.from([Call.new(nil, "bar", named_args: [NamedArgument.new("end", 1.int32)])] of ASTNode)) + # Macros with annotated parameters + it_parses "macro foo(@[Foo] var);end", Macro.new("foo", ["var".arg(annotations: ["Foo".ann])], Expressions.new) + it_parses "macro foo(@[Foo] outer inner);end", Macro.new("foo", ["inner".arg(annotations: ["Foo".ann], external_name: "outer")], Expressions.new) + it_parses "macro foo(@[Foo] var);end", Macro.new("foo", ["var".arg(annotations: ["Foo".ann])], Expressions.new) + it_parses "macro foo(a, @[Foo] var);end", Macro.new("foo", ["a".arg, "var".arg(annotations: ["Foo".ann])], Expressions.new) + it_parses "macro foo(a, @[Foo] &block);end", Macro.new("foo", ["a".arg], Expressions.new, block_arg: "block".arg(annotations: ["Foo".ann])) + it_parses "macro foo(@[Foo] *args);end", Macro.new("foo", ["args".arg(annotations: ["Foo".ann])], Expressions.new, splat_index: 0) + it_parses "macro foo(@[Foo] **args);end", Macro.new("foo", body: Expressions.new, double_splat: "args".arg(annotations: ["Foo".ann])) + it_parses <<-CR, Macro.new("foo", ["id".arg(annotations: ["Foo".ann]), "name".arg(annotations: ["Bar".ann])], Expressions.new) + macro foo( + @[Foo] + id, + @[Bar] name + );end + CR + assert_syntax_error "macro foo; {% foo = 1 }; end" assert_syntax_error "macro def foo : String; 1; end" diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index a5072cdcebca..4e72464f22c2 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -70,7 +70,9 @@ describe "ASTNode#to_s" do expect_to_s %[1 & 2 & (3 | 4)], %[(1 & 2) & (3 | 4)] expect_to_s %[(1 & 2) & (3 | 4)] expect_to_s "def foo(x : T = 1)\nend" + expect_to_s "def foo(@[Foo] x : T = 1)\nend" expect_to_s "def foo(x : X, y : Y) forall X, Y\nend" + expect_to_s "def foo(x : X, @[Foo] y : Y) forall X, Y\nend" expect_to_s %(foo : A | (B -> C)) expect_to_s %(foo : (A | B).class) expect_to_s %[%("\#{foo}")], %["\\"\#{foo}\\""] @@ -82,22 +84,39 @@ describe "ASTNode#to_s" do expect_to_s "_foo.bar" expect_to_s "1.responds_to?(:to_s)" expect_to_s "1.responds_to?(:\"&&\")" + expect_to_s "macro foo(&block)\nend" + expect_to_s "macro foo(&)\nend" + expect_to_s "macro foo(*, __var var)\nend" + expect_to_s "macro foo(*, var)\nend" + expect_to_s "macro foo(*var)\nend" + expect_to_s "macro foo(@[Foo] &)\nend" + expect_to_s "macro foo(@[Foo] &block)\nend" expect_to_s "macro foo(x, *y)\nend" + expect_to_s "macro foo(x, @[Foo] *y)\nend" + expect_to_s "macro foo(@[Foo] x, @[Foo] *y)\nend" expect_to_s "{ {1, 2, 3} }" expect_to_s "{ {1 => 2} }" expect_to_s "{ {1, 2, 3} => 4 }" expect_to_s "{ {foo: 2} }" expect_to_s "def foo(*args)\nend" + expect_to_s "def foo(@[Foo] *args)\nend" expect_to_s "def foo(*args : _)\nend" expect_to_s "def foo(**args)\nend" + expect_to_s "def foo(@[Foo] **args)\nend" expect_to_s "def foo(**args : T)\nend" expect_to_s "def foo(x, **args)\nend" + expect_to_s "def foo(x, @[Foo] **args)\nend" expect_to_s "def foo(x, **args, &block)\nend" + expect_to_s "def foo(@[Foo] x, @[Bar] **args, @[Baz] &block)\nend" expect_to_s "def foo(x, **args, &block : (_ -> _))\nend" expect_to_s "def foo(& : (->))\nend" + expect_to_s "macro foo(@[Foo] id)\nend" expect_to_s "macro foo(**args)\nend" + expect_to_s "macro foo(@[Foo] **args)\nend" expect_to_s "macro foo(x, **args)\nend" + expect_to_s "macro foo(x, @[Foo] **args)\nend" expect_to_s "def foo(x y)\nend" + expect_to_s "def foo(@[Foo] x y)\nend" expect_to_s %(foo("bar baz": 2)) expect_to_s %(Foo("bar baz": Int32)) expect_to_s %(Foo()) diff --git a/spec/compiler/semantic/annotation_spec.cr b/spec/compiler/semantic/annotation_spec.cr index 2e6670b62f09..89dd5e734243 100644 --- a/spec/compiler/semantic/annotation_spec.cr +++ b/spec/compiler/semantic/annotation_spec.cr @@ -916,6 +916,78 @@ describe "Semantic: annotation" do {{ Child.superclass.annotation(Ann)[0] }} )) { int32 } end + + it "finds annotation on method arg" do + assert_type(%( + annotation Ann; end + + def foo( + @[Ann] foo : Int32 + ) + end + + {% if @top_level.methods.find(&.name.==("foo")).args.first.annotation(Ann) %} + 1 + {% else %} + 'a' + {% end %} + )) { int32 } + end + + it "finds annotation on method splat arg" do + assert_type(%( + annotation Ann; end + + def foo( + id : Int32, + @[Ann] *nums : Int32 + ) + end + + {% if @top_level.methods.find(&.name.==("foo")).args[1].annotation(Ann) %} + 1 + {% else %} + 'a' + {% end %} + )) { int32 } + end + + it "finds annotation on method double splat arg" do + assert_type(%( + annotation Ann; end + + def foo( + id : Int32, + @[Ann] **nums + ) + end + + {% if @top_level.methods.find(&.name.==("foo")).double_splat.annotation(Ann) %} + 1 + {% else %} + 'a' + {% end %} + )) { int32 } + end + + it "finds annotation on an restricted method block arg" do + assert_type(%( + annotation Ann; end + + def foo( + id : Int32, + @[Ann] &block : Int32 -> + ) + yield 10 + end + + {% if @top_level.methods.find(&.name.==("foo")).block_arg.annotation(Ann) %} + 1 + {% else %} + 'a' + {% end %} + )) { int32 } + end end it "errors when annotate instance variable in subclass" do diff --git a/spec/support/syntax.cr b/spec/support/syntax.cr index 67634180f565..5514fc1d4acf 100644 --- a/spec/support/syntax.cr +++ b/spec/support/syntax.cr @@ -54,8 +54,12 @@ class String Var.new self end - def arg(default_value = nil, restriction = nil, external_name = nil) - Arg.new self, default_value: default_value, restriction: restriction, external_name: external_name + def ann + Annotation.new path + end + + def arg(default_value = nil, restriction = nil, external_name = nil, annotations = nil) + Arg.new self, default_value: default_value, restriction: restriction, external_name: external_name, parsed_annotations: annotations end def call diff --git a/src/compiler/crystal/macros.cr b/src/compiler/crystal/macros.cr index c1123ddbc787..d4ce2d4b995c 100644 --- a/src/compiler/crystal/macros.cr +++ b/src/compiler/crystal/macros.cr @@ -1114,6 +1114,16 @@ module Crystal::Macros # A def argument. class Arg < ASTNode + # Returns the last `Annotation` with the given `type` + # attached to this arg or `NilLiteral` if there are none. + def annotation(type : TypeNode) : Annotation | NilLiteral + end + + # Returns an array of annotations with the given `type` + # attached to this arg, or an empty `ArrayLiteral` if there are none. + def annotations(type : TypeNode) : ArrayLiteral(Annotation) + end + # Returns the external name of this argument. # # For example, for `def write(to file)` returns `to`. diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 6ff11ad10ee5..4058f8590080 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -1342,6 +1342,16 @@ module Crystal interpret_check_args { default_value || Nop.new } when "restriction" interpret_check_args { restriction || Nop.new } + when "annotation" + fetch_annotation(self, method, args, named_args, block) do |type| + self.annotation(type) + end + when "annotations" + fetch_annotation(self, method, args, named_args, block) do |type| + annotations = self.annotations(type) + return ArrayLiteral.new if annotations.nil? + ArrayLiteral.map(annotations, &.itself) + end else super end diff --git a/src/compiler/crystal/semantic/ast.cr b/src/compiler/crystal/semantic/ast.cr index fd13848b1978..3eee08ca3537 100644 --- a/src/compiler/crystal/semantic/ast.cr +++ b/src/compiler/crystal/semantic/ast.cr @@ -121,6 +121,8 @@ module Crystal end class Arg + include Annotatable + def initialize(@name : String, @default_value : ASTNode? = nil, @restriction : ASTNode? = nil, external_name : String? = nil, @type : Type? = nil) @external_name = external_name || @name end diff --git a/src/compiler/crystal/semantic/to_s.cr b/src/compiler/crystal/semantic/to_s.cr index 38932f11aa05..8c91426bc379 100644 --- a/src/compiler/crystal/semantic/to_s.cr +++ b/src/compiler/crystal/semantic/to_s.cr @@ -3,6 +3,19 @@ require "../syntax/to_s" module Crystal class ToSVisitor def visit(node : Arg) + if parsed_annotations = node.parsed_annotations + parsed_annotations.each do |ann| + ann.accept self + @str << ' ' + end + end + + case @current_arg_type + when .splat? then @str << '*' + when .double_splat? then @str << "**" + when .block_arg? then @str << '&' + end + if node.external_name != node.name visit_named_arg_name(node.external_name) @str << ' ' @@ -24,6 +37,8 @@ module Crystal default_value.accept self end false + ensure + @current_arg_type = :none end def visit(node : Primitive) diff --git a/src/compiler/crystal/semantic/top_level_visitor.cr b/src/compiler/crystal/semantic/top_level_visitor.cr index 9a16767b8757..dd062f77a515 100644 --- a/src/compiler/crystal/semantic/top_level_visitor.cr +++ b/src/compiler/crystal/semantic/top_level_visitor.cr @@ -327,6 +327,10 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor node.doc ||= annotations_doc(annotations) check_ditto node, node.location + node.args.each &.accept self + node.double_splat.try &.accept self + node.block_arg.try &.accept self + node.set_type @program.nil if node.name == "finished" @@ -347,6 +351,16 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor false end + def visit(node : Arg) + if anns = node.parsed_annotations + process_annotations anns do |annotation_type, ann| + node.add_annotation annotation_type, ann + end + end + + false + end + def visit(node : Def) check_outside_exp node, "declare def" @@ -363,6 +377,10 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor node.doc ||= annotations_doc(annotations) check_ditto node, node.location + node.args.each &.accept self + node.double_splat.try &.accept self + node.block_arg.try &.accept self + is_instance_method = false target_type = case receiver = node.receiver diff --git a/src/compiler/crystal/syntax/ast.cr b/src/compiler/crystal/syntax/ast.cr index 50b6f5da6503..1d8da8f60652 100644 --- a/src/compiler/crystal/syntax/ast.cr +++ b/src/compiler/crystal/syntax/ast.cr @@ -989,8 +989,9 @@ module Crystal property default_value : ASTNode? property restriction : ASTNode? property doc : String? + property parsed_annotations : Array(Annotation)? - def initialize(@name : String, @default_value : ASTNode? = nil, @restriction : ASTNode? = nil, external_name : String? = nil) + def initialize(@name : String, @default_value : ASTNode? = nil, @restriction : ASTNode? = nil, external_name : String? = nil, @parsed_annotations : Array(Annotation)? = nil) @external_name = external_name || @name end @@ -1004,10 +1005,10 @@ module Crystal end def clone_without_location - Arg.new @name, @default_value.clone, @restriction.clone, @external_name.clone + Arg.new @name, @default_value.clone, @restriction.clone, @external_name.clone, @parsed_annotations.clone end - def_equals_and_hash name, default_value, restriction, external_name + def_equals_and_hash name, default_value, restriction, external_name, parsed_annotations end # The Proc notation in the type grammar: diff --git a/src/compiler/crystal/syntax/parser.cr b/src/compiler/crystal/syntax/parser.cr index 5be39462c8c8..08a04342a27d 100644 --- a/src/compiler/crystal/syntax/parser.cr +++ b/src/compiler/crystal/syntax/parser.cr @@ -3721,9 +3721,18 @@ module Crystal double_splat : Bool def parse_arg(args, extra_assigns, parentheses, found_default_value, found_splat, found_double_splat, allow_restrictions) + annotations = nil + + # Parse annotations first since they would be before any actual arg tokens. + # Do this in a loop to account for multiple annotations. + while @token.type.op_at_lsquare? + (annotations ||= Array(Annotation).new) << parse_annotation + skip_space_or_newline + end + if @token.type.op_amp? next_token_skip_space_or_newline - block_arg = parse_block_arg(extra_assigns) + block_arg = parse_block_arg(extra_assigns, annotations) skip_space_or_newline # When block_arg.name is empty, this is an anonymous parameter. # An anonymous parameter should not conflict other parameters names. @@ -3854,14 +3863,14 @@ module Crystal raise "BUG: arg_name is nil" unless arg_name - arg = Arg.new(arg_name, default_value, restriction, external_name: external_name).at(arg_location) + arg = Arg.new(arg_name, default_value, restriction, external_name: external_name, parsed_annotations: annotations).at(arg_location) args << arg push_var arg ArgExtras.new(nil, !!default_value, splat, !!double_splat) end - def parse_block_arg(extra_assigns) + def parse_block_arg(extra_assigns, annotations) name_location = @token.location if @token.type.op_rparen? || @token.type.newline? || @token.type.op_colon? @@ -3882,7 +3891,7 @@ module Crystal type_spec = parse_bare_proc_type end - block_arg = Arg.new(arg_name, restriction: type_spec).at(name_location) + block_arg = Arg.new(arg_name, restriction: type_spec, parsed_annotations: annotations).at(name_location) push_var block_arg diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index a4ab440fc8b5..a9a12a35c3ae 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -16,6 +16,14 @@ module Crystal class ToSVisitor < Visitor @str : IO @macro_expansion_pragmas : Hash(Int32, Array(Lexer::LocPragma))? + @current_arg_type : DefArgType = :none + + private enum DefArgType + NONE + SPLAT + DOUBLE_SPLAT + BLOCK_ARG + end def initialize(@str = IO::Memory.new, @macro_expansion_pragmas = nil, @emit_doc = false) @indent = 0 @@ -613,19 +621,19 @@ module Crystal printed_arg = false node.args.each_with_index do |arg, i| @str << ", " if printed_arg - @str << '*' if node.splat_index == i + @current_arg_type = :splat if node.splat_index == i arg.accept self printed_arg = true end if double_splat = node.double_splat + @current_arg_type = :double_splat @str << ", " if printed_arg - @str << "**" double_splat.accept self printed_arg = true end if block_arg = node.block_arg + @current_arg_type = :block_arg @str << ", " if printed_arg - @str << '&' block_arg.accept self printed_arg = true end @@ -659,19 +667,19 @@ module Crystal printed_arg = false node.args.each_with_index do |arg, i| @str << ", " if printed_arg - @str << '*' if i == node.splat_index + @current_arg_type = :splat if i == node.splat_index arg.accept self printed_arg = true end if double_splat = node.double_splat @str << ", " if printed_arg - @str << "**" + @current_arg_type = :double_splat double_splat.accept self printed_arg = true end if block_arg = node.block_arg @str << ", " if printed_arg - @str << '&' + @current_arg_type = :block_arg block_arg.accept self end @str << ')' @@ -771,6 +779,19 @@ module Crystal end def visit(node : Arg) + if parsed_annotations = node.parsed_annotations + parsed_annotations.each do |ann| + ann.accept self + @str << ' ' + end + end + + case @current_arg_type + when .splat? then @str << '*' + when .double_splat? then @str << "**" + when .block_arg? then @str << '&' + end + if node.external_name != node.name visit_named_arg_name(node.external_name) @str << ' ' @@ -789,6 +810,8 @@ module Crystal default_value.accept self end false + ensure + @current_arg_type = :none end def visit(node : ProcNotation)