Skip to content

Commit

Permalink
Rollup merge of rust-lang#126417 - beetrees:f16-f128-inline-asm-x86, …
Browse files Browse the repository at this point in the history
…r=Amanieu

Add `f16` and `f128` inline ASM support for `x86` and `x86-64`

This PR adds `f16` and `f128` input and output support to inline ASM on `x86` and `x86-64`. `f16` vector sizes are taken from [here](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html).

Relevant issue: rust-lang#125398
Tracking issue: rust-lang#116909

``@rustbot`` label +F-f16_and_f128
  • Loading branch information
matthiaskrgr authored Jun 15, 2024
2 parents dad74aa + dfc5514 commit 0f2cc21
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 42 deletions.
100 changes: 100 additions & 0 deletions compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,43 @@ fn llvm_fixup_input<'ll, 'tcx>(
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg | X86InlineAsmRegClass::zmm_reg),
Abi::Vector { .. },
) if layout.size.bytes() == 64 => bx.bitcast(value, bx.cx.type_vector(bx.cx.type_f64(), 8)),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if bx.sess().asm_arch == Some(InlineAsmArch::X86)
&& s.primitive() == Primitive::Float(Float::F128) =>
{
bx.bitcast(value, bx.type_vector(bx.type_i32(), 4))
}
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if s.primitive() == Primitive::Float(Float::F16) => {
let value = bx.insert_element(
bx.const_undef(bx.type_vector(bx.type_f16(), 8)),
value,
bx.const_usize(0),
);
bx.bitcast(value, bx.type_vector(bx.type_i16(), 8))
}
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Vector { element, count: count @ (8 | 16) },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_vector(bx.type_i16(), count))
}
(
InlineAsmRegClass::Arm(ArmInlineAsmRegClass::sreg | ArmInlineAsmRegClass::sreg_low16),
Abi::Scalar(s),
Expand Down Expand Up @@ -1036,6 +1073,39 @@ fn llvm_fixup_output<'ll, 'tcx>(
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg | X86InlineAsmRegClass::zmm_reg),
Abi::Vector { .. },
) if layout.size.bytes() == 64 => bx.bitcast(value, layout.llvm_type(bx.cx)),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if bx.sess().asm_arch == Some(InlineAsmArch::X86)
&& s.primitive() == Primitive::Float(Float::F128) =>
{
bx.bitcast(value, bx.type_f128())
}
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if s.primitive() == Primitive::Float(Float::F16) => {
let value = bx.bitcast(value, bx.type_vector(bx.type_f16(), 8));
bx.extract_element(value, bx.const_usize(0))
}
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Vector { element, count: count @ (8 | 16) },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_vector(bx.type_f16(), count))
}
(
InlineAsmRegClass::Arm(ArmInlineAsmRegClass::sreg | ArmInlineAsmRegClass::sreg_low16),
Abi::Scalar(s),
Expand Down Expand Up @@ -1109,6 +1179,36 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
InlineAsmRegClass::X86(X86InlineAsmRegClass::xmm_reg | X86InlineAsmRegClass::zmm_reg),
Abi::Vector { .. },
) if layout.size.bytes() == 64 => cx.type_vector(cx.type_f64(), 8),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if cx.sess().asm_arch == Some(InlineAsmArch::X86)
&& s.primitive() == Primitive::Float(Float::F128) =>
{
cx.type_vector(cx.type_i32(), 4)
}
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Scalar(s),
) if s.primitive() == Primitive::Float(Float::F16) => cx.type_vector(cx.type_i16(), 8),
(
InlineAsmRegClass::X86(
X86InlineAsmRegClass::xmm_reg
| X86InlineAsmRegClass::ymm_reg
| X86InlineAsmRegClass::zmm_reg,
),
Abi::Vector { element, count: count @ (8 | 16) },
) if element.primitive() == Primitive::Float(Float::F16) => {
cx.type_vector(cx.type_i16(), count)
}
(
InlineAsmRegClass::Arm(ArmInlineAsmRegClass::sreg | ArmInlineAsmRegClass::sreg_low16),
Abi::Scalar(s),
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsicck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ impl<'a, 'tcx> InlineAsmCtxt<'a, 'tcx> {
ty::Int(IntTy::I64) | ty::Uint(UintTy::U64) => Some(InlineAsmType::I64),
ty::Int(IntTy::I128) | ty::Uint(UintTy::U128) => Some(InlineAsmType::I128),
ty::Int(IntTy::Isize) | ty::Uint(UintTy::Usize) => Some(asm_ty_isize),
ty::Float(FloatTy::F16) => Some(InlineAsmType::F16),
ty::Float(FloatTy::F32) => Some(InlineAsmType::F32),
ty::Float(FloatTy::F64) => Some(InlineAsmType::F64),
ty::Float(FloatTy::F128) => Some(InlineAsmType::F128),
ty::FnPtr(_) => Some(asm_ty_isize),
ty::RawPtr(ty, _) if self.is_thin_ptr_ty(ty) => Some(asm_ty_isize),
ty::Adt(adt, args) if adt.repr().simd() => {
Expand Down Expand Up @@ -105,8 +107,10 @@ impl<'a, 'tcx> InlineAsmCtxt<'a, 'tcx> {
width => bug!("unsupported pointer width: {width}"),
})
}
ty::Float(FloatTy::F16) => Some(InlineAsmType::VecF16(size)),
ty::Float(FloatTy::F32) => Some(InlineAsmType::VecF32(size)),
ty::Float(FloatTy::F64) => Some(InlineAsmType::VecF64(size)),
ty::Float(FloatTy::F128) => Some(InlineAsmType::VecF128(size)),
_ => None,
}
}
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_target/src/asm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,15 +707,19 @@ pub enum InlineAsmType {
I32,
I64,
I128,
F16,
F32,
F64,
F128,
VecI8(u64),
VecI16(u64),
VecI32(u64),
VecI64(u64),
VecI128(u64),
VecF16(u64),
VecF32(u64),
VecF64(u64),
VecF128(u64),
}

impl InlineAsmType {
Expand All @@ -730,15 +734,19 @@ impl InlineAsmType {
Self::I32 => 4,
Self::I64 => 8,
Self::I128 => 16,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
Self::F128 => 16,
Self::VecI8(n) => n * 1,
Self::VecI16(n) => n * 2,
Self::VecI32(n) => n * 4,
Self::VecI64(n) => n * 8,
Self::VecI128(n) => n * 16,
Self::VecF16(n) => n * 2,
Self::VecF32(n) => n * 4,
Self::VecF64(n) => n * 8,
Self::VecF128(n) => n * 16,
})
}
}
Expand All @@ -751,15 +759,19 @@ impl fmt::Display for InlineAsmType {
Self::I32 => f.write_str("i32"),
Self::I64 => f.write_str("i64"),
Self::I128 => f.write_str("i128"),
Self::F16 => f.write_str("f16"),
Self::F32 => f.write_str("f32"),
Self::F64 => f.write_str("f64"),
Self::F128 => f.write_str("f128"),
Self::VecI8(n) => write!(f, "i8x{n}"),
Self::VecI16(n) => write!(f, "i16x{n}"),
Self::VecI32(n) => write!(f, "i32x{n}"),
Self::VecI64(n) => write!(f, "i64x{n}"),
Self::VecI128(n) => write!(f, "i128x{n}"),
Self::VecF16(n) => write!(f, "f16x{n}"),
Self::VecF32(n) => write!(f, "f32x{n}"),
Self::VecF64(n) => write!(f, "f64x{n}"),
Self::VecF128(n) => write!(f, "f128x{n}"),
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions compiler/rustc_target/src/asm/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,26 @@ impl X86InlineAsmRegClass {
match self {
Self::reg | Self::reg_abcd => {
if arch == InlineAsmArch::X86_64 {
types! { _: I16, I32, I64, F32, F64; }
types! { _: I16, I32, I64, F16, F32, F64; }
} else {
types! { _: I16, I32, F32; }
types! { _: I16, I32, F16, F32; }
}
}
Self::reg_byte => types! { _: I8; },
Self::xmm_reg => types! {
sse: I32, I64, F32, F64,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF32(4), VecF64(2);
sse: I32, I64, F16, F32, F64, F128,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF16(8), VecF32(4), VecF64(2);
},
Self::ymm_reg => types! {
avx: I32, I64, F32, F64,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF32(4), VecF64(2),
VecI8(32), VecI16(16), VecI32(8), VecI64(4), VecF32(8), VecF64(4);
avx: I32, I64, F16, F32, F64, F128,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF16(8), VecF32(4), VecF64(2),
VecI8(32), VecI16(16), VecI32(8), VecI64(4), VecF16(16), VecF32(8), VecF64(4);
},
Self::zmm_reg => types! {
avx512f: I32, I64, F32, F64,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF32(4), VecF64(2),
VecI8(32), VecI16(16), VecI32(8), VecI64(4), VecF32(8), VecF64(4),
VecI8(64), VecI16(32), VecI32(16), VecI64(8), VecF32(16), VecF64(8);
avx512f: I32, I64, F16, F32, F64, F128,
VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF16(8), VecF32(4), VecF64(2),
VecI8(32), VecI16(16), VecI32(8), VecI64(4), VecF16(16), VecF32(8), VecF64(4),
VecI8(64), VecI16(32), VecI32(16), VecI64(8), VecF16(32), VecF32(16), VecF64(8);
},
Self::kreg => types! {
avx512f: I8, I16;
Expand Down
Loading

0 comments on commit 0f2cc21

Please sign in to comment.