From d6788d49b047d155a0e5c7335e364b4857e24914 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 12 Nov 2018 02:48:58 -0800 Subject: [PATCH] second stage of shader parsing works for vulkan_minimal_compute's shader --- shader-compiler/src/lib.rs | 705 +++++++++++++++++++++---- spirv-parser-generator/src/generate.rs | 17 +- 2 files changed, 621 insertions(+), 101 deletions(-) diff --git a/shader-compiler/src/lib.rs b/shader-compiler/src/lib.rs index d0d1987..2e1a476 100644 --- a/shader-compiler/src/lib.rs +++ b/shader-compiler/src/lib.rs @@ -1,27 +1,38 @@ // SPDX-License-Identifier: LGPL-2.1-or-later // Copyright 2018 Jacob Lifshay -#[macro_use] extern crate shader_compiler_backend; extern crate spirv_parser; use spirv_parser::{ BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction, StorageClass, }; -use std::error; +use std::cell::RefCell; +use std::collections::HashSet; use std::fmt; +use std::hash::{Hash, Hasher}; use std::mem; use std::ops::{Index, IndexMut}; use std::rc::Rc; -#[derive(Default)] pub struct Context { types: pointer_type::ContextTypes, + next_struct_id: usize, +} + +impl Default for Context { + fn default() -> Context { + Context { + types: Default::default(), + next_struct_id: 0, + } + } } mod pointer_type { use super::{Context, Type}; use std::cell::RefCell; + use std::fmt; use std::hash::{Hash, Hasher}; use std::rc::{Rc, Weak}; @@ -35,11 +46,23 @@ mod pointer_type { Unresolved, } - #[derive(Clone, Debug)] + #[derive(Clone)] pub struct PointerType { pointee: RefCell, } + impl fmt::Debug for PointerType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut state = f.debug_struct("PointerType"); + if let PointerTypeState::Unresolved = *self.pointee.borrow() { + state.field("pointee", &PointerTypeState::Unresolved); + } else { + state.field("pointee", &self.pointee()); + } + state.finish() + } + } + impl PointerType { pub fn new(context: &mut Context, pointee: Option>) -> Self { Self { @@ -121,24 +144,99 @@ pub enum ScalarType { } #[derive(Clone, Eq, PartialEq, Hash, Debug)] -pub enum Type { - Scalar(ScalarType), - Vector { - element: ScalarType, - element_count: usize, - }, +pub struct VectorType { + pub element: ScalarType, + pub element_count: usize, } -#[derive(Debug)] -pub struct NotAPointer; +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct StructMember { + pub decorations: Vec, + pub member_type: Rc, +} + +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub struct StructId(usize); + +impl StructId { + pub fn new(context: &mut Context) -> Self { + let retval = StructId(context.next_struct_id); + context.next_struct_id += 1; + retval + } +} + +#[derive(Clone)] +pub struct StructType { + pub id: StructId, + pub decorations: Vec, + pub members: Vec, +} -impl fmt::Display for NotAPointer { +impl Eq for StructType {} + +impl PartialEq for StructType { + fn eq(&self, rhs: &Self) -> bool { + self.id == rhs.id + } +} + +impl Hash for StructType { + fn hash(&self, h: &mut H) { + self.id.hash(h) + } +} + +impl fmt::Debug for StructType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "not a pointer") + thread_local! { + static CURRENTLY_FORMATTING: RefCell> = RefCell::new(HashSet::new()); + } + struct CurrentlyFormatting { + id: StructId, + was_formatting: bool, + } + impl CurrentlyFormatting { + fn new(id: StructId) -> Self { + let was_formatting = CURRENTLY_FORMATTING + .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id)); + Self { id, was_formatting } + } + } + impl Drop for CurrentlyFormatting { + fn drop(&mut self) { + if !self.was_formatting { + CURRENTLY_FORMATTING.with(|currently_formatting| { + currently_formatting.borrow_mut().remove(&self.id); + }); + } + } + } + let currently_formatting = CurrentlyFormatting::new(self.id); + let mut state = f.debug_struct("StructType"); + state.field("id", &self.id); + if !currently_formatting.was_formatting { + state.field("decorations", &self.decorations); + state.field("members", &self.members); + } + state.finish() } } -impl error::Error for NotAPointer {} +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct ArrayType { + pub decorations: Vec, + pub element: Rc, + pub element_count: Option, +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub enum Type { + Scalar(ScalarType), + Vector(VectorType), + Struct(StructType), + Array(ArrayType), +} impl Type { pub fn is_pointer(&self) -> bool { @@ -148,53 +246,234 @@ impl Type { false } } - pub fn get_pointee(&self) -> Result>, NotAPointer> { + pub fn is_scalar(&self) -> bool { + if let Type::Scalar(_) = self { + true + } else { + false + } + } + pub fn is_vector(&self) -> bool { + if let Type::Vector(_) = self { + true + } else { + false + } + } + pub fn get_pointee(&self) -> Option> { if let Type::Scalar(ScalarType::Pointer(pointer)) = self { - Ok(pointer.pointee()) + pointer.pointee() } else { - Err(NotAPointer) + unreachable!("not a pointer") } } pub fn get_nonvoid_pointee(&self) -> Rc { - self.get_pointee() - .unwrap() - .expect("void is not allowed here") + self.get_pointee().expect("void is not allowed here") + } + pub fn get_scalar(&self) -> &ScalarType { + if let Type::Scalar(scalar) = self { + scalar + } else { + unreachable!("not a scalar type") + } + } + pub fn get_vector(&self) -> &VectorType { + if let Type::Vector(vector) = self { + vector + } else { + unreachable!("not a vector type") + } + } +} + +/// value that can be either defined or undefined +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub enum Undefable { + Undefined, + Defined(T), +} + +impl Undefable { + pub fn unwrap(self) -> T { + match self { + Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"), + Undefable::Defined(v) => v, + } + } +} + +impl From for Undefable { + fn from(v: T) -> Undefable { + Undefable::Defined(v) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum ScalarConstant { + U8(Undefable), + U16(Undefable), + U32(Undefable), + U64(Undefable), + I8(Undefable), + I16(Undefable), + I32(Undefable), + I64(Undefable), + F16(Undefable), + F32(Undefable), + F64(Undefable), + Bool(Undefable), +} + +macro_rules! define_scalar_vector_constant_impl_without_from { + ($type:ident, $name:ident, $get_name:ident) => { + impl ScalarConstant { + pub fn $get_name(self) -> Undefable<$type> { + match self { + ScalarConstant::$name(v) => v, + _ => unreachable!(concat!("expected a constant ", stringify!($type))), + } + } + } + impl VectorConstant { + pub fn $get_name(&self) -> &Vec> { + match self { + VectorConstant::$name(v) => v, + _ => unreachable!(concat!( + "expected a constant vector with ", + stringify!($type), + " elements" + )), + } + } + } + }; +} + +macro_rules! define_scalar_vector_constant_impl { + ($type:ident, $name:ident, $get_name:ident) => { + define_scalar_vector_constant_impl_without_from!($type, $name, $get_name); + impl From> for ScalarConstant { + fn from(v: Undefable<$type>) -> ScalarConstant { + ScalarConstant::$name(v) + } + } + impl From>> for VectorConstant { + fn from(v: Vec>) -> VectorConstant { + VectorConstant::$name(v) + } + } + }; +} + +define_scalar_vector_constant_impl!(u8, U8, get_u8); +define_scalar_vector_constant_impl!(u16, U16, get_u16); +define_scalar_vector_constant_impl!(u32, U32, get_u32); +define_scalar_vector_constant_impl!(u64, U64, get_u64); +define_scalar_vector_constant_impl!(i8, I8, get_i8); +define_scalar_vector_constant_impl!(i16, I16, get_i16); +define_scalar_vector_constant_impl!(i32, I32, get_i32); +define_scalar_vector_constant_impl!(i64, I64, get_i64); +define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16); +define_scalar_vector_constant_impl!(f32, F32, get_f32); +define_scalar_vector_constant_impl!(f64, F64, get_f64); +define_scalar_vector_constant_impl!(bool, Bool, get_bool); + +impl ScalarConstant { + pub fn get_type(self) -> Type { + Type::Scalar(self.get_scalar_type()) + } + pub fn get_scalar_type(self) -> ScalarType { + match self { + ScalarConstant::U8(_) => ScalarType::U8, + ScalarConstant::U16(_) => ScalarType::U16, + ScalarConstant::U32(_) => ScalarType::U32, + ScalarConstant::U64(_) => ScalarType::U64, + ScalarConstant::I8(_) => ScalarType::I8, + ScalarConstant::I16(_) => ScalarType::I16, + ScalarConstant::I32(_) => ScalarType::I32, + ScalarConstant::I64(_) => ScalarType::I64, + ScalarConstant::F16(_) => ScalarType::F16, + ScalarConstant::F32(_) => ScalarType::F32, + ScalarConstant::F64(_) => ScalarType::F64, + ScalarConstant::Bool(_) => ScalarType::Bool, + } + } +} + +#[derive(Clone, Debug)] +pub enum VectorConstant { + U8(Vec>), + U16(Vec>), + U32(Vec>), + U64(Vec>), + I8(Vec>), + I16(Vec>), + I32(Vec>), + I64(Vec>), + F16(Vec>), + F32(Vec>), + F64(Vec>), + Bool(Vec>), +} + +impl VectorConstant { + pub fn get_element_type(&self) -> ScalarType { + match self { + VectorConstant::U8(_) => ScalarType::U8, + VectorConstant::U16(_) => ScalarType::U16, + VectorConstant::U32(_) => ScalarType::U32, + VectorConstant::U64(_) => ScalarType::U64, + VectorConstant::I8(_) => ScalarType::I8, + VectorConstant::I16(_) => ScalarType::I16, + VectorConstant::I32(_) => ScalarType::I32, + VectorConstant::I64(_) => ScalarType::I64, + VectorConstant::F16(_) => ScalarType::F16, + VectorConstant::F32(_) => ScalarType::F32, + VectorConstant::F64(_) => ScalarType::F64, + VectorConstant::Bool(_) => ScalarType::Bool, + } + } + pub fn get_element_count(&self) -> usize { + match self { + VectorConstant::U8(v) => v.len(), + VectorConstant::U16(v) => v.len(), + VectorConstant::U32(v) => v.len(), + VectorConstant::U64(v) => v.len(), + VectorConstant::I8(v) => v.len(), + VectorConstant::I16(v) => v.len(), + VectorConstant::I32(v) => v.len(), + VectorConstant::I64(v) => v.len(), + VectorConstant::F16(v) => v.len(), + VectorConstant::F32(v) => v.len(), + VectorConstant::F64(v) => v.len(), + VectorConstant::Bool(v) => v.len(), + } + } + pub fn get_type(&self) -> Type { + Type::Vector(VectorType { + element: self.get_element_type(), + element_count: self.get_element_count(), + }) } } #[derive(Clone, Debug)] pub enum Constant { - Undef(Rc), - U8(u8), - U16(u16), - U32(u32), - U64(u64), - I8(i8), - I16(i16), - I32(i32), - I64(i64), - F16(u16), - F32(f32), - F64(f64), - Bool(bool), + Scalar(ScalarConstant), + Vector(VectorConstant), } impl Constant { - pub fn get_type(&self) -> &Type { + pub fn get_type(&self) -> Type { + match self { + Constant::Scalar(v) => v.get_type(), + Constant::Vector(v) => v.get_type(), + } + } + pub fn get_scalar(&self) -> &ScalarConstant { match self { - Constant::Undef(t) => &*t, - Constant::U8(_) => &Type::Scalar(ScalarType::U8), - Constant::U16(_) => &Type::Scalar(ScalarType::U16), - Constant::U32(_) => &Type::Scalar(ScalarType::U32), - Constant::U64(_) => &Type::Scalar(ScalarType::U64), - Constant::I8(_) => &Type::Scalar(ScalarType::I8), - Constant::I16(_) => &Type::Scalar(ScalarType::I16), - Constant::I32(_) => &Type::Scalar(ScalarType::I32), - Constant::I64(_) => &Type::Scalar(ScalarType::I64), - Constant::F16(_) => &Type::Scalar(ScalarType::F16), - Constant::F32(_) => &Type::Scalar(ScalarType::F32), - Constant::F64(_) => &Type::Scalar(ScalarType::F64), - Constant::Bool(_) => &Type::Scalar(ScalarType::Bool), + Constant::Scalar(v) => v, + _ => unreachable!("not a scalar constant"), } } } @@ -213,15 +492,22 @@ struct BuiltInVariable { impl BuiltInVariable { fn get_type(&self, _context: &mut Context) -> Rc { match self.built_in { - BuiltIn::GlobalInvocationId => Rc::new(Type::Vector { + BuiltIn::GlobalInvocationId => Rc::new(Type::Vector(VectorType { element: ScalarType::U32, element_count: 3, - }), + })), _ => unreachable!("unknown built-in"), } } } +#[derive(Debug, Clone)] +struct UniformVariable { + binding: u32, + descriptor_set: u32, + variable_type: Rc, +} + #[derive(Debug)] enum IdKind { Undefined, @@ -234,7 +520,8 @@ enum IdKind { }, ForwardPointer(Rc), BuiltInVariable(BuiltInVariable), - Constant(Constant), + Constant(Rc), + UniformVariable(UniformVariable), } #[derive(Debug)] @@ -245,6 +532,13 @@ struct IdProperties { } impl IdProperties { + fn is_empty(&self) -> bool { + match self.kind { + IdKind::Undefined => {} + _ => return false, + } + self.decorations.is_empty() && self.member_decorations.is_empty() + } fn set_kind(&mut self, kind: IdKind) { match &self.kind { IdKind::Undefined => {} @@ -252,16 +546,22 @@ impl IdProperties { } self.kind = kind; } - fn get_type(&self) -> Option> { + fn get_type(&self) -> Option<&Rc> { match &self.kind { - IdKind::Type(t) => Some(t.clone()), + IdKind::Type(t) => Some(t), IdKind::VoidType => None, _ => unreachable!("id is not type"), } } - fn get_nonvoid_type(&self) -> Rc { + fn get_nonvoid_type(&self) -> &Rc { self.get_type().expect("void is not allowed here") } + fn get_constant(&self) -> &Rc { + match &self.kind { + IdKind::Constant(c) => c, + _ => unreachable!("id is not a constant"), + } + } fn assert_no_member_decorations(&self, id: IdRef) { for member_decoration in &self.member_decorations { unreachable!( @@ -278,9 +578,26 @@ impl IdProperties { } } -#[derive(Debug)] struct Ids(Vec); +impl fmt::Debug for Ids { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map() + .entries( + self.0 + .iter() + .enumerate() + .filter_map(|(id_index, id_properties)| { + if id_properties.is_empty() { + return None; + } + Some((IdRef(id_index as u32), id_properties)) + }), + ) + .finish() + } +} + impl Index for Ids { type Output = IdProperties; fn index(&self, index: IdRef) -> &IdProperties { @@ -298,13 +615,24 @@ struct ParsedShaderFunction { instructions: Vec, } -#[allow(dead_code)] +impl fmt::Debug for ParsedShaderFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ParsedShaderFunction:\n")?; + for instruction in &self.instructions { + write!(f, "{}", instruction)?; + } + Ok(()) + } +} + +#[derive(Debug)] struct ParsedShader { ids: Ids, functions: Vec, main_function_id: IdRef, interface_variables: Vec, execution_modes: Vec, + workgroup_size: Option<(u32, u32, u32)>, } struct ShaderEntryPoint { @@ -341,6 +669,7 @@ impl ParsedShader { let mut current_function: Option = None; let mut functions = Vec::new(); let mut execution_modes = Vec::new(); + let mut workgroup_size = None; for instruction in instructions { match current_function { Some(mut function) => { @@ -441,14 +770,10 @@ impl ParsedShader { } => { ids[id_result.0].assert_no_decorations(id_result.0); let kind = IdKind::FunctionType { - return_type: ids[return_type].get_type(), + return_type: ids[return_type].get_type().map(Clone::clone), arguments: parameter_types .iter() - .map(|argument| { - ids[*argument] - .get_type() - .expect("void is not allowed as a function argument") - }) + .map(|argument| ids[*argument].get_nonvoid_type().clone()) .collect(), }; ids[id_result.0].set_kind(kind); @@ -501,17 +826,11 @@ impl ParsedShader { component_count, } => { ids[id_result.0].assert_no_decorations(id_result.0); - let element = match &*ids[component_type] - .get_type() - .expect("void is not a valid vector element type") - { - Type::Scalar(v) => v.clone(), - _ => unreachable!("vector element type must be a scalar"), - }; - ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Vector { + let element = ids[component_type].get_nonvoid_type().get_scalar().clone(); + ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Vector(VectorType { element, element_count: component_count as usize, - }))); + })))); } Instruction::TypeForwardPointer { pointer_type, .. } => { ids[pointer_type].set_kind(IdKind::ForwardPointer(Rc::new(Type::Scalar( @@ -524,7 +843,7 @@ impl ParsedShader { .. } => { ids[id_result.0].assert_no_decorations(id_result.0); - let pointee = ids[pointee].get_type(); + let pointee = ids[pointee].get_type().map(Clone::clone); let pointer = match mem::replace(&mut ids[id_result.0].kind, IdKind::Undefined) { IdKind::Undefined => Rc::new(Type::Scalar(ScalarType::Pointer( @@ -542,6 +861,49 @@ impl ParsedShader { }; ids[id_result.0].set_kind(IdKind::Type(pointer)); } + Instruction::TypeStruct { + id_result, + member_types, + } => { + let decorations = ids[id_result.0].decorations.clone(); + let struct_type = { + let mut members: Vec<_> = member_types + .into_iter() + .map(|member_type| StructMember { + decorations: Vec::new(), + member_type: match ids[member_type].kind { + IdKind::Type(ref t) => t.clone(), + IdKind::ForwardPointer(ref t) => t.clone(), + _ => unreachable!("invalid struct member type"), + }, + }) + .collect(); + for member_decoration in &ids[id_result.0].member_decorations { + members[member_decoration.member as usize] + .decorations + .push(member_decoration.decoration.clone()); + } + StructType { + id: StructId::new(context), + decorations, + members, + } + }; + ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Struct(struct_type)))); + } + Instruction::TypeRuntimeArray { + id_result, + element_type, + } => { + ids[id_result.0].assert_no_member_decorations(id_result.0); + let decorations = ids[id_result.0].decorations.clone(); + let element = ids[element_type].get_nonvoid_type().clone(); + ids[id_result.0].set_kind(IdKind::Type(Rc::new(Type::Array(ArrayType { + decorations, + element, + element_count: None, + })))); + } Instruction::Variable { id_result_type, id_result, @@ -582,7 +944,46 @@ impl ParsedShader { ); ids[id_result.0].set_kind(IdKind::BuiltInVariable(built_in_variable)); } else { + let variable_type = ids[id_result_type.0].get_nonvoid_type().clone(); match storage_class { + StorageClass::Uniform => { + let mut descriptor_set = None; + let mut binding = None; + for decoration in &ids[id_result.0].decorations { + match *decoration { + Decoration::DescriptorSet { descriptor_set: v } => { + assert!( + descriptor_set.is_none(), + "duplicate DescriptorSet decoration" + ); + descriptor_set = Some(v); + } + Decoration::Binding { binding_point: v } => { + assert!( + binding.is_none(), + "duplicate Binding decoration" + ); + binding = Some(v); + } + _ => unimplemented!( + "unimplemented decoration on uniform variable: {:?}", + decoration + ), + } + } + let descriptor_set = descriptor_set + .expect("uniform variable is missing DescriptorSet decoration"); + let binding = binding + .expect("uniform variable is missing Binding decoration"); + assert!(initializer.is_none()); + ids[id_result.0].set_kind(IdKind::UniformVariable( + UniformVariable { + binding, + descriptor_set, + variable_type, + }, + )); + } StorageClass::Input => unimplemented!(), _ => unimplemented!( "unimplemented OpVariable StorageClass: {:?}", @@ -598,38 +999,54 @@ impl ParsedShader { } => { ids[id_result.0].assert_no_decorations(id_result.0); #[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_lossless))] - let constant = match &*ids[id_result_type.0].get_nonvoid_type() { + let constant = match **ids[id_result_type.0].get_nonvoid_type() { Type::Scalar(ScalarType::U8) => { let converted_value = value as u8; assert_eq!(converted_value as u32, value); - Constant::U8(converted_value) + Constant::Scalar(ScalarConstant::U8(Undefable::Defined( + converted_value, + ))) } Type::Scalar(ScalarType::U16) => { let converted_value = value as u16; assert_eq!(converted_value as u32, value); - Constant::U16(converted_value) + Constant::Scalar(ScalarConstant::U16(Undefable::Defined( + converted_value, + ))) + } + Type::Scalar(ScalarType::U32) => { + Constant::Scalar(ScalarConstant::U32(Undefable::Defined(value))) } - Type::Scalar(ScalarType::U32) => Constant::U32(value), Type::Scalar(ScalarType::I8) => { let converted_value = value as i8; assert_eq!(converted_value as u32, value); - Constant::I8(converted_value) + Constant::Scalar(ScalarConstant::I8(Undefable::Defined( + converted_value, + ))) } Type::Scalar(ScalarType::I16) => { let converted_value = value as i16; assert_eq!(converted_value as u32, value); - Constant::I16(converted_value) + Constant::Scalar(ScalarConstant::I16(Undefable::Defined( + converted_value, + ))) + } + Type::Scalar(ScalarType::I32) => { + Constant::Scalar(ScalarConstant::I32(Undefable::Defined(value as i32))) } - Type::Scalar(ScalarType::I32) => Constant::I32(value as i32), Type::Scalar(ScalarType::F16) => { let converted_value = value as u16; assert_eq!(converted_value as u32, value); - Constant::F16(converted_value) + Constant::Scalar(ScalarConstant::F16(Undefable::Defined( + converted_value, + ))) } - Type::Scalar(ScalarType::F32) => Constant::F32(f32::from_bits(value)), + Type::Scalar(ScalarType::F32) => Constant::Scalar(ScalarConstant::F32( + Undefable::Defined(f32::from_bits(value)), + )), _ => unreachable!("invalid type"), }; - ids[id_result.0].set_kind(IdKind::Constant(constant)); + ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant))); } Instruction::Constant64 { id_result_type, @@ -637,35 +1054,133 @@ impl ParsedShader { value, } => { ids[id_result.0].assert_no_decorations(id_result.0); - let constant = match &*ids[id_result_type.0].get_nonvoid_type() { - Type::Scalar(ScalarType::U64) => Constant::U64(value), - Type::Scalar(ScalarType::I64) => Constant::I64(value as i64), - Type::Scalar(ScalarType::F64) => Constant::F64(f64::from_bits(value)), + let constant = match **ids[id_result_type.0].get_nonvoid_type() { + Type::Scalar(ScalarType::U64) => { + Constant::Scalar(ScalarConstant::U64(Undefable::Defined(value))) + } + Type::Scalar(ScalarType::I64) => { + Constant::Scalar(ScalarConstant::I64(Undefable::Defined(value as i64))) + } + Type::Scalar(ScalarType::F64) => Constant::Scalar(ScalarConstant::F64( + Undefable::Defined(f64::from_bits(value)), + )), _ => unreachable!("invalid type"), }; - ids[id_result.0].set_kind(IdKind::Constant(constant)); + ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant))); } Instruction::ConstantFalse { id_result_type, id_result, } => { ids[id_result.0].assert_no_decorations(id_result.0); - let constant = match &*ids[id_result_type.0].get_nonvoid_type() { - Type::Scalar(ScalarType::Bool) => Constant::Bool(false), + let constant = match **ids[id_result_type.0].get_nonvoid_type() { + Type::Scalar(ScalarType::Bool) => { + Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(false))) + } _ => unreachable!("invalid type"), }; - ids[id_result.0].set_kind(IdKind::Constant(constant)); + ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant))); } Instruction::ConstantTrue { id_result_type, id_result, } => { ids[id_result.0].assert_no_decorations(id_result.0); - let constant = match &*ids[id_result_type.0].get_nonvoid_type() { - Type::Scalar(ScalarType::Bool) => Constant::Bool(true), + let constant = match **ids[id_result_type.0].get_nonvoid_type() { + Type::Scalar(ScalarType::Bool) => { + Constant::Scalar(ScalarConstant::Bool(Undefable::Defined(true))) + } _ => unreachable!("invalid type"), }; - ids[id_result.0].set_kind(IdKind::Constant(constant)); + ids[id_result.0].set_kind(IdKind::Constant(Rc::new(constant))); + } + Instruction::ConstantComposite { + id_result_type, + id_result, + constituents, + } => { + let constant = match **ids[id_result_type.0].get_nonvoid_type() { + Type::Vector(VectorType { + ref element, + element_count, + }) => { + assert_eq!(element_count, constituents.len()); + let constituents = constituents + .iter() + .map(|id| *ids[*id].get_constant().get_scalar()); + match *element { + ScalarType::U8 => { + VectorConstant::U8(constituents.map(|v| v.get_u8()).collect()) + } + ScalarType::U16 => { + VectorConstant::U16(constituents.map(|v| v.get_u16()).collect()) + } + ScalarType::U32 => { + VectorConstant::U32(constituents.map(|v| v.get_u32()).collect()) + } + ScalarType::U64 => { + VectorConstant::U64(constituents.map(|v| v.get_u64()).collect()) + } + ScalarType::I8 => { + VectorConstant::I8(constituents.map(|v| v.get_i8()).collect()) + } + ScalarType::I16 => { + VectorConstant::I16(constituents.map(|v| v.get_i16()).collect()) + } + ScalarType::I32 => { + VectorConstant::I32(constituents.map(|v| v.get_i32()).collect()) + } + ScalarType::I64 => { + VectorConstant::I64(constituents.map(|v| v.get_i64()).collect()) + } + ScalarType::F16 => { + VectorConstant::F16(constituents.map(|v| v.get_f16()).collect()) + } + ScalarType::F32 => { + VectorConstant::F32(constituents.map(|v| v.get_f32()).collect()) + } + ScalarType::F64 => { + VectorConstant::F64(constituents.map(|v| v.get_f64()).collect()) + } + ScalarType::Bool => VectorConstant::Bool( + constituents.map(|v| v.get_bool()).collect(), + ), + ScalarType::Pointer(_) => unimplemented!(), + } + } + _ => unimplemented!(), + }; + for decoration in &ids[id_result.0].decorations { + match decoration { + Decoration::BuiltIn { + built_in: BuiltIn::WorkgroupSize, + } => { + assert!( + workgroup_size.is_none(), + "duplicate WorkgroupSize decorations" + ); + workgroup_size = match constant { + VectorConstant::U32(ref v) => { + assert_eq!( + v.len(), + 3, + "invalid type for WorkgroupSize built-in" + ); + Some((v[0].unwrap(), v[1].unwrap(), v[2].unwrap())) + } + _ => unreachable!("invalid type for WorkgroupSize built-in"), + }; + } + _ => unimplemented!( + "unimplemented decoration on constant {:?}: {:?}", + Constant::Vector(constant), + decoration + ), + } + } + ids[id_result.0].assert_no_member_decorations(id_result.0); + ids[id_result.0] + .set_kind(IdKind::Constant(Rc::new(Constant::Vector(constant)))); } Instruction::MemoryModel { addressing_model, @@ -702,6 +1217,7 @@ impl ParsedShader { main_function_id, interface_variables, execution_modes, + workgroup_size, } } } @@ -741,11 +1257,12 @@ impl ComputePipeline { compute_shader_stage: ShaderStageCreateInfo, ) -> ComputePipeline { let mut context = Context::default(); - let _parsed_shader = ParsedShader::create( + let parsed_shader = ParsedShader::create( &mut context, compute_shader_stage, ExecutionModel::GLCompute, ); + println!("parsed_shader:\n{:#?}", parsed_shader); unimplemented!() } } diff --git a/spirv-parser-generator/src/generate.rs b/spirv-parser-generator/src/generate.rs index a1423b1..4f70960 100644 --- a/spirv-parser-generator/src/generate.rs +++ b/spirv-parser-generator/src/generate.rs @@ -449,7 +449,7 @@ pub(crate) fn generate( let enumerant_parse_operation; if enumerant.parameters.is_empty() { enumerant_items.push(quote!{ - #[derive(Clone, Debug, Default)] + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)] pub struct #type_name; }); enumerant_parse_operation = quote!{(Some(#type_name), words)}; @@ -486,7 +486,7 @@ pub(crate) fn generate( }); } enumerant_items.push(quote!{ - #[derive(Clone, Debug, Default)] + #[derive(Clone, Debug, Default, Eq, PartialEq, Hash)] pub struct #type_name(#(#enumerant_parameter_declarations)*); }); let enumerant_parameter_names = &enumerant_parameter_names; @@ -527,7 +527,7 @@ pub(crate) fn generate( &mut out, "{}", quote!{ - #[derive(Clone, Debug, Default)] + #[derive(Clone, Debug, Default, Eq, PartialEq, Hash)] pub struct #kind_id { #(#enumerant_members),* } @@ -646,12 +646,15 @@ pub(crate) fn generate( }); } } - let mut derives = vec![quote!{Clone}, quote!{Debug}]; + let mut derives = vec![ + quote!{Clone}, + quote!{Debug}, + quote!{Eq}, + quote!{PartialEq}, + quote!{Hash}, + ]; if !has_any_parameters { - derives.push(quote!{Eq}); - derives.push(quote!{PartialEq}); derives.push(quote!{Copy}); - derives.push(quote!{Hash}); } writeln!( &mut out, -- 2.30.2