Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Implement SIMD swizzle and shuffle #6515

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i8x16_arith2",
"simd_i8x16_cmp",
"simd_int_to_int_extend",
"simd_lane",
"simd_load",
"simd_load_extend",
"simd_load_zero",
Expand Down
6 changes: 6 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,12 @@

;; UImm5 Helpers

;; Extractor that matches a `Value` equivalent to a replicated UImm5 on all lanes.
;; TODO: Try matching vconst here as well
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you either resolve this TODO in this PR or turn it into TODO(#1234) with a reference to a follow up issue?

(decl replicated_uimm5 (UImm5) Value)
(extractor (replicated_uimm5 n)
(def_inst (splat (uimm5_from_value n))))

;; Helper to go directly from a `Value`, when it's an `iconst`, to an `UImm5`.
(decl uimm5_from_value (UImm5) Value)
(extractor (uimm5_from_value n)
Expand Down
28 changes: 25 additions & 3 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,17 +654,39 @@ fn riscv64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut Operan

collector.reg_use(vs1);
collector.reg_use(vs2);
collector.reg_def(vd);

// If the operation forbids source/destination overlap, then we must
// register it as an early_def. This encodes the constraint that
// these must not overlap.
if op.forbids_src_dst_overlaps() {
collector.reg_early_def(vd);
} else {
collector.reg_def(vd);
}

vec_mask_operands(mask, collector);
}
&Inst::VecAluRRImm5 {
vd, vs2, ref mask, ..
op,
vd,
vs2,
ref mask,
..
} => {
debug_assert_eq!(vd.to_reg().class(), RegClass::Vector);
debug_assert_eq!(vs2.class(), RegClass::Vector);

collector.reg_use(vs2);
collector.reg_def(vd);

// If the operation forbids source/destination overlap, then we must
// register it as an early_def. This encodes the constraint that
// these must not overlap.
if op.forbids_src_dst_overlaps() {
collector.reg_early_def(vd);
} else {
collector.reg_def(vd);
}

vec_mask_operands(mask, collector);
}
&Inst::VecAluRR {
Expand Down
30 changes: 26 additions & 4 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ impl VecAluOpRRR {
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjnVV => 0b001001,
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => 0b001100,
VecAluOpRRR::VmsltVX => 0b011011,
}
}
Expand All @@ -318,7 +319,8 @@ impl VecAluOpRRR {
| VecAluOpRRR::VminVV
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM => VecOpCategory::OPIVV,
| VecAluOpRRR::VmergeVVM
| VecAluOpRRR::VrgatherVV => VecOpCategory::OPIVV,
VecAluOpRRR::VmulVV
| VecAluOpRRR::VmulhVV
| VecAluOpRRR::VmulhuVV
Expand All @@ -343,7 +345,8 @@ impl VecAluOpRRR {
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM
| VecAluOpRRR::VmsltVX => VecOpCategory::OPIVX,
| VecAluOpRRR::VmsltVX
| VecAluOpRRR::VrgatherVX => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
Expand All @@ -368,6 +371,14 @@ impl VecAluOpRRR {
_ => unreachable!(),
}
}

/// Some instructions do not allow the source and destination registers to overlap.
pub fn forbids_src_dst_overlaps(&self) -> bool {
match self {
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => true,
_ => false,
}
}
}

impl fmt::Display for VecAluOpRRR {
Expand Down Expand Up @@ -408,6 +419,7 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VmergeVIM => 0b010111,
VecAluOpRRImm5::VsadduVI => 0b100000,
VecAluOpRRImm5::VsaddVI => 0b100001,
VecAluOpRRImm5::VrgatherVI => 0b001100,
}
}

Expand All @@ -424,7 +436,8 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VsaddVI
| VecAluOpRRImm5::VrgatherVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -433,7 +446,8 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VsllVI
| VecAluOpRRImm5::VsrlVI
| VecAluOpRRImm5::VsraVI
| VecAluOpRRImm5::VslidedownVI => true,
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VrgatherVI => true,
VecAluOpRRImm5::VaddVI
| VecAluOpRRImm5::VrsubVI
| VecAluOpRRImm5::VandVI
Expand All @@ -444,6 +458,14 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VsaddVI => false,
}
}

/// Some instructions do not allow the source and destination registers to overlap.
pub fn forbids_src_dst_overlaps(&self) -> bool {
match self {
VecAluOpRRImm5::VrgatherVI => true,
_ => false,
}
}
}

impl fmt::Display for VecAluOpRRImm5 {
Expand Down
40 changes: 39 additions & 1 deletion cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
(VmergeVVM)
(VredmaxuVS)
(VredminuVS)
(VrgatherVV)

;; Vector-Scalar Opcodes
(VaddVX)
Expand Down Expand Up @@ -145,6 +146,7 @@
(VfrdivVF)
(VmergeVXM)
(VfmergeVFM)
(VrgatherVX)
(VmsltVX)
))

Expand All @@ -163,6 +165,7 @@
(VxorVI)
(VslidedownVI)
(VmergeVIM)
(VrgatherVI)
))

;; Imm only ALU Ops
Expand Down Expand Up @@ -718,6 +721,25 @@
(rule (rv_vredmaxu_vs vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VredmaxuVS) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vv` instruction.
;;
;; vd[i] = (vs1[i] >= VLMAX) ? 0 : vs2[vs1[i]];
(decl rv_vrgather_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vrgather_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrgatherVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vx` instruction.
;;
;; vd[i] = (x[rs1] >= VLMAX) ? 0 : vs2[x[rs1]]
(decl rv_vrgather_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vrgather_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrgatherVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vrgather.vi` instruction.
(decl rv_vrgather_vi (VReg UImm5 VecOpMasking VState) VReg)
(rule (rv_vrgather_vi vs2 imm mask vstate)
(vec_alu_rr_uimm5 (VecAluOpRRImm5.VrgatherVI) vs2 imm mask vstate))

;; Helper for emitting the `vmslt.vx` (Vector Mask Set Less Than) instruction.
(decl rv_vmslt_vx (VReg XReg VecOpMasking VState) VReg)
(rule (rv_vmslt_vx vs2 vs1 mask vstate)
Expand Down Expand Up @@ -757,4 +779,20 @@
;; Materialize the mask into an X register, and move it into the bottom of
;; the vector register.
(rule (gen_vec_mask mask)
(rv_vmv_sx (imm $I64 mask) (vstate_from_type $I64X2)))
(rv_vmv_sx (imm $I64 mask) (vstate_from_type $I64X2)))


;; Loads a `VCodeConstant` value into a vector register. For some special `VCodeConstant`s
;; we can use a dedicated instruction, otherwise we load the value from the pool.
;;
;; Type is the preferred type to use when loading the constant.
(decl gen_constant (Type VCodeConstant) VReg)

;; The fallback case is to load the constant from the pool.
(rule (gen_constant ty n)
(vec_load
(element_width_from_type ty)
(VecAMode.UnitStride (gen_const_amode n))
(mem_flags_trusted)
(unmasked)
ty))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: missing trailing newline

33 changes: 27 additions & 6 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
;; ;;;; Rules for `vconst` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (ty_vec_fits_in_register ty) (vconst n)))
(vec_load
(element_width_from_type ty)
(VecAMode.UnitStride (gen_const_amode (const_to_vconst n)))
(mem_flags_trusted)
(unmasked)
ty))
(gen_constant ty (const_to_vconst n)))

;;;; Rules for `f32const` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down Expand Up @@ -1407,3 +1402,29 @@
;; use the original type as a VState and avoid a state change.
(x_mask XReg (rv_vmv_xs mask (vstate_from_type $I64X2))))
(gen_andi x_mask (ty_lane_mask ty))))

;;;; Rules for `swizzle` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x y)))
(rv_vrgather_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x (splat y))))
(rv_vrgather_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (swizzle x (replicated_uimm5 y))))
(rv_vrgather_vi x y (unmasked) ty))

;;;; Rules for `shuffle` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; Use a vrgather to load all 0-15 lanes from x. And then modify the mask to load all
;; 16-31 lanes from y. Finally, use a vor to combine the two vectors.
;;
;; vrgather will insert a 0 for lanes that are out of bounds, so we can let it load
;; negative and out of bounds indexes.
(rule (lower (has_type (ty_vec_fits_in_register ty @ $I8X16) (shuffle x y (vconst_from_immediate mask))))
(if-let neg16 (imm5_from_i8 -16))
(let ((x_mask VReg (gen_constant ty mask))
(x_lanes VReg (rv_vrgather_vv x x_mask (unmasked) ty))
(y_mask VReg (rv_vadd_vi x_mask neg16 (unmasked) ty))
(y_lanes VReg (rv_vrgather_vv y y_mask (unmasked) ty)))
(rv_vor_vv x_lanes y_lanes (unmasked) ty)))
7 changes: 7 additions & 0 deletions cranelift/codegen/src/machinst/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ macro_rules! isle_lower_prelude_methods {
Some(u128::from_le_bytes(bytes.try_into().ok()?))
}

#[inline]
fn vconst_from_immediate(&mut self, imm: Immediate) -> Option<VCodeConstant> {
Some(self.lower_ctx.use_constant(VCodeConstantData::Generated(
self.lower_ctx.get_immediate_data(imm).clone(),
)))
}

#[inline]
fn vec_mask_from_immediate(&mut self, imm: Immediate) -> Option<VecMask> {
let data = self.lower_ctx.get_immediate_data(imm);
Expand Down
5 changes: 5 additions & 0 deletions cranelift/codegen/src/prelude_lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,11 @@
(decl u128_from_immediate (u128) Immediate)
(extern extractor u128_from_immediate u128_from_immediate)

;; Extracts an `Immediate` as a `VCodeConstant`.

(decl vconst_from_immediate (VCodeConstant) Immediate)
(extern extractor vconst_from_immediate vconst_from_immediate)

;; Accessor for `Constant` as u128.

(decl u128_from_constant (u128) Constant)
Expand Down
61 changes: 61 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-shuffle.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
test compile precise-output
set unwind_info=false
target riscv64 has_v

function %shuffle_i8x16(i8x16, i8x16) -> i8x16 {
block0(v0: i8x16, v1: i8x16):
v2 = shuffle v0, v1, [3 0 31 26 4 6 12 11 23 13 24 4 2 15 17 5]
return v2
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v3,32(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v6,[const(0)] #avl=16, #vtype=(e8, m1, ta, ma)
; vrgather.vv v8,v1,v6 #avl=16, #vtype=(e8, m1, ta, ma)
; vadd.vi v10,v6,-16 #avl=16, #vtype=(e8, m1, ta, ma)
; vrgather.vv v12,v3,v10 #avl=16, #vtype=(e8, m1, ta, ma)
; vor.vv v14,v8,v12 #avl=16, #vtype=(e8, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; addi t6, s0, 0x20
; .byte 0x87, 0x81, 0x0f, 0x02
; auipc t6, 0
; addi t6, t6, 0x3c
; .byte 0x07, 0x83, 0x0f, 0x02
; .byte 0x57, 0x04, 0x13, 0x32
; .byte 0x57, 0x35, 0x68, 0x02
; .byte 0x57, 0x06, 0x35, 0x32
; .byte 0x57, 0x07, 0x86, 0x2a
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; lb zero, 0x1a1(t5)
; .byte 0x04, 0x06, 0x0c, 0x0b
; auipc s10, 0x4180
; .byte 0x02, 0x0f, 0x11, 0x05

Loading