From 84c6fb3f7066562fc753b34a58359e292ae16928 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 27 Nov 2024 14:53:38 +0000 Subject: [PATCH] Fix elem ty lookup and refactor build_alloca_i8_ptr --- hugr-core/src/std_extensions/collections.rs | 5 + hugr-llvm/src/extension/collections.rs | 94 +++++++++---------- ...sion__collections__test__const@llvm14.snap | 3 +- ...tions__test__const@pre-mem2reg@llvm14.snap | 2 +- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index c418f0b92..d8de58304 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -60,6 +60,11 @@ impl ListValue { list_custom_type(self.1.clone()) } + /// Returns the type of values inside the `[ListValue]`. + pub fn get_element_type(&self) -> &Type { + &self.1 + } + /// Returns the values contained inside the `[ListValue]`. pub fn get_contents(&self) -> &[Value] { &self.0 diff --git a/hugr-llvm/src/extension/collections.rs b/hugr-llvm/src/extension/collections.rs index 5a872a082..8720041d8 100644 --- a/hugr-llvm/src/extension/collections.rs +++ b/hugr-llvm/src/extension/collections.rs @@ -1,8 +1,6 @@ use anyhow::{anyhow, bail, Ok, Result}; -use hugr_core::core::Either; -use hugr_core::core::Either::{Left, Right}; use hugr_core::{ - ops::{constant::CustomConst, ExtensionOp, NamedOp}, + ops::{ExtensionOp, NamedOp}, std_extensions::collections::{self, ListOp, ListValue}, types::{SumType, Type, TypeArg}, HugrView, @@ -111,8 +109,8 @@ impl From for CollectionsRtFunc { } } -/// A helper trait for customising the lowering [hugr_core::std_extensions::collections] -/// types, [CustomConst]s, and ops. +/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections] +/// types, [hugr_core::ops::constant::CustomConst]s, and ops. pub trait CollectionsCodegen: Clone { /// Return the llvm type of [hugr_core::std_extensions::collections::LIST_TYPENAME]. fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> { @@ -138,43 +136,6 @@ pub trait CollectionsCodegen: Clone { } } -/// Helper function to allocate space on the stack for a given type. -/// -/// Returns an i8 pointer to the allocated memory. -fn build_alloca_i8_ptr<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - ty_or_val: Either, BasicValueEnum<'c>>, -) -> Result> { - let builder = ctx.builder(); - let ty = match ty_or_val { - Left(ty) => ty, - Right(val) => val.get_type(), - }; - let ptr = builder.build_alloca(ty, "")?; - - if let Right(val) = ty_or_val { - builder.build_store(ptr, val)?; - } - let i8_ptr = builder.build_pointer_cast( - ptr, - ctx.iw_context().i8_type().ptr_type(AddressSpace::default()), - "", - )?; - Ok(i8_ptr) -} - -/// Helper function to load a value from an i8 pointer. -fn build_load_i8_ptr<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - i8_ptr: PointerValue<'c>, - ty: BasicTypeEnum<'c>, -) -> Result> { - let builder = ctx.builder(); - let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?; - let val = builder.build_load(ptr, "")?; - Ok(val) -} - /// A trivial implementation of [CollectionsCodegen] which passes all methods /// through to their default implementations. #[derive(Default, Clone)] @@ -251,14 +212,14 @@ fn emit_list_op<'c, H: HugrView>( match op { ListOp::push => { let [list, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, Right(elem))?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; ctx.builder() .build_call(func, &[list.into(), elem_ptr.into()], "")?; args.outputs.finish(ctx.builder(), vec![list])?; } ListOp::pop => { let [list] = args.inputs.try_into().unwrap(); - let out_ptr = build_alloca_i8_ptr(ctx, Left(elem_ty))?; + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; let ok = ctx .builder() .build_call(func, &[list.into(), out_ptr.into()], "")? @@ -271,7 +232,7 @@ fn emit_list_op<'c, H: HugrView>( } ListOp::get => { let [list, idx] = args.inputs.try_into().unwrap(); - let out_ptr = build_alloca_i8_ptr(ctx, Left(elem_ty))?; + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; let ok = ctx .builder() .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")? @@ -284,7 +245,7 @@ fn emit_list_op<'c, H: HugrView>( } ListOp::set => { let [list, idx, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, Right(elem))?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; let ok = ctx .builder() .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? @@ -298,7 +259,7 @@ fn emit_list_op<'c, H: HugrView>( } ListOp::insert => { let [list, idx, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, Right(elem))?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; let ok = ctx .builder() .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? @@ -332,7 +293,7 @@ fn emit_list_value<'c, H: HugrView>( ccg: &(impl CollectionsCodegen + 'c), val: &ListValue, ) -> Result> { - let elem_ty = ctx.llvm_type(&val.get_type())?; + let elem_ty = ctx.llvm_type(val.get_element_type())?; let iwc = ctx.typing_session().iw_context(); let capacity = iwc .i64_type() @@ -359,13 +320,48 @@ fn emit_list_value<'c, H: HugrView>( let rt_push = CollectionsRtFunc::Push.get_extern(ctx, ccg)?; for v in val.get_contents() { let elem = emit_value(ctx, v)?; - let elem_ptr = build_alloca_i8_ptr(ctx, Right(elem))?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; ctx.builder() .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?; } Ok(list) } +/// Helper function to allocate space on the stack for a given type. +/// +/// Optionally also stores a value at that location. +/// +/// Returns an i8 pointer to the allocated memory. +fn build_alloca_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ty: BasicTypeEnum<'c>, + value: Option>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_alloca(ty, "")?; + if let Some(val) = value { + builder.build_store(ptr, val)?; + } + let i8_ptr = builder.build_pointer_cast( + ptr, + ctx.iw_context().i8_type().ptr_type(AddressSpace::default()), + "", + )?; + Ok(i8_ptr) +} + +/// Helper function to load a value from an i8 pointer. +fn build_load_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + i8_ptr: PointerValue<'c>, + ty: BasicTypeEnum<'c>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?; + let val = builder.build_load(ptr, "")?; + Ok(val) +} + #[cfg(test)] mod test { use hugr_core::{ diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap index 8102b8690..8ad058cf3 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap @@ -1,6 +1,5 @@ --- source: hugr-llvm/src/extension/collections.rs -assertion_line: 592 expression: mod_str --- ; ModuleID = 'test_context' @@ -11,7 +10,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i8** getelementptr (i8*, i8** null, i32 1) to i64), i64 8, i8* null) + %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 8, i8* null) %1 = alloca i64, align 8 store i64 1, i64* %1, align 4 %2 = bitcast i64* %1 to i8* diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap index d3a6c8b10..5522be9ad 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap @@ -12,7 +12,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i8** getelementptr (i8*, i8** null, i32 1) to i64), i64 8, i8* null) + %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 8, i8* null) %1 = alloca i64, align 8 store i64 1, i64* %1, align 4 %2 = bitcast i64* %1 to i8*