From 92c677520e32d002dd367af84f5fe703f2da4b33 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 7 May 2018 22:11:56 -0300 Subject: [PATCH] Automatic cast of numbers to other numbers, and symbols to enums --- spec/compiler/codegen/automatic_cast.cr | 218 +++++++++++ spec/compiler/normalize/def_spec.cr | 12 +- spec/compiler/semantic/automatic_cast.cr | 370 ++++++++++++++++++ spec/compiler/semantic/def_spec.cr | 4 +- spec/compiler/semantic/uninitialized_spec.cr | 4 +- src/compiler/crystal/codegen/cast.cr | 33 ++ src/compiler/crystal/codegen/codegen.cr | 2 + src/compiler/crystal/semantic/ast.cr | 33 +- src/compiler/crystal/semantic/bindings.cr | 15 + src/compiler/crystal/semantic/call.cr | 91 +++-- src/compiler/crystal/semantic/call_error.cr | 9 +- .../class_vars_initializer_visitor.cr | 12 +- .../crystal/semantic/cleanup_transformer.cr | 4 + src/compiler/crystal/semantic/cover.cr | 8 + .../crystal/semantic/default_arguments.cr | 6 +- .../instance_vars_initializer_visitor.cr | 27 +- src/compiler/crystal/semantic/main_visitor.cr | 89 ++++- src/compiler/crystal/semantic/match.cr | 1 + .../crystal/semantic/method_lookup.cr | 6 +- src/compiler/crystal/semantic/restrictions.cr | 48 +++ src/compiler/crystal/semantic/to_s.cr | 8 +- src/compiler/crystal/semantic/transformer.cr | 2 +- src/compiler/crystal/syntax/ast.cr | 15 + src/compiler/crystal/types.cr | 85 ++++ src/llvm/lib_llvm.cr | 3 + src/llvm/value_methods.cr | 8 + 26 files changed, 1022 insertions(+), 91 deletions(-) create mode 100644 spec/compiler/codegen/automatic_cast.cr create mode 100644 spec/compiler/semantic/automatic_cast.cr diff --git a/spec/compiler/codegen/automatic_cast.cr b/spec/compiler/codegen/automatic_cast.cr new file mode 100644 index 000000000000..e1de666b33fc --- /dev/null +++ b/spec/compiler/codegen/automatic_cast.cr @@ -0,0 +1,218 @@ +require "../../spec_helper" + +describe "Code gen: automatic cast" do + it "casts literal integer (Int32 -> Int64)" do + run(%( + def foo(x : Int64) + x + end + + foo(12345) + )).to_i.should eq(12345) + end + + it "casts literal integer (Int64 -> Int32, ok)" do + run(%( + def foo(x : Int32) + x + end + + foo(2147483647_i64) + )).to_i.should eq(2147483647) + end + + it "casts literal integer (Int32 -> Float32)" do + run(%( + def foo(x : Float32) + x + end + + foo(12345).to_i + )).to_i.should eq(12345) + end + + it "casts literal integer (Int32 -> Float64)" do + run(%( + def foo(x : Float64) + x + end + + foo(12345).to_i + )).to_i.should eq(12345) + end + + it "casts literal float (Float32 -> Float64)" do + run(%( + def foo(x : Float64) + x + end + + foo(12345.0_f32).to_i + )).to_i.should eq(12345) + end + + it "casts literal float (Float64 -> Float32)" do + run(%( + def foo(x : Float32) + x + end + + foo(12345.0).to_i + )).to_i.should eq(12345) + end + + it "casts symbol literal to enum" do + run(%( + :four + + enum Foo + One + Two + Three + end + + def foo(x : Foo) + x + end + + foo(:three) + )).to_i.should eq(2) + end + + it "casts Int32 to Int64 in ivar assignment" do + run(%( + class Foo + @x : Int64 + + def initialize + @x = 10 + end + + def x + @x + end + end + + Foo.new.x + )).to_i.should eq(10) + end + + it "casts Symbol to Enum in ivar assignment" do + run(%( + enum E + One + Two + Three + end + + class Foo + @x : E + + def initialize + @x = :three + end + + def x + @x + end + end + + Foo.new.x + )).to_i.should eq(2) + end + + it "casts Int32 to Int64 in cvar assignment" do + run(%( + class Foo + @@x : Int64 = 0_i64 + + def self.x + @@x = 10 + @@x + end + end + + Foo.x + )).to_i.should eq(10) + end + + it "casts Int32 to Int64 in lvar assignment" do + run(%( + x : Int64 + x = 123 + x + )).to_i.should eq(123) + end + + it "casts Int32 to Int64 in ivar type declaration" do + run(%( + class Foo + @x : Int64 = 10 + + def x + @x + end + end + + Foo.new.x + )).to_i.should eq(10) + end + + it "casts Symbol to Enum in ivar type declaration" do + run(%( + enum Color + Red + Green + Blue + end + + class Foo + @x : Color = :blue + + def x + @x + end + end + + Foo.new.x + )).to_i.should eq(2) + end + + it "casts Int32 to Int64 in cvar type declaration" do + run(%( + class Foo + @@x : Int64 = 10 + + def self.x + @@x + end + end + + Foo.x + )).to_i.should eq(10) + end + + it "casts Int32 -> Int64 in arg restriction" do + run(%( + def foo(x : Int64 = 123) + x + end + + foo + )).to_i.should eq(123) + end + + it "casts Int32 to Int64 in ivar type declaration in generic" do + run(%( + class Foo(T) + @x : T = 10 + + def x + @x + end + end + + Foo(Int64).new.x + )).to_i.should eq(10) + end +end diff --git a/spec/compiler/normalize/def_spec.cr b/spec/compiler/normalize/def_spec.cr index 95c96c8224f7..625fd6bdb892 100644 --- a/spec/compiler/normalize/def_spec.cr +++ b/spec/compiler/normalize/def_spec.cr @@ -33,8 +33,11 @@ describe "Normalize: def" do a_def = parse("def foo(x, y : Int32 = 1, z : Int64 = 2i64); x + y + z; end").as(Def) actual = a_def.expand_default_arguments(Program.new, 1) expected = parse("def foo(x); y = 1; z = 2i64; x + y + z; end").as(Def) - expected.body.as(Expressions).expressions.insert 1, TypeRestriction.new Var.new("y"), Path.new(["Int32"]) - expected.body.as(Expressions).expressions.insert 3, TypeRestriction.new Var.new("z"), Path.new(["Int64"]) + + exps = expected.body.as(Expressions).expressions + exps[0] = AssignWithRestriction.new(exps[0].as(Assign), Path.new("Int32")) + exps[1] = AssignWithRestriction.new(exps[1].as(Assign), Path.new("Int64")) + actual.should eq(expected) end @@ -42,7 +45,10 @@ describe "Normalize: def" do a_def = parse("def foo(x, y : Int32 = 1, z : Int64 = 2i64); x + y + z; end").as(Def) actual = a_def.expand_default_arguments(Program.new, 2) expected = parse("def foo(x, y : Int32); z = 2i64; x + y + z; end").as(Def) - expected.body.as(Expressions).expressions.insert 1, TypeRestriction.new Var.new("z"), Path.new(["Int64"]) + + exps = expected.body.as(Expressions).expressions + exps[0] = AssignWithRestriction.new(exps[0].as(Assign), Path.new("Int64")) + actual.should eq(expected) end diff --git a/spec/compiler/semantic/automatic_cast.cr b/spec/compiler/semantic/automatic_cast.cr new file mode 100644 index 000000000000..7be5d32cb5d5 --- /dev/null +++ b/spec/compiler/semantic/automatic_cast.cr @@ -0,0 +1,370 @@ +require "../../spec_helper" + +describe "Semantic: automatic cast" do + it "casts literal integer (Int32 -> no restriction)" do + assert_type(%( + def foo(x) + x + 1 + end + + foo(12345) + ), inject_primitives: true) { int32 } + end + + it "casts literal integer (Int32 -> Int64)" do + assert_type(%( + def foo(x : Int64) + x + end + + foo(12345) + )) { int64 } + end + + it "casts literal integer (Int64 -> Int32, ok)" do + assert_type(%( + def foo(x : Int32) + x + end + + foo(2147483647_i64) + )) { int32 } + end + + it "casts literal integer (Int64 -> Int32, too big)" do + assert_error %( + def foo(x : Int32) + x + end + + foo(2147483648_i64) + ), + "no overload matches" + end + + it "casts literal integer (Int32 -> Float32)" do + assert_type(%( + def foo(x : Float32) + x + end + + foo(12345) + )) { float32 } + end + + it "casts literal integer (Int32 -> Float64)" do + assert_type(%( + def foo(x : Float64) + x + end + + foo(12345) + )) { float64 } + end + + it "casts literal float (Float32 -> Float64)" do + assert_type(%( + def foo(x : Float64) + x + end + + foo(1.23_f32) + )) { float64 } + end + + it "casts literal float (Float64 -> Float32)" do + assert_type(%( + def foo(x : Float32) + x + end + + foo(1.23) + )) { float32 } + end + + it "matches correct overload" do + assert_type(%( + def foo(x : Int32) + x + end + + def foo(x : Int64) + x + end + + foo(1_i64) + )) { int64 } + end + + it "casts literal integer through alias with union" do + assert_type(%( + alias A = Int64 | String + + def foo(x : A) + x + end + + foo(12345) + )) { int64 } + end + + it "says ambiguous call for integer" do + assert_error %( + def foo(x : Int8) + x + end + + def foo(x : UInt8) + x + end + + foo(1) + ), + "ambiguous" + end + + it "says ambiguous call for integer (2)" do + assert_error %( + def foo(x : Int8 | UInt8) + x + end + + foo(1) + ), + "ambiguous" + end + + it "casts symbol literal to enum" do + assert_type(%( + enum Foo + One + Two + Three + end + + def foo(x : Foo) + x + end + + foo(:one) + )) { types["Foo"] } + end + + it "casts literal integer through alias with union" do + assert_type(%( + enum Foo + One + Two + end + + alias A = Foo | String + + def foo(x : A) + x + end + + foo(:two) + )) { types["Foo"] } + end + + it "errors if symbol name doesn't match enum member" do + assert_error %( + enum Foo + One + Two + Three + end + + def foo(x : Foo) + x + end + + foo(:four) + ), + "no overload matches" + end + + it "says ambiguous call for symbol" do + assert_error %( + enum Foo + One + Two + Three + end + + enum Foo2 + One + Two + Three + end + + def foo(x : Foo) + x + end + + def foo(x : Foo2) + x + end + + foo(:one) + ), + "ambiguous" + end + + it "casts Int32 to Int64 in ivar assignment" do + assert_type(%( + class Foo + @x : Int64 + + def initialize + @x = 10 + end + + def x + @x + end + end + + Foo.new.x + )) { int64 } + end + + it "casts Symbol to Enum in ivar assignment" do + assert_type(%( + enum E + One + Two + Three + end + + class Foo + @x : E + + def initialize + @x = :two + end + + def x + @x + end + end + + Foo.new.x + )) { types["E"] } + end + + it "casts Int32 to Int64 in cvar assignment" do + assert_type(%( + class Foo + @@x : Int64 = 0_i64 + + def self.x + @@x = 10 + @@x + end + end + + Foo.x + )) { int64 } + end + + it "casts Int32 to Int64 in lvar assignment" do + assert_type(%( + x : Int64 + x = 123 + x + )) { int64 } + end + + it "casts Int32 to Int64 in ivar type declaration" do + assert_type(%( + class Foo + @x : Int64 = 10 + + def x + @x + end + end + + Foo.new.x + )) { int64 } + end + + it "casts Symbol to Enum in ivar type declaration" do + assert_type(%( + enum Color + Red + Green + Blue + end + + class Foo + @x : Color = :red + + def x + @x + end + end + + Foo.new.x + )) { types["Color"] } + end + + it "casts Int32 to Int64 in cvar type declaration" do + assert_type(%( + class Foo + @@x : Int64 = 10 + + def self.x + @@x + end + end + + Foo.x + )) { int64 } + end + + it "casts Symbol to Enum in cvar type declaration" do + assert_type(%( + enum Color + Red + Green + Blue + end + + class Foo + @@x : Color = :red + + def self.x + @@x + end + end + + Foo.x + )) { types["Color"] } + end + + it "casts Int32 -> Int64 in arg restriction" do + assert_type(%( + def foo(x : Int64 = 0) + x + end + + foo + )) { int64 } + end + + it "casts Int32 to Int64 in ivar type declaration in generic" do + assert_type(%( + class Foo(T) + @x : T = 10 + + def x + @x + end + end + + Foo(Int64).new.x + )) { int64 } + end +end diff --git a/spec/compiler/semantic/def_spec.cr b/spec/compiler/semantic/def_spec.cr index ee961481fcfc..c5e6d4588dee 100644 --- a/spec/compiler/semantic/def_spec.cr +++ b/spec/compiler/semantic/def_spec.cr @@ -173,12 +173,12 @@ describe "Semantic: def" do it "errors when default value is incompatible with type restriction" do assert_error " - def foo(x : Int64 = 1) + def foo(x : Int64 = 'a') end foo ", - "can't restrict Int32 to Int64" + "can't restrict Char to Int64" end it "types call with global scope" do diff --git a/spec/compiler/semantic/uninitialized_spec.cr b/spec/compiler/semantic/uninitialized_spec.cr index fb9dc0b0e993..36de678b392c 100644 --- a/spec/compiler/semantic/uninitialized_spec.cr +++ b/spec/compiler/semantic/uninitialized_spec.cr @@ -58,9 +58,9 @@ describe "Semantic: uninitialized" do it "errors if declares var and then assigns other type" do assert_error %( x = uninitialized Int32 - x = 1_i64 + x = 'a' ), - "type must be Int32, not (Int32 | Int64)" + "type must be Int32, not (Char | Int32)" end it "errors if declaring variable multiple times with different types (#917)" do diff --git a/src/compiler/crystal/codegen/cast.cr b/src/compiler/crystal/codegen/cast.cr index 1460ebea5647..f441d9be399c 100644 --- a/src/compiler/crystal/codegen/cast.cr +++ b/src/compiler/crystal/codegen/cast.cr @@ -488,6 +488,39 @@ class Crystal::CodeGenVisitor target_pointer end + # This is the case of the automatic cast between integer types + def downcast_distinct(value, to_type : IntegerType, from_type : IntegerType) + codegen_cast(from_type, to_type, value) + end + + # This is the case of the automatic cast between integer type and float type + def downcast_distinct(value, to_type : FloatType, from_type : IntegerType) + codegen_cast(from_type, to_type, value) + end + + # This is the case of the automatic cast between float types + def downcast_distinct(value, to_type : FloatType, from_type : FloatType) + codegen_cast(from_type, to_type, value) + end + + # This is the case of the automatic cast between symbol and enum + def downcast_distinct(value, to_type : EnumType, from_type : SymbolType) + # value has the value of the symbol inside the symbol table, + # so we first get which symbol name that is, and then match + # it to one of the enum members + index = value.const_int_get_sext_value + symbol = @symbols_by_index[index].underscore + + to_type.types.each do |name, value| + if name.underscore == symbol + accept(value.as(Const).value) + return @last + end + end + + raise "Bug: expected to find enum member of #{to_type} matching symbol #{symbol}" + end + def downcast_distinct(value, to_type : Type, from_type : Type) raise "BUG: trying to downcast #{to_type} <- #{from_type}" end diff --git a/src/compiler/crystal/codegen/codegen.cr b/src/compiler/crystal/codegen/codegen.cr index 25287dad1723..aa7197cf582f 100644 --- a/src/compiler/crystal/codegen/codegen.cr +++ b/src/compiler/crystal/codegen/codegen.cr @@ -179,9 +179,11 @@ module Crystal @in_lib = false @strings = {} of StringKey => LLVM::Value @symbols = {} of String => Int32 + @symbols_by_index = [] of String @symbol_table_values = [] of LLVM::Value program.symbols.each_with_index do |sym, index| @symbols[sym] = index + @symbols_by_index << sym @symbol_table_values << build_string_constant(sym, sym) end diff --git a/src/compiler/crystal/semantic/ast.cr b/src/compiler/crystal/semantic/ast.cr index 102eb8b49b38..a5b5b5ed735c 100644 --- a/src/compiler/crystal/semantic/ast.cr +++ b/src/compiler/crystal/semantic/ast.cr @@ -77,21 +77,20 @@ module Crystal def_equals_and_hash type end - # Fictitious node to represent a type restriction - # - # It is used for type restrection of method arguments. - class TypeRestriction < ASTNode - getter obj - getter to + # Fictitious node to represent an assignment with a type restriction, + # created to match the assignment of a method argument's default value. + class AssignWithRestriction < ASTNode + property assign + property restriction - def initialize(@obj : ASTNode, @to : ASTNode) + def initialize(@assign : Assign, @restriction : ASTNode) end def clone_without_location - TypeRestriction.new @obj.clone, @to.clone + AssignWithRestriction.new @assign.clone, @restriction.clone end - def_equals_and_hash obj, to + def_equals_and_hash assign, restriction end class Arg @@ -734,4 +733,20 @@ module Crystal self end end + + class NumberLiteral + def can_be_autocast_to?(other_type) + case {self.type, other_type} + when {IntegerType, IntegerType} + min, max = other_type.range + min <= integer_value <= max + when {IntegerType, FloatType} + true + when {FloatType, FloatType} + true + else + false + end + end + end end diff --git a/src/compiler/crystal/semantic/bindings.cr b/src/compiler/crystal/semantic/bindings.cr index 247f3c7d8e38..13b50219ca90 100644 --- a/src/compiler/crystal/semantic/bindings.cr +++ b/src/compiler/crystal/semantic/bindings.cr @@ -17,6 +17,21 @@ module Crystal @type end + def type(*, with_literals = false) + type = self.type + + if with_literals + case self + when NumberLiteral + return NumberLiteralType.new(type.program, self) + when SymbolLiteral + return SymbolLiteralType.new(type.program, self) + end + end + + type + end + def set_type(type : Type) type = type.remove_alias_if_simple if !type.no_return? && (freeze_type = @freeze_type) && !type.implements?(freeze_type) diff --git a/src/compiler/crystal/semantic/call.cr b/src/compiler/crystal/semantic/call.cr index b1297d721941..ca2cccef1321 100644 --- a/src/compiler/crystal/semantic/call.cr +++ b/src/compiler/crystal/semantic/call.cr @@ -13,6 +13,9 @@ class Crystal::Call property? uses_with_scope = false getter? raises = false + class RetryLookupWithLiterals < ::Exception + end + def program scope.program end @@ -93,16 +96,22 @@ class Crystal::Call end def lookup_matches + lookup_matches(with_literals: false) + rescue ex : RetryLookupWithLiterals + lookup_matches(with_literals: true) + end + + def lookup_matches(*, with_literals = false) if args.any? { |arg| arg.is_a?(Splat) || arg.is_a?(DoubleSplat) } - lookup_matches_with_splat + lookup_matches_with_splat(with_literals) else - arg_types = args.map(&.type) - named_args_types = NamedArgumentType.from_args(named_args) - lookup_matches_without_splat arg_types, named_args_types + arg_types = args.map(&.type(with_literals: with_literals)) + named_args_types = NamedArgumentType.from_args(named_args, with_literals) + lookup_matches_without_splat arg_types, named_args_types, with_literals end end - def lookup_matches_with_splat + def lookup_matches_with_splat(with_literals) # Check if all splat are of tuples arg_types = Array(Type).new(args.size * 2) named_args_types = nil @@ -133,7 +142,7 @@ class Crystal::Call arg.raise "argument to double splat must be a named tuple, not #{arg_type}" end else - arg_types << arg.type + arg_types << arg.type(with_literals: with_literals) end end @@ -143,66 +152,69 @@ class Crystal::Call named_args_types ||= [] of NamedArgumentType named_args.each do |named_arg| raise "duplicate key: #{named_arg.name}" if named_args_types.any? &.name.==(named_arg.name) - named_args_types << NamedArgumentType.new(named_arg.name, named_arg.value.type) + named_args_types << NamedArgumentType.new( + named_arg.name, + named_arg.value.type(with_literals: with_literals), + ) end end - lookup_matches_without_splat arg_types, named_args_types + lookup_matches_without_splat arg_types, named_args_types, with_literals: with_literals end - def lookup_matches_without_splat(arg_types, named_args_types) + def lookup_matches_without_splat(arg_types, named_args_types, with_literals) if obj = @obj - lookup_matches_in(obj.type, arg_types, named_args_types) + lookup_matches_in(obj.type, arg_types, named_args_types, with_literals: with_literals) elsif name == "super" - lookup_super_matches(arg_types, named_args_types) + lookup_super_matches(arg_types, named_args_types, with_literals: with_literals) elsif name == "previous_def" - lookup_previous_def_matches(arg_types, named_args_types) + lookup_previous_def_matches(arg_types, named_args_types, with_literals: with_literals) elsif with_scope = @with_scope - lookup_matches_with_scope_in with_scope, arg_types, named_args_types + lookup_matches_with_scope_in with_scope, arg_types, named_args_types, with_literals: with_literals else - lookup_matches_in scope, arg_types, named_args_types + lookup_matches_in scope, arg_types, named_args_types, with_literals: with_literals end end - def lookup_matches_in(owner : AliasType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) - lookup_matches_in(owner.remove_alias, arg_types, named_args_types, search_in_parents: search_in_parents) + def lookup_matches_in(owner : AliasType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) + lookup_matches_in(owner.remove_alias, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals) end - def lookup_matches_in(owner : UnionType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) - owner.union_types.flat_map { |type| lookup_matches_in(type, arg_types, named_args_types, search_in_parents: search_in_parents) } + def lookup_matches_in(owner : UnionType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) + owner.union_types.flat_map { |type| lookup_matches_in(type, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals) } end - def lookup_matches_in(owner : Program, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) - lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents) + def lookup_matches_in(owner : Program, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) + lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents: search_in_parents, with_literals: with_literals) end - def lookup_matches_in(owner : FileModule, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) - lookup_matches_in program, arg_types, named_args_types, search_in_parents: search_in_parents + def lookup_matches_in(owner : FileModule, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) + lookup_matches_in program, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals end - def lookup_matches_in(owner : NonGenericModuleType | GenericModuleInstanceType | GenericType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) + def lookup_matches_in(owner : NonGenericModuleType | GenericModuleInstanceType | GenericType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) attach_subclass_observer owner including_types = owner.including_types if including_types - lookup_matches_in(including_types, arg_types, named_args_types, search_in_parents: search_in_parents) + lookup_matches_in(including_types, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals) else [] of Def end end - def lookup_matches_in(owner : LibType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) + def lookup_matches_in(owner : LibType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) raise "lib fun call is not supported in dispatch" end - def lookup_matches_in(owner : Type, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true) - lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents) + def lookup_matches_in(owner : Type, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false) + lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents: search_in_parents, with_literals: with_literals) end - def lookup_matches_with_scope_in(owner, arg_types, named_args_types) + def lookup_matches_with_scope_in(owner, arg_types, named_args_types, with_literals = false) signature = CallSignature.new(name, arg_types, block, named_args_types) - matches = lookup_matches_checking_expansion(owner, signature) + matches = lookup_matches_checking_expansion(owner, signature, with_literals: with_literals) if matches.empty? && owner.class? && owner.abstract? matches = owner.virtual_type.lookup_matches(signature) @@ -210,14 +222,14 @@ class Crystal::Call if matches.empty? @uses_with_scope = false - return lookup_matches_in scope, arg_types, named_args_types + return lookup_matches_in scope, arg_types, named_args_types, with_literals: with_literals end @uses_with_scope = true instantiate matches, owner, self_type: nil end - def lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents, search_in_toplevel = true) + def lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents, search_in_toplevel = true, with_literals = false) signature = CallSignature.new(def_name, arg_types, block, named_args_types) matches = check_tuple_indexer(owner, def_name, args, arg_types) @@ -254,7 +266,7 @@ class Crystal::Call # compile errors, which will anyway appear once you add concrete # subclasses and instances. if def_name == "new" || !(!owner.metaclass? && owner.abstract? && (owner.leaf? || owner.is_a?(GenericClassInstanceType))) - raise_matches_not_found(matches.owner || owner, def_name, arg_types, named_args_types, matches) + raise_matches_not_found(matches.owner || owner, def_name, arg_types, named_args_types, matches, with_literals: with_literals) end end @@ -271,7 +283,7 @@ class Crystal::Call instantiate matches, owner, self_type end - def lookup_matches_checking_expansion(owner, signature, search_in_parents = true) + def lookup_matches_checking_expansion(owner, signature, search_in_parents = true, with_literals = false) # If this call is an expansion (because of default or named args) we must # resolve the call in the type that defined the original method, without # triggering a virtual lookup. But the context of lookup must be preseved. @@ -460,6 +472,7 @@ class Crystal::Call in_bounds = (0 <= index < instance_type.size) if nilable || in_bounds indexer_def = yield instance_type, (in_bounds ? index : -1) + arg_types.map!(&.remove_literal) indexer_match = Match.new(indexer_def, arg_types, MatchContext.new(owner, owner)) return Matches.new([indexer_match] of Match, true) elsif instance_type.size == 0 @@ -483,6 +496,7 @@ class Crystal::Call index = instance_type.name_index(name) if index || nilable indexer_def = yield instance_type, (index || -1) + arg_types.map!(&.remove_literal) indexer_match = Match.new(indexer_def, arg_types, MatchContext.new(owner, owner)) return Matches.new([indexer_match] of Match, true) else @@ -554,7 +568,7 @@ class Crystal::Call end end - def lookup_super_matches(arg_types, named_args_types) + def lookup_super_matches(arg_types, named_args_types, with_literals) if scope.is_a?(Program) raise "there's no superclass in this scope" end @@ -592,16 +606,16 @@ class Crystal::Call if parents && parents.size > 0 parents.each_with_index do |parent, i| if parent.lookup_first_def(enclosing_def.name, block) - return lookup_matches_in_type(parent, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false) + return lookup_matches_in_type(parent, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false, with_literals: with_literals) end end - lookup_matches_in_type(parents.last, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false) + lookup_matches_in_type(parents.last, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false, with_literals: with_literals) else raise "there's no superclass in this scope" end end - def lookup_previous_def_matches(arg_types, named_args_types) + def lookup_previous_def_matches(arg_types, named_args_types, with_literals) enclosing_def = enclosing_def("previous_def") previous_item = enclosing_def.previous @@ -613,11 +627,12 @@ class Crystal::Call signature = CallSignature.new(previous.name, arg_types, block, named_args_types) context = MatchContext.new(scope, scope, def_free_vars: previous.free_vars) + arg_types.map!(&.remove_literal) match = Match.new(previous, arg_types, context, named_args_types) matches = Matches.new([match] of Match, true) unless signature.match(previous_item, context) - raise_matches_not_found scope, previous.name, arg_types, named_args_types, matches + raise_matches_not_found scope, previous.name, arg_types, named_args_types, matches, with_literals: with_literals end unless scope.is_a?(Program) diff --git a/src/compiler/crystal/semantic/call_error.cr b/src/compiler/crystal/semantic/call_error.cr index 57a2ca412101..a3aedd440d8c 100644 --- a/src/compiler/crystal/semantic/call_error.cr +++ b/src/compiler/crystal/semantic/call_error.cr @@ -37,7 +37,7 @@ class Crystal::Path end class Crystal::Call - def raise_matches_not_found(owner, def_name, arg_types, named_args_types, matches = nil) + def raise_matches_not_found(owner, def_name, arg_types, named_args_types, matches = nil, with_literals = false) # Special case: Foo+:Class#new if owner.is_a?(VirtualMetaclassType) && def_name == "new" raise_matches_not_found_for_virtual_metaclass_new owner @@ -212,6 +212,13 @@ class Crystal::Call end end + # If we made a lookup without the special rule for literals, + # and we have literals in the call, try again with that special rule. + if with_literals == false && (args.any? { |arg| arg.is_a?(NumberLiteral) || arg.is_a?(SymbolLiteral) } || + named_args.try &.any? { |arg| arg.value.is_a?(NumberLiteral) || arg.value.is_a?(SymbolLiteral) }) + ::raise RetryLookupWithLiterals.new + end + if args.size == 1 && args.first.type.includes_type?(program.nil) owner_trace = args.first.find_owner_trace(program, program.nil) end diff --git a/src/compiler/crystal/semantic/class_vars_initializer_visitor.cr b/src/compiler/crystal/semantic/class_vars_initializer_visitor.cr index 05b786c919e8..8dfd4f09fe8c 100644 --- a/src/compiler/crystal/semantic/class_vars_initializer_visitor.cr +++ b/src/compiler/crystal/semantic/class_vars_initializer_visitor.cr @@ -68,7 +68,17 @@ module Crystal end main_visitor.pushing_type(owner.as(ModuleType)) do - node.accept main_visitor + # Check if we can autocast + if (node.is_a?(NumberLiteral) || node.is_a?(SymbolLiteral)) && + (class_var_type = class_var.type?) + cloned_node = node.clone + cloned_node.accept MainVisitor.new(self) + if casted_value = MainVisitor.check_automatic_cast(cloned_node, class_var_type) + node = initializer.node = casted_value + end + end + + node.accept main_visitor unless node.type? end unless had_class_var diff --git a/src/compiler/crystal/semantic/cleanup_transformer.cr b/src/compiler/crystal/semantic/cleanup_transformer.cr index a9b75ec4b2e2..e108c12f1276 100644 --- a/src/compiler/crystal/semantic/cleanup_transformer.cr +++ b/src/compiler/crystal/semantic/cleanup_transformer.cr @@ -753,6 +753,10 @@ module Crystal node end + def transform(node : AssignWithRestriction) + transform(node.assign) + end + @false_literal : BoolLiteral? def false_literal diff --git a/src/compiler/crystal/semantic/cover.cr b/src/compiler/crystal/semantic/cover.cr index ee5cb10776e5..228103c1172f 100644 --- a/src/compiler/crystal/semantic/cover.cr +++ b/src/compiler/crystal/semantic/cover.cr @@ -276,4 +276,12 @@ module Crystal class AliasType delegate cover, cover_size, to: aliased_type end + + class NumberLiteralType + delegate cover, cover_size, to: (@matched_type || literal.type) + end + + class SymbolLiteralType + delegate cover, cover_size, to: (@matched_type || literal.type) + end end diff --git a/src/compiler/crystal/semantic/default_arguments.cr b/src/compiler/crystal/semantic/default_arguments.cr index a04c1602fb6d..4ab90cb08cf2 100644 --- a/src/compiler/crystal/semantic/default_arguments.cr +++ b/src/compiler/crystal/semantic/default_arguments.cr @@ -121,11 +121,13 @@ class Crystal::Def if default_value.is_a?(MagicConstant) expansion.args.push arg.clone else - new_body << Assign.new(Var.new(arg.name).at(arg), default_value).at(arg) + assign = Assign.new(Var.new(arg.name).at(arg), default_value).at(arg) if restriction = arg.restriction - new_body << TypeRestriction.new(Var.new(arg.name).at(arg), restriction).at(arg) + assign = AssignWithRestriction.new(assign, restriction) end + + new_body << assign end end end diff --git a/src/compiler/crystal/semantic/instance_vars_initializer_visitor.cr b/src/compiler/crystal/semantic/instance_vars_initializer_visitor.cr index 7b144c2f8788..227242bd8b98 100644 --- a/src/compiler/crystal/semantic/instance_vars_initializer_visitor.cr +++ b/src/compiler/crystal/semantic/instance_vars_initializer_visitor.cr @@ -70,6 +70,8 @@ class Crystal::InstanceVarsInitializerVisitor < Crystal::SemanticVisitor end def finish + scope_initializers = [] of InstanceVarInitializerContainer::InstanceVarInitializer? + # First declare them, so when we type all of them we will have # the info of which instance vars have initializers (so they are not nil) initializers.each do |i| @@ -78,18 +80,31 @@ class Crystal::InstanceVarsInitializerVisitor < Crystal::SemanticVisitor program.undefined_instance_variable(i.target, scope, nil) end - scope.add_instance_var_initializer(i.target.name, i.value, scope.is_a?(GenericType) ? nil : i.meta_vars) + scope_initializers << + scope.add_instance_var_initializer(i.target.name, i.value, scope.is_a?(GenericType) ? nil : i.meta_vars) end # Now type them - initializers.each do |i| + initializers.each_with_index do |i, index| scope = i.scope + value = i.value - unless scope.is_a?(GenericType) - ivar_visitor = MainVisitor.new(program, meta_vars: i.meta_vars) - ivar_visitor.scope = scope - i.value.accept ivar_visitor + next if scope.is_a?(GenericType) + + # Check if we can autocast + if (value.is_a?(NumberLiteral) || value.is_a?(SymbolLiteral)) && + (scope_initializer = scope_initializers[index]) + cloned_value = value.clone + cloned_value.accept MainVisitor.new(program) + if casted_value = MainVisitor.check_automatic_cast(cloned_value, scope.lookup_instance_var(i.target.name).type) + scope_initializer.value = casted_value + next + end end + + ivar_visitor = MainVisitor.new(program, meta_vars: i.meta_vars) + ivar_visitor.scope = scope + value.accept ivar_visitor end end end diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index f43974280f6d..8e2f02ec4a96 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -741,19 +741,40 @@ module Crystal false end - def type_assign(target : Var, value, node) + def type_assign(target : Var, value, node, restriction = nil) value.accept self + var_name = target.name + meta_var = (@meta_vars[var_name] ||= new_meta_var(var_name)) + + if freeze_type = meta_var.freeze_type + if casted_value = check_automatic_cast(value, freeze_type, node) + value = casted_value + end + end + + # If this assign comes from a AssignWithRestriction node, check the restriction + + if restriction && (value_type = value.type?) + if value_type.restrict(restriction, match_context.not_nil!) + # OK + else + # Check autocast too + restriction_type = scope.lookup_type(restriction, free_vars: free_vars) + if casted_value = check_automatic_cast(value, restriction_type, node) + value = casted_value + else + node.raise "can't restrict #{value.type} to #{restriction}" + end + end + end + target.bind_to value node.bind_to value - var_name = target.name - value_type_filters = @type_filters @type_filters = nil - meta_var = (@meta_vars[var_name] ||= new_meta_var(var_name)) - # Save variable assignment location for debugging output meta_var.location ||= target.location @@ -817,6 +838,9 @@ module Crystal value.accept self var = lookup_instance_var target + if casted_value = check_automatic_cast(value, var.type, node) + value = casted_value + end target.bind_to var node.bind_to value @@ -911,6 +935,10 @@ module Crystal var = lookup_class_var(target) check_class_var_is_thread_local(target, var, attributes) + if casted_value = check_automatic_cast(value, var.type, node) + value = casted_value + end + target.bind_to var node.bind_to value @@ -931,6 +959,34 @@ module Crystal raise "BUG: unknown assign target in MainVisitor: #{target}" end + # See if we can automatically cast the value if the types don't exactly match + def check_automatic_cast(value, var_type, assign = nil) + MainVisitor.check_automatic_cast(value, var_type, assign) + end + + def self.check_automatic_cast(value, var_type, assign = nil) + if value.is_a?(NumberLiteral) && value.type != var_type && (var_type.is_a?(IntegerType) || var_type.is_a?(FloatType)) + if value.can_be_autocast_to?(var_type) + value.type = var_type + value.kind = var_type.kind + assign.value = value if assign + return value + end + elsif value.is_a?(SymbolLiteral) && var_type.is_a?(EnumType) + member = var_type.find_member(value.value) + if member + path = Path.new(member.name) + path.target_const = member + path.type = var_type + value = path + assign.value = value if assign + return value + end + end + + nil + end + def visit(node : Yield) call = @call unless call @@ -2938,22 +2994,13 @@ module Crystal false end - def visit(node : TypeRestriction) - obj = node.obj - to = node.to - - obj.accept self - - unless context = match_context - node.raise "BUG: there is no match context" - end - - if type = obj.type.restrict(to, context) - node.type = type - else - node.raise "can't restrict #{obj.type} to #{to}" - end - + def visit(node : AssignWithRestriction) + type_assign( + node.assign.target.as(Var), + node.assign.value, + node.assign, + restriction: node.restriction) + node.bind_to(node.assign) false end diff --git a/src/compiler/crystal/semantic/match.cr b/src/compiler/crystal/semantic/match.cr index c2df0fa1277a..9f3c8da499ca 100644 --- a/src/compiler/crystal/semantic/match.cr +++ b/src/compiler/crystal/semantic/match.cr @@ -60,6 +60,7 @@ module Crystal def set_free_var(name, type) free_vars = @free_vars ||= {} of String => TypeVar + type = type.remove_literal if type.is_a?(Type) free_vars[name] = type end diff --git a/src/compiler/crystal/semantic/method_lookup.cr b/src/compiler/crystal/semantic/method_lookup.cr index 13cfe91f318c..702fa8b8a4df 100644 --- a/src/compiler/crystal/semantic/method_lookup.cr +++ b/src/compiler/crystal/semantic/method_lookup.cr @@ -2,8 +2,10 @@ require "../types" module Crystal record NamedArgumentType, name : String, type : Type do - def self.from_args(named_args : Array(NamedArgument)?) - named_args.try &.map { |named_arg| new(named_arg.name, named_arg.value.type) } + def self.from_args(named_args : Array(NamedArgument)?, with_literals = false) + named_args.try &.map do |named_arg| + new(named_arg.name, named_arg.value.type(with_literals: with_literals)) + end end end diff --git a/src/compiler/crystal/semantic/restrictions.cr b/src/compiler/crystal/semantic/restrictions.cr index 90dc0d6df2c1..48f63e0e4037 100644 --- a/src/compiler/crystal/semantic/restrictions.cr +++ b/src/compiler/crystal/semantic/restrictions.cr @@ -1117,6 +1117,54 @@ module Crystal true end end + + class NumberLiteralType + def restrict(other, context) + if other.is_a?(IntegerType) || other.is_a?(FloatType) + if literal.can_be_autocast_to?(other) + if @matched_type && @matched_type != other + literal.raise "ambiguous call matches both #{@matched_type} and #{other}" + end + + @matched_type = other + other + else + literal.type.restrict(other, context) + end + else + type = super(other, context) || + literal.type.restrict(other, context) + if type == self + type = @matched_type || literal.type + end + type + end + end + end + + class SymbolLiteralType + def restrict(other, context) + if other.is_a?(EnumType) + if other.find_member(literal.value) + if @matched_type && @matched_type != other + literal.raise "ambiguous call matches both #{@matched_type} and #{other}" + end + + @matched_type = other + other + else + literal.type.restrict(other, context) + end + else + type = super(other, context) || + literal.type.restrict(other, context) + if type == self + type = @matched_type || literal.type + end + type + end + end + end end private def get_generic_type(node, context) diff --git a/src/compiler/crystal/semantic/to_s.cr b/src/compiler/crystal/semantic/to_s.cr index fe5799f968b5..34b6d3a10ac7 100644 --- a/src/compiler/crystal/semantic/to_s.cr +++ b/src/compiler/crystal/semantic/to_s.cr @@ -48,11 +48,13 @@ module Crystal false end - def visit(node : TypeRestriction) + def visit(node : AssignWithRestriction) @str << "# type restriction: " - node.obj.accept self + node.assign.target.accept self @str << " : " - node.to.accept self + node.restriction.accept self + @str << " = " + node.assign.value.accept self false end diff --git a/src/compiler/crystal/semantic/transformer.cr b/src/compiler/crystal/semantic/transformer.cr index ca9459042263..1246f7e79b73 100644 --- a/src/compiler/crystal/semantic/transformer.cr +++ b/src/compiler/crystal/semantic/transformer.cr @@ -2,7 +2,7 @@ require "../syntax/transformer" module Crystal class Transformer - def transform(node : MetaVar | MetaMacroVar | Primitive | TypeFilteredNode | TupleIndexer | TypeNode | TypeRestriction | YieldBlockBinder | MacroId) + def transform(node : MetaVar | MetaMacroVar | Primitive | TypeFilteredNode | TupleIndexer | TypeNode | AssignWithRestriction | YieldBlockBinder | MacroId) node end diff --git a/src/compiler/crystal/syntax/ast.cr b/src/compiler/crystal/syntax/ast.cr index a004894cb78b..7a30306c583e 100644 --- a/src/compiler/crystal/syntax/ast.cr +++ b/src/compiler/crystal/syntax/ast.cr @@ -222,6 +222,21 @@ module Crystal @value[0] == '+' || @value[0] == '-' end + def integer_value + case kind + when :i8 then value.to_i8 + when :i16 then value.to_i16 + when :i32 then value.to_i32 + when :i64 then value.to_i64 + when :u8 then value.to_u8 + when :u16 then value.to_u16 + when :u32 then value.to_u32 + when :u64 then value.to_u64 + else + raise "Bug: called 'integer_value' for non-integer literal" + end + end + def clone_without_location NumberLiteral.new(@value, @kind) end diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr index 87768ebafdc4..53013df12964 100644 --- a/src/compiler/crystal/types.cr +++ b/src/compiler/crystal/types.cr @@ -554,6 +554,10 @@ module Crystal self end + def remove_literal + self + end + def generic_nest 0 end @@ -1176,6 +1180,29 @@ module Crystal def normal_rank (@rank - 1) / 2 end + + def range + case kind + when :i8 + {Int8::MIN, Int8::MAX} + when :i16 + {Int16::MIN, Int16::MAX} + when :i32 + {Int32::MIN, Int32::MAX} + when :i64 + {Int64::MIN, Int64::MAX} + when :u8 + {UInt8::MIN, UInt8::MAX} + when :u16 + {UInt16::MIN, UInt16::MAX} + when :u32 + {UInt32::MIN, UInt32::MAX} + when :u64 + {UInt64::MIN, UInt64::MAX} + else + raise "Bug: called 'range' for non-integer literal" + end + end end class FloatType < PrimitiveType @@ -1206,6 +1233,45 @@ module Crystal class VoidType < NamedType end + # Type for a number literal: it has the specific type of the number literal + # but can also match other types (like ints and floats) if the literal + # fits in those types. + class NumberLiteralType < Type + getter literal : NumberLiteral + @matched_type : Type? + + def initialize(program, @literal) + super(program) + end + + def remove_literal + literal.type + end + + def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true, codegen = false) + io << @literal.type + end + end + + # Type for a symbol literal: it has the specific type of the symbol literal (SymbolType) + # but can also match enums if their members match the symbol's name. + class SymbolLiteralType < Type + getter literal : SymbolLiteral + @matched_type : Type? + + def initialize(program, @literal) + super(program) + end + + def remove_literal + literal.type + end + + def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true, codegen = false) + io << @literal.type + end + end + # Any thing that can be passed as a generic type variable. # # For example, in: @@ -1360,6 +1426,15 @@ module Crystal value = initializer.value.clone value.accept visitor instance_var = instance.lookup_instance_var(initializer.name) + + # Check if automatic cast can be done + if instance_var.type != value.type && + (value.is_a?(NumberLiteral) || value.is_a?(SymbolLiteral)) + if casted_value = MainVisitor.check_automatic_cast(value, instance_var.type) + value = casted_value + end + end + instance_var.bind_to(value) instance.add_instance_var_initializer(initializer.name, value, meta_vars) end @@ -2451,6 +2526,16 @@ module Crystal true end + def find_member(name) + name = name.underscore + types.each do |member_name, member| + if name == member_name.underscore + return member.as(Const) + end + end + nil + end + def type_desc "enum" end diff --git a/src/llvm/lib_llvm.cr b/src/llvm/lib_llvm.cr index 7c6fc5c333ef..02b1d8f22a57 100644 --- a/src/llvm/lib_llvm.cr +++ b/src/llvm/lib_llvm.cr @@ -352,4 +352,7 @@ lib LibLLVM fun create_builder_in_context = LLVMCreateBuilderInContext(c : ContextRef) : BuilderRef fun get_type_context = LLVMGetTypeContext(TypeRef) : ContextRef + + fun const_int_get_sext_value = LLVMConstIntGetSExtValue(ValueRef) : Int64 + fun const_int_get_zext_value = LLVMConstIntGetZExtValue(ValueRef) : UInt64 end diff --git a/src/llvm/value_methods.cr b/src/llvm/value_methods.cr index 3b69d12e06e7..b66cecabe0a0 100644 --- a/src/llvm/value_methods.cr +++ b/src/llvm/value_methods.cr @@ -87,6 +87,14 @@ module LLVM::ValueMethods LibLLVM.set_alignment(self, bytes) end + def const_int_get_sext_value + LibLLVM.const_int_get_sext_value(self) + end + + def const_int_get_zext_value + LibLLVM.const_int_get_zext_value(self) + end + def to_value LLVM::Value.new unwrap end