Skip to content

Commit

Permalink
Caching constant system
Browse files Browse the repository at this point in the history
  • Loading branch information
khyperia committed Sep 16, 2020
1 parent 0a2629e commit 85e7a6a
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 250 deletions.
90 changes: 44 additions & 46 deletions rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Builder;
use crate::builder_spirv::{BuilderCursor, SpirvValueExt};
use crate::builder_spirv::{BuilderCursor, SpirvConst, SpirvValueExt};
use crate::spirv_type::SpirvType;
use rspirv::dr::{InsertPoint, Instruction, Operand};
use rspirv::spirv::{MemorySemantics, Op, Scope, StorageClass};
Expand Down Expand Up @@ -316,8 +316,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
lhs: Self::Value,
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
let bool = SpirvType::Bool.def(self);
let fals = self.emit_global().constant_false(bool).with_type(bool);
let fals = self.constant_bool(false);
match oop {
OverflowOp::Add => (self.add(lhs, rhs), fals),
OverflowOp::Sub => (self.sub(lhs, rhs), fals),
Expand Down Expand Up @@ -924,24 +923,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
use RealPredicate::*;
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self);
let mut e = self.emit();
match op {
RealPredicateFalse => return e.constant_false(b).with_type(b),
RealPredicateTrue => return e.constant_true(b).with_type(b),
RealOEQ => e.f_ord_equal(b, None, lhs.def, rhs.def),
RealOGT => e.f_ord_greater_than(b, None, lhs.def, rhs.def),
RealOGE => e.f_ord_greater_than_equal(b, None, lhs.def, rhs.def),
RealOLT => e.f_ord_less_than(b, None, lhs.def, rhs.def),
RealOLE => e.f_ord_less_than_equal(b, None, lhs.def, rhs.def),
RealONE => e.f_ord_not_equal(b, None, lhs.def, rhs.def),
RealORD => e.ordered(b, None, lhs.def, rhs.def),
RealUNO => e.unordered(b, None, lhs.def, rhs.def),
RealUEQ => e.f_unord_equal(b, None, lhs.def, rhs.def),
RealUGT => e.f_unord_greater_than(b, None, lhs.def, rhs.def),
RealUGE => e.f_unord_greater_than_equal(b, None, lhs.def, rhs.def),
RealULT => e.f_unord_less_than(b, None, lhs.def, rhs.def),
RealULE => e.f_unord_less_than_equal(b, None, lhs.def, rhs.def),
RealUNE => e.f_unord_not_equal(b, None, lhs.def, rhs.def),
RealPredicateFalse => return self.cx.constant_bool(false),
RealPredicateTrue => return self.cx.constant_bool(true),
RealOEQ => self.emit().f_ord_equal(b, None, lhs.def, rhs.def),
RealOGT => self.emit().f_ord_greater_than(b, None, lhs.def, rhs.def),
RealOGE => self
.emit()
.f_ord_greater_than_equal(b, None, lhs.def, rhs.def),
RealOLT => self.emit().f_ord_less_than(b, None, lhs.def, rhs.def),
RealOLE => self.emit().f_ord_less_than_equal(b, None, lhs.def, rhs.def),
RealONE => self.emit().f_ord_not_equal(b, None, lhs.def, rhs.def),
RealORD => self.emit().ordered(b, None, lhs.def, rhs.def),
RealUNO => self.emit().unordered(b, None, lhs.def, rhs.def),
RealUEQ => self.emit().f_unord_equal(b, None, lhs.def, rhs.def),
RealUGT => self.emit().f_unord_greater_than(b, None, lhs.def, rhs.def),
RealUGE => self
.emit()
.f_unord_greater_than_equal(b, None, lhs.def, rhs.def),
RealULT => self.emit().f_unord_less_than(b, None, lhs.def, rhs.def),
RealULE => self
.emit()
.f_unord_less_than_equal(b, None, lhs.def, rhs.def),
RealUNE => self.emit().f_unord_not_equal(b, None, lhs.def, rhs.def),
}
.unwrap()
.with_type(b)
Expand Down Expand Up @@ -995,14 +999,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.builder.lookup_const(fill_byte.def),
self.builder.lookup_const(size.def),
) {
(Ok(fill_byte), Ok(size)) => {
(Some(fill_byte), Some(size)) => {
let fill_byte = match fill_byte {
Operand::LiteralInt32(v) => v as u8,
other => panic!("memset fill_byte constant value not supported: {}", other),
SpirvConst::U32(_, v) => v as u8,
other => panic!("memset fill_byte constant value not supported: {:?}", other),
};
let size = match size {
Operand::LiteralInt32(v) => v as usize,
other => panic!("memset size constant value not supported: {}", other),
SpirvConst::U32(_, v) => v as usize,
other => panic!("memset size constant value not supported: {:?}", other),
};
let pat = elem_ty_spv
.memset_const_pattern(self, fill_byte)
Expand All @@ -1021,13 +1025,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
}
}
(Ok(_fill_byte), Err(_)) => {
(Some(_fill_byte), None) => {
panic!("memset constant fill_byte dynamic size not implemented yet")
}
(Err(_), Ok(size)) => {
(None, Some(size)) => {
let size = match size {
Operand::LiteralInt32(v) => v as usize,
other => panic!("memset size constant value not supported: {}", other),
SpirvConst::U32(_, v) => v as usize,
other => panic!("memset size constant value not supported: {:?}", other),
};
let pat = elem_ty_spv
.memset_dynamic_pattern(self, fill_byte.def)
Expand All @@ -1046,7 +1050,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
}
}
(Err(_), Err(_)) => panic!("memset dynamic fill_byte dynamic size not implemented yet"),
(None, None) => panic!("memset dynamic fill_byte dynamic size not implemented yet"),
}
}

Expand Down Expand Up @@ -1074,13 +1078,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
other => panic!("extract_element not implemented on type {:?}", other),
};
match self.builder.lookup_const_u64(idx.def) {
Ok(const_index) => self.emit().composite_extract(
Some(const_index) => self.emit().composite_extract(
result_type,
None,
vec.def,
[const_index as u32].iter().cloned(),
),
Err(_) => self
None => self
.emit()
.vector_extract_dynamic(result_type, None, vec.def, idx.def),
}
Expand All @@ -1094,15 +1098,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
count: num_elts as u32,
}
.def(self);
if self.builder.lookup_const(elt.def).is_ok() {
self.emit()
.constant_composite(result_type, std::iter::repeat(elt.def).take(num_elts))
if self.builder.lookup_const(elt.def).is_some() {
self.constant_composite(result_type, vec![elt.def; num_elts])
} else {
self.emit()
.composite_construct(result_type, None, std::iter::repeat(elt.def).take(num_elts))
.unwrap()
.with_type(result_type)
}
.with_type(result_type)
}

fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value {
Expand Down Expand Up @@ -1315,16 +1318,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let (result_type, argument_types) = loop {
match self.lookup_type(llfn.ty) {
SpirvType::Pointer { pointee, .. } => {
llfn = match self.builder.lookup_global_constant_variable(llfn.def) {
// constant, known deref
Ok(v) => v.with_type(pointee),
// dynamic deref
Err(_) => self
.emit()
.load(pointee, None, llfn.def, None, empty())
.unwrap()
.with_type(pointee),
}
llfn = self
.emit()
.load(pointee, None, llfn.def, None, empty())
.unwrap()
.with_type(pointee)
}
SpirvType::Function {
return_type,
Expand Down
2 changes: 1 addition & 1 deletion rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
pointee: result_pointee_type,
}
.def(self);
if self.builder.lookup_const_u64(indices[0].def) == Ok(0) {
if self.builder.lookup_const_u64(indices[0].def) == Some(0) {
if is_inbounds {
self.emit()
.in_bounds_access_chain(result_type, None, ptr.def, result_indices)
Expand Down
151 changes: 71 additions & 80 deletions rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use rspirv::dr::{Block, Builder, Module, Operand};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
use rspirv::{binary::Assemble, binary::Disassemble};
use std::cell::{RefCell, RefMut};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::{fs::File, io::Write, path::Path};

#[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq)]
Expand All @@ -20,6 +22,20 @@ impl SpirvValueExt for Word {
}
}

#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvConst {
U32(Word, u32),
U64(Word, u64),
/// f32 isn't hash, so store bits
F32(Word, u32),
/// f64 isn't hash, so store bits
F64(Word, u64),
Bool(Word, bool),
Composite(Word, Vec<Word>),
Null(Word),
Undef(Word),
}

#[derive(Debug, Default, Copy, Clone)]
#[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"]
pub struct BuilderCursor {
Expand All @@ -29,6 +45,8 @@ pub struct BuilderCursor {

pub struct BuilderSpirv {
builder: RefCell<Builder>,
constants: RefCell<HashMap<SpirvConst, SpirvValue>>,
constants_inverse: RefCell<HashMap<Word, SpirvConst>>,
}

impl BuilderSpirv {
Expand All @@ -49,6 +67,8 @@ impl BuilderSpirv {
builder.memory_model(AddressingModel::Physical32, MemoryModel::OpenCL);
Self {
builder: RefCell::new(builder),
constants: Default::default(),
constants_inverse: Default::default(),
}
}

Expand Down Expand Up @@ -109,94 +129,65 @@ impl BuilderSpirv {
panic!("Function not found: {}", id);
}

pub fn def_constant(&self, ty: Word, val: Operand) -> SpirvValue {
let mut builder = self.builder.borrow_mut();
// TODO: Cache these instead of doing a search.
for inst in &builder.module_ref().types_global_values {
if inst.class.opcode == Op::Constant
&& inst.result_type == Some(ty)
&& inst.operands[0] == val
{
return inst.result_id.unwrap().with_type(ty);
}
}
match val {
Operand::LiteralInt32(v) => builder.constant_u32(ty, v),
Operand::LiteralInt64(v) => builder.constant_u64(ty, v),
Operand::LiteralFloat32(v) => builder.constant_f32(ty, v),
Operand::LiteralFloat64(v) => builder.constant_f64(ty, v),
unknown => panic!("def_constant doesn't support constant type {}", unknown),
}
.with_type(ty)
}

pub fn lookup_const_u64(&self, def: Word) -> Result<u64, &'static str> {
match self.lookup_const(def)? {
Operand::LiteralInt32(v) => Ok(v as u64),
Operand::LiteralInt64(v) => Ok(v),
_ => Err("Literal value not Int32/64"),
}
}

pub fn lookup_const(&self, def: Word) -> Result<Operand, &'static str> {
let builder = self.builder.borrow();
for inst in &builder.module_ref().types_global_values {
if inst.result_id == Some(def) {
return if inst.class.opcode == Op::Constant {
Ok(inst.operands[0].clone())
} else {
Err("Instruction not OpConstant")
};
}
}
Err("Definition not found")
}

pub fn lookup_const_bool(&self, def: Word) -> Result<bool, &'static str> {
let builder = self.builder.borrow();
for inst in &builder.module_ref().types_global_values {
if inst.result_id == Some(def) {
return match inst.class.opcode {
Op::ConstantFalse => Ok(true),
Op::ConstantTrue => Ok(true),
_ => Err("Instruction not OpConstantTrue/False"),
pub fn def_constant(&self, val: SpirvConst) -> SpirvValue {
match self.constants.borrow_mut().entry(val) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let id = match *entry.key() {
SpirvConst::U32(ty, v) => {
self.builder.borrow_mut().constant_u32(ty, v).with_type(ty)
}
SpirvConst::U64(ty, v) => {
self.builder.borrow_mut().constant_u64(ty, v).with_type(ty)
}
SpirvConst::F32(ty, v) => self
.builder
.borrow_mut()
.constant_f32(ty, f32::from_bits(v))
.with_type(ty),
SpirvConst::F64(ty, v) => self
.builder
.borrow_mut()
.constant_f64(ty, f64::from_bits(v))
.with_type(ty),
SpirvConst::Bool(ty, v) => {
if v {
self.builder.borrow_mut().constant_true(ty).with_type(ty)
} else {
self.builder.borrow_mut().constant_false(ty).with_type(ty)
}
}
SpirvConst::Composite(ty, ref v) => self
.builder
.borrow_mut()
.constant_composite(ty, v.iter().copied())
.with_type(ty),
SpirvConst::Null(ty) => {
self.builder.borrow_mut().constant_null(ty).with_type(ty)
}
SpirvConst::Undef(ty) => {
self.builder.borrow_mut().undef(ty, None).with_type(ty)
}
};
self.constants_inverse
.borrow_mut()
.insert(id.def, entry.key().clone());
entry.insert(id);
id
}
}
Err("Definition not found")
}

pub fn lookup_global_constant_variable(&self, def: Word) -> Result<Word, &'static str> {
// TODO: Maybe assert that this indeed a constant?
let builder = self.builder.borrow();
for inst in &builder.module_ref().types_global_values {
if inst.result_id == Some(def) {
return if inst.class.opcode == Op::Variable {
if let Some(&Operand::IdRef(id_ref)) = inst.operands.get(1) {
Ok(id_ref)
} else {
Err("Instruction had no initializer")
}
} else {
Err("Instruction not OpVariable")
};
}
}
Err("Definition not found")
pub fn lookup_const(&self, def: Word) -> Option<SpirvConst> {
self.constants_inverse.borrow().get(&def).cloned()
}

pub fn find_global_constant_variable(&self, value: Word) -> Option<SpirvValue> {
let builder = self.builder.borrow();
for inst in &builder.module_ref().types_global_values {
if inst.class.opcode == Op::Variable {
if let Some(&Operand::IdRef(id_ref)) = inst.operands.get(1) {
if id_ref == value {
return Some(inst.result_id.unwrap().with_type(inst.result_type.unwrap()));
}
}
}
pub fn lookup_const_u64(&self, def: Word) -> Option<u64> {
match self.lookup_const(def)? {
SpirvConst::U32(_, v) => Some(v as u64),
SpirvConst::U64(_, v) => Some(v),
_ => None,
}
None
}

pub fn set_global_initializer(&self, global: Word, initializer: Word) {
Expand Down
Loading

0 comments on commit 85e7a6a

Please sign in to comment.