Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds break expressions #62

Merged
merged 3 commits into from
Nov 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 8 additions & 25 deletions crates/mun/test/main.mun
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
// function to subtract two floats
fn subtract(a:float, b:float):float {
a-b
}

// function to subtract two floats
fn multiply(a:float, b:float):float {
a*b
}

fn main():int {
add(5, 3)
}

fn add_impl(a:int, b:int):int {
a+b
}

fn add(a:int, b:int):int {
add_impl(a,b)
}

fn test():int {
add(4,5)
}
fn foo():int {
break; // error: not in a loop
loop { break 3; break 3.0; } // error: mismatched type
let a:int = loop { break 3.0; } // error: mismatched type
loop { break 3; }
let a:int = loop { break loop { break 3; } }
loop { break loop { break 3.0; } } // error: mismatched type
}
53 changes: 52 additions & 1 deletion crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ use mun_hir::{
};
use std::{collections::HashMap, mem, sync::Arc};

use inkwell::basic_block::BasicBlock;
use inkwell::values::PointerValue;

struct LoopInfo {
break_values: Vec<(
inkwell::values::BasicValueEnum,
inkwell::basic_block::BasicBlock,
)>,
exit_block: BasicBlock,
}

pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> {
db: &'a D,
module: &'a Module,
Expand All @@ -25,6 +34,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> {
pat_to_name: HashMap<PatId, String>,
function_map: &'a HashMap<mun_hir::Function, FunctionValue>,
dispatch_table: &'b DispatchTable,
active_loop: Option<LoopInfo>,
}

impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Expand Down Expand Up @@ -58,6 +68,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
pat_to_name: HashMap::default(),
function_map,
dispatch_table,
active_loop: None,
}
}

Expand Down Expand Up @@ -132,6 +143,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
} => self.gen_if(expr, *condition, *then_branch, *else_branch),
Expr::Return { expr: ret_expr } => self.gen_return(expr, *ret_expr),
Expr::Loop { body } => self.gen_loop(expr, *body),
Expr::Break { expr: break_expr } => self.gen_break(expr, *break_expr),
_ => unimplemented!("unimplemented expr type {:?}", &body[expr]),
}
}
Expand Down Expand Up @@ -575,9 +587,31 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
None
}

fn gen_break(&mut self, _expr: ExprId, break_expr: Option<ExprId>) -> Option<BasicValueEnum> {
let break_value = break_expr.and_then(|expr| self.gen_expr(expr));
let loop_info = self.active_loop.as_mut().unwrap();
if let Some(break_value) = break_value {
loop_info
.break_values
.push((break_value, self.builder.get_insert_block().unwrap()));
}
self.builder
.build_unconditional_branch(&loop_info.exit_block);
None
}

fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option<BasicValueEnum> {
let context = self.module.get_context();
let loop_block = context.append_basic_block(&self.fn_value, "loop");
let exit_block = context.append_basic_block(&self.fn_value, "exit");

// Build a new loop info struct
let loop_info = LoopInfo {
exit_block,
break_values: Vec::new(),
};

let prev_loop = std::mem::replace(&mut self.active_loop, Some(loop_info));

// Insert an explicit fall through from the current block to the loop
self.builder.build_unconditional_branch(&loop_block);
Expand All @@ -589,6 +623,23 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Jump to the start of the loop
self.builder.build_unconditional_branch(&loop_block);

None
let LoopInfo {
exit_block,
break_values,
} = std::mem::replace(&mut self.active_loop, prev_loop).unwrap();

// Move the builder to the exit block
self.builder.position_at_end(&exit_block);

if !break_values.is_empty() {
let (value, _) = break_values.first().unwrap();
let phi = self.builder.build_phi(value.get_type(), "exit");
for (ref value, ref block) in break_values {
phi.add_incoming(&[(value, block)])
}
Some(phi.as_basic_value())
} else {
None
}
}
}
29 changes: 29 additions & 0 deletions crates/mun_codegen/src/snapshots/test__loop_break_expr.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
source: crates/mun_codegen/src/test.rs
expression: "fn foo(n:int):int {\n loop {\n if n > 5 {\n break n;\n }\n if n > 10 {\n break 10;\n }\n n += 1;\n }\n}"
---
; ModuleID = 'main.mun'
source_filename = "main.mun"

define i64 @foo(i64) {
body:
br label %loop

loop: ; preds = %if_merge6, %body
%n.0 = phi i64 [ %0, %body ], [ %add, %if_merge6 ]
%greater = icmp sgt i64 %n.0, 5
br i1 %greater, label %exit, label %if_merge

exit: ; preds = %if_merge, %loop
%exit8 = phi i64 [ %n.0, %loop ], [ 10, %if_merge ]
ret i64 %exit8

if_merge: ; preds = %loop
%greater4 = icmp sgt i64 %n.0, 10
br i1 %greater4, label %exit, label %if_merge6

if_merge6: ; preds = %if_merge
%add = add i64 %n.0, 1
br label %loop
}

19 changes: 19 additions & 0 deletions crates/mun_codegen/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,25 @@ fn loop_expr() {
)
}

#[test]
fn loop_break_expr() {
test_snapshot(
r#"
fn foo(n:int):int {
loop {
if n > 5 {
break n;
}
if n > 10 {
break 10;
}
n += 1;
}
}
"#,
)
}

fn test_snapshot(text: &str) {
let text = text.trim().replace("\n ", "\n");

Expand Down
24 changes: 24 additions & 0 deletions crates/mun_hir/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,27 @@ impl Diagnostic for ReturnMissingExpression {
self
}
}

#[derive(Debug)]
pub struct BreakOutsideLoop {
pub file: FileId,
pub break_expr: SyntaxNodePtr,
}

impl Diagnostic for BreakOutsideLoop {
fn message(&self) -> String {
"`break` outside of a loop".to_owned()
}

fn file(&self) -> FileId {
self.file
}

fn syntax_node_ptr(&self) -> SyntaxNodePtr {
self.break_expr
}

fn as_any(&self) -> &(dyn Any + Send + 'static) {
self
}
}
15 changes: 15 additions & 0 deletions crates/mun_hir/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ pub enum Expr {
Return {
expr: Option<ExprId>,
},
Break {
expr: Option<ExprId>,
},
Loop {
body: ExprId,
},
Expand Down Expand Up @@ -293,6 +296,11 @@ impl Expr {
f(*expr);
}
}
Expr::Break { expr } => {
if let Some(expr) = expr {
f(*expr);
}
}
Expr::Loop { body } => {
f(*body);
}
Expand Down Expand Up @@ -461,6 +469,7 @@ where
match expr.kind() {
ast::ExprKind::LoopExpr(expr) => self.collect_loop(expr),
ast::ExprKind::ReturnExpr(r) => self.collect_return(r),
ast::ExprKind::BreakExpr(r) => self.collect_break(r),
ast::ExprKind::BlockExpr(b) => self.collect_block(b),
ast::ExprKind::Literal(e) => {
let lit = match e.kind() {
Expand Down Expand Up @@ -634,6 +643,12 @@ where
self.alloc_expr(Expr::Return { expr }, syntax_node_ptr)
}

fn collect_break(&mut self, expr: ast::BreakExpr) -> ExprId {
let syntax_node_ptr = AstPtr::new(&expr.clone().into());
let expr = expr.expr().map(|e| self.collect_expr(e));
self.alloc_expr(Expr::Break { expr }, syntax_node_ptr)
}

fn collect_loop(&mut self, expr: ast::LoopExpr) -> ExprId {
let syntax_node_ptr = AstPtr::new(&expr.clone().into());
let body = self.collect_block_opt(expr.loop_body());
Expand Down
82 changes: 76 additions & 6 deletions crates/mun_hir/src/ty/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ struct InferenceResultBuilder<'a, D: HirDatabase> {

type_variables: TypeVariableTable,

/// Information on the current loop that we're processing (or None if we're not in a loop) the
/// entry contains the current type of the loop statement (initially `never`) and the expected
/// type of the loop expression. Both these values are updated when a break statement is
/// encountered.
active_loop: Option<(Ty, Expectation)>,

/// The return type of the function being inferred.
return_ty: Ty,
}
Expand All @@ -115,6 +121,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> {
type_of_expr: ArenaMap::default(),
type_of_pat: ArenaMap::default(),
diagnostics: Vec::default(),
active_loop: None,
type_variables: TypeVariableTable::default(),
db,
body,
Expand Down Expand Up @@ -306,10 +313,8 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> {

Ty::simple(TypeCtor::Never)
}
Expr::Loop { body } => {
self.infer_expr(*body, &Expectation::has_type(Ty::Empty));
Ty::simple(TypeCtor::Never)
}
Expr::Break { expr } => self.infer_break(tgt_expr, *expr),
Expr::Loop { body } => self.infer_loop_expr(tgt_expr, *body, expected),
_ => Ty::Unknown,
// Expr::UnaryOp { expr: _, op: _ } => {}
// Expr::Block { statements: _, tail: _ } => {}
Expand Down Expand Up @@ -513,6 +518,61 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> {
}
}

fn infer_break(&mut self, tgt_expr: ExprId, expr: Option<ExprId>) -> Ty {
// Fetch the expected type
let expected = if let Some((_, info)) = &self.active_loop {
info.clone()
} else {
self.diagnostics
.push(InferenceDiagnostic::BreakOutsideLoop { id: tgt_expr });
return Ty::simple(TypeCtor::Never);
};

// Infer the type of the break expression
let ty = if let Some(expr) = expr {
self.infer_expr_inner(expr, &expected)
} else {
Ty::Empty
};

// Verify that it matches what we expected
let ty = if !expected.is_none() && ty != expected.ty {
self.diagnostics.push(InferenceDiagnostic::MismatchedTypes {
expected: expected.ty.clone(),
found: ty.clone(),
id: tgt_expr,
});
expected.ty
} else {
ty
};

// Update the expected type for the rest of the loop
self.active_loop = Some((ty.clone(), Expectation::has_type(ty)));

Ty::simple(TypeCtor::Never)
}

fn infer_loop_expr(&mut self, _tgt_expr: ExprId, body: ExprId, expected: &Expectation) -> Ty {
self.infer_loop_block(body, expected)
}

/// Infers the type of a loop body, taking into account breaks.
fn infer_loop_block(&mut self, body: ExprId, expected: &Expectation) -> Ty {
// Take the previous loop information and replace it with a new entry
let top_level_loop = std::mem::replace(
&mut self.active_loop,
Some((Ty::simple(TypeCtor::Never), expected.clone())),
);

// Infer the body of the loop
self.infer_expr_coerce(body, &Expectation::has_type(Ty::Empty));

// Take the result of the loop information and replace with top level loop
let (ty, _) = std::mem::replace(&mut self.active_loop, top_level_loop).unwrap();
ty
}

pub fn report_pat_inference_failure(&mut self, _pat: PatId) {
// self.diagnostics.push(InferenceDiagnostic::PatInferenceFailed {
// pat
Expand Down Expand Up @@ -573,8 +633,8 @@ impl From<PatId> for ExprOrPatId {

mod diagnostics {
use crate::diagnostics::{
CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, InvalidLHS, MismatchedType,
MissingElseBranch, ParameterCountMismatch, ReturnMissingExpression,
BreakOutsideLoop, CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, InvalidLHS,
MismatchedType, MissingElseBranch, ParameterCountMismatch, ReturnMissingExpression,
};
use crate::{
code_model::src::HasSource,
Expand Down Expand Up @@ -627,6 +687,9 @@ mod diagnostics {
ReturnMissingExpression {
id: ExprId,
},
BreakOutsideLoop {
id: ExprId,
},
}

impl InferenceDiagnostic {
Expand Down Expand Up @@ -736,6 +799,13 @@ mod diagnostics {
return_expr: id,
});
}
InferenceDiagnostic::BreakOutsideLoop { id } => {
let id = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr();
sink.push(BreakOutsideLoop {
file,
break_expr: id,
});
}
}
}
}
Expand Down
Loading