Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: "Types in a binary operation should match, but found T and T" #4648

Merged
merged 11 commits into from
Mar 29, 2024
82 changes: 44 additions & 38 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ impl<'interner> TypeChecker<'interner> {
Ok((typ, use_impl)) => {
if use_impl {
let id = infix_expr.trait_method_id;
// Assume operators have no trait generics
self.verify_trait_constraint(
&lhs_type,
id.trait_id,
&[],
*expr_id,
span,
);

// Delay checking the trait constraint until the end of the function.
// Checking it now could bind an unbound type variable to any type
// that implements the trait.
let constraint = crate::hir_def::traits::TraitConstraint {
typ: lhs_type.clone(),
trait_id: id.trait_id,
trait_generics: Vec::new(),
};
self.trait_constraints.push((constraint, *expr_id));
self.typecheck_operator_method(*expr_id, id, &lhs_type, span);
}
typ
Expand Down Expand Up @@ -836,6 +838,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_, Error) => Ok((Bool, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
}

// Matches on TypeVariable must be first to follow any type
// bindings.
Expand All @@ -844,12 +850,8 @@ impl<'interner> TypeChecker<'interner> {
return self.comparator_operand_type_rules(other, binding, op, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand Down Expand Up @@ -1079,36 +1081,42 @@ impl<'interner> TypeChecker<'interner> {
}
}

/// Handles the TypeVariable case for checking binary operators.
/// Returns true if we should use the impl for the operator instead of the primitive
/// version of it.
fn bind_type_variables_for_infix(
&mut self,
lhs_type: &Type,
op: &HirBinaryOp,
rhs_type: &Type,
span: Span,
) {
) -> bool {
self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});

// In addition to unifying both types, we also have to bind either
// the lhs or rhs to an integer type variable. This ensures if both lhs
// and rhs are type variables, that they will have the correct integer
// type variable kind instead of TypeVariableKind::Normal.
let target = if op.kind.is_valid_for_field_type() {
Type::polymorphic_integer_or_field(self.interner)
} else {
Type::polymorphic_integer(self.interner)
};
let use_impl = !lhs_type.is_numeric();

self.unify(lhs_type, &target, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});
// if the type variable is an integer or field we have to narrow it to only an integer
jfecher marked this conversation as resolved.
Show resolved Hide resolved
if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() {
// In addition to unifying both types, we also have to bind either
// the lhs or rhs to an integer type variable. This ensures if both lhs
// and rhs are type variables, that they will have the correct integer
// type variable kind instead of TypeVariableKind::Normal.
let target = Type::polymorphic_integer(self.interner);

self.unify(lhs_type, &target, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
jfecher marked this conversation as resolved.
Show resolved Hide resolved
source: Source::Binary,
span,
});
}

use_impl
}

// Given a binary operator and another type. This method will produce the output type
Expand All @@ -1130,6 +1138,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// An error type on either side will always return an error
(Error, _) | (_, Error) => Ok((Error, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
}

// Matches on TypeVariable must be first so that we follow any type
// bindings.
Expand All @@ -1138,14 +1150,8 @@ impl<'interner> TypeChecker<'interner> {
return self.infix_operand_type_rules(binding, op, other, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);

// Both types are unified so the choice of which to return is arbitrary
Ok((other.clone(), false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((other.clone(), use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand Down
50 changes: 24 additions & 26 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,13 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type

let function_last_type = type_checker.check_function_body(function_body_id);

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

errors.append(&mut type_checker.errors);

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

// Check declared return type and actual return type
if !can_ignore_ret {
let (expr_span, empty_function) = function_info(interner, function_body_id);
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
let (expr_span, empty_function) = function_info(type_checker.interner, function_body_id);
let func_span = type_checker.interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
if let Type::TraitAsType(trait_id, _, generics) = &declared_return_type {
if interner
if type_checker
.interner
.lookup_trait_implementation(&function_last_type, *trait_id, generics)
.is_err()
{
Expand All @@ -126,7 +108,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
type_checker.interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
Expand All @@ -137,16 +119,32 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
error = error.add_context("implicitly returns `()` as its body has no tail or `return` expression");
}
error
},
);
}
}

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
type_checker.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

errors.append(&mut type_checker.errors);
errors
}

Expand Down
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@
TypeBinding::Bound(binding) => binding.is_bindable(),
TypeBinding::Unbound(_) => true,
},
Type::Alias(alias, args) => alias.borrow().get_type(args).is_bindable(),
_ => false,
}
}
Expand All @@ -605,6 +606,15 @@
matches!(self.follow_bindings(), Type::Integer(Signedness::Unsigned, _))
}

pub fn is_numeric(&self) -> bool {
use Type::*;
use TypeVariableKind as K;
matches!(
self.follow_bindings(),
FieldElement | Integer(..) | Bool | TypeVariable(_, K::Integer | K::IntegerOrField)
)
}

fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool {
// True if the given type is a NamedGeneric with the target_id
let named_generic_id_matches_target = |typ: &Type| {
Expand Down Expand Up @@ -1510,7 +1520,7 @@
Type::Tuple(fields)
}
Type::Forall(typevars, typ) => {
// Trying to substitute_helper a variable de, substitute_bound_typevarsfined within a nested Forall

Check warning on line 1523 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (typevarsfined)
// is usually impossible and indicative of an error in the type checker somewhere.
for var in typevars {
assert!(!type_bindings.contains_key(&var.id()));
Expand Down
12 changes: 6 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,19 +1033,19 @@ mod test {
fn resolve_complex_closures() {
let src = r#"
fn main(x: Field) -> pub Field {
let closure_without_captures = |x| x + x;
let closure_without_captures = |x: Field| -> Field { x + x };
let a = closure_without_captures(1);

let closure_capturing_a_param = |y| y + x;
let closure_capturing_a_param = |y: Field| -> Field { y + x };
let b = closure_capturing_a_param(2);

let closure_capturing_a_local_var = |y| y + b;
let closure_capturing_a_local_var = |y: Field| -> Field { y + b };
let c = closure_capturing_a_local_var(3);

let closure_with_transitive_captures = |y| {
let closure_with_transitive_captures = |y: Field| -> Field {
let d = 5;
let nested_closure = |z| {
let doubly_nested_closure = |w| w + x + b;
let nested_closure = |z: Field| -> Field {
let doubly_nested_closure = |w: Field| -> Field { w + x + b };
a + z + y + d + x + doubly_nested_closure(4) + x + y
};
let res = nested_closure(5);
Expand Down
14 changes: 7 additions & 7 deletions noir_stdlib/src/cmp.nr
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ trait Eq {

impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } }

impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u64 { fn eq(self, other: u64) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }

impl Eq for i8 { fn eq(self, other: i8) -> bool { self == other } }
impl Eq for i32 { fn eq(self, other: i32) -> bool { self == other } }
Expand Down Expand Up @@ -107,8 +107,8 @@ trait Ord {

// Note: Field deliberately does not implement Ord

impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand All @@ -131,8 +131,8 @@ impl Ord for u32 {
}
}

impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand Down
Loading
Loading