// SPDX-License-Identifier: LGPL-2.1-or-later // Copyright 2018 Jacob Lifshay extern crate shader_compiler_backend; extern crate spirv_parser; mod parsed_shader_compile; mod parsed_shader_create; use parsed_shader_compile::ParsedShaderCompile; use shader_compiler_backend::Module; use spirv_parser::{BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction}; use std::cell::RefCell; use std::collections::HashSet; use std::fmt; use std::hash::{Hash, Hasher}; use std::iter; use std::ops::{Index, IndexMut}; use std::rc::Rc; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum CompiledFunctionKey { ComputeShaderEntrypoint, } pub struct Context { types: pointer_type::ContextTypes, next_struct_id: usize, } impl Default for Context { fn default() -> Context { Context { types: Default::default(), next_struct_id: 0, } } } mod pointer_type { use super::{Context, Type}; use std::cell::RefCell; use std::fmt; use std::hash::{Hash, Hasher}; use std::rc::{Rc, Weak}; #[derive(Default)] pub struct ContextTypes(Vec>); #[derive(Clone, Debug)] enum PointerTypeState { Void, Normal(Weak), Unresolved, } #[derive(Clone)] pub struct PointerType { pointee: RefCell, } impl fmt::Debug for PointerType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut state = f.debug_struct("PointerType"); if let PointerTypeState::Unresolved = *self.pointee.borrow() { state.field("pointee", &PointerTypeState::Unresolved); } else { state.field("pointee", &self.pointee()); } state.finish() } } impl PointerType { pub fn new(context: &mut Context, pointee: Option>) -> Self { Self { pointee: RefCell::new(match pointee { Some(pointee) => { let weak = Rc::downgrade(&pointee); context.types.0.push(pointee); PointerTypeState::Normal(weak) } None => PointerTypeState::Void, }), } } pub fn new_void() -> Self { Self { pointee: RefCell::new(PointerTypeState::Void), } } pub fn unresolved() -> Self { Self { pointee: RefCell::new(PointerTypeState::Unresolved), } } pub fn resolve(&self, context: &mut Context, new_pointee: Option>) { let mut pointee = self.pointee.borrow_mut(); match &*pointee { PointerTypeState::Unresolved => {} _ => unreachable!("pointer already resolved"), } *pointee = Self::new(context, new_pointee).pointee.into_inner(); } pub fn pointee(&self) -> Option> { match *self.pointee.borrow() { PointerTypeState::Normal(ref pointee) => Some( pointee .upgrade() .expect("PointerType is not valid after the associated Context is dropped"), ), PointerTypeState::Void => None, PointerTypeState::Unresolved => { unreachable!("pointee() called on unresolved pointer") } } } } impl PartialEq for PointerType { fn eq(&self, rhs: &Self) -> bool { self.pointee() == rhs.pointee() } } impl Eq for PointerType {} impl Hash for PointerType { fn hash(&self, hasher: &mut H) { self.pointee().hash(hasher); } } } pub use pointer_type::PointerType; #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum ScalarType { I8, U8, I16, U16, I32, U32, I64, U64, F16, F32, F64, Bool, Pointer(PointerType), } #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct VectorType { pub element: ScalarType, pub element_count: usize, } #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct StructMember { pub decorations: Vec, pub member_type: Rc, } #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub struct StructId(usize); impl StructId { pub fn new(context: &mut Context) -> Self { let retval = StructId(context.next_struct_id); context.next_struct_id += 1; retval } } #[derive(Clone)] pub struct StructType { pub id: StructId, pub decorations: Vec, pub members: Vec, } impl Eq for StructType {} impl PartialEq for StructType { fn eq(&self, rhs: &Self) -> bool { self.id == rhs.id } } impl Hash for StructType { fn hash(&self, h: &mut H) { self.id.hash(h) } } impl fmt::Debug for StructType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { thread_local! { static CURRENTLY_FORMATTING: RefCell> = RefCell::new(HashSet::new()); } struct CurrentlyFormatting { id: StructId, was_formatting: bool, } impl CurrentlyFormatting { fn new(id: StructId) -> Self { let was_formatting = CURRENTLY_FORMATTING .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id)); Self { id, was_formatting } } } impl Drop for CurrentlyFormatting { fn drop(&mut self) { if !self.was_formatting { CURRENTLY_FORMATTING.with(|currently_formatting| { currently_formatting.borrow_mut().remove(&self.id); }); } } } let currently_formatting = CurrentlyFormatting::new(self.id); let mut state = f.debug_struct("StructType"); state.field("id", &self.id); if !currently_formatting.was_formatting { state.field("decorations", &self.decorations); state.field("members", &self.members); } state.finish() } } #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct ArrayType { pub decorations: Vec, pub element: Rc, pub element_count: Option, } #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum Type { Scalar(ScalarType), Vector(VectorType), Struct(StructType), Array(ArrayType), } impl Type { pub fn is_pointer(&self) -> bool { if let Type::Scalar(ScalarType::Pointer(_)) = self { true } else { false } } pub fn is_scalar(&self) -> bool { if let Type::Scalar(_) = self { true } else { false } } pub fn is_vector(&self) -> bool { if let Type::Vector(_) = self { true } else { false } } pub fn get_pointee(&self) -> Option> { if let Type::Scalar(ScalarType::Pointer(pointer)) = self { pointer.pointee() } else { unreachable!("not a pointer") } } pub fn get_nonvoid_pointee(&self) -> Rc { self.get_pointee().expect("void is not allowed here") } pub fn get_scalar(&self) -> &ScalarType { if let Type::Scalar(scalar) = self { scalar } else { unreachable!("not a scalar type") } } pub fn get_vector(&self) -> &VectorType { if let Type::Vector(vector) = self { vector } else { unreachable!("not a vector type") } } } /// value that can be either defined or undefined #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum Undefable { Undefined, Defined(T), } impl Undefable { pub fn unwrap(self) -> T { match self { Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"), Undefable::Defined(v) => v, } } } impl From for Undefable { fn from(v: T) -> Undefable { Undefable::Defined(v) } } #[derive(Copy, Clone, Debug)] pub enum ScalarConstant { U8(Undefable), U16(Undefable), U32(Undefable), U64(Undefable), I8(Undefable), I16(Undefable), I32(Undefable), I64(Undefable), F16(Undefable), F32(Undefable), F64(Undefable), Bool(Undefable), } macro_rules! define_scalar_vector_constant_impl_without_from { ($type:ident, $name:ident, $get_name:ident) => { impl ScalarConstant { pub fn $get_name(self) -> Undefable<$type> { match self { ScalarConstant::$name(v) => v, _ => unreachable!(concat!("expected a constant ", stringify!($type))), } } } impl VectorConstant { pub fn $get_name(&self) -> &Vec> { match self { VectorConstant::$name(v) => v, _ => unreachable!(concat!( "expected a constant vector with ", stringify!($type), " elements" )), } } } }; } macro_rules! define_scalar_vector_constant_impl { ($type:ident, $name:ident, $get_name:ident) => { define_scalar_vector_constant_impl_without_from!($type, $name, $get_name); impl From> for ScalarConstant { fn from(v: Undefable<$type>) -> ScalarConstant { ScalarConstant::$name(v) } } impl From>> for VectorConstant { fn from(v: Vec>) -> VectorConstant { VectorConstant::$name(v) } } }; } define_scalar_vector_constant_impl!(u8, U8, get_u8); define_scalar_vector_constant_impl!(u16, U16, get_u16); define_scalar_vector_constant_impl!(u32, U32, get_u32); define_scalar_vector_constant_impl!(u64, U64, get_u64); define_scalar_vector_constant_impl!(i8, I8, get_i8); define_scalar_vector_constant_impl!(i16, I16, get_i16); define_scalar_vector_constant_impl!(i32, I32, get_i32); define_scalar_vector_constant_impl!(i64, I64, get_i64); define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16); define_scalar_vector_constant_impl!(f32, F32, get_f32); define_scalar_vector_constant_impl!(f64, F64, get_f64); define_scalar_vector_constant_impl!(bool, Bool, get_bool); impl ScalarConstant { pub fn get_type(self) -> Type { Type::Scalar(self.get_scalar_type()) } pub fn get_scalar_type(self) -> ScalarType { match self { ScalarConstant::U8(_) => ScalarType::U8, ScalarConstant::U16(_) => ScalarType::U16, ScalarConstant::U32(_) => ScalarType::U32, ScalarConstant::U64(_) => ScalarType::U64, ScalarConstant::I8(_) => ScalarType::I8, ScalarConstant::I16(_) => ScalarType::I16, ScalarConstant::I32(_) => ScalarType::I32, ScalarConstant::I64(_) => ScalarType::I64, ScalarConstant::F16(_) => ScalarType::F16, ScalarConstant::F32(_) => ScalarType::F32, ScalarConstant::F64(_) => ScalarType::F64, ScalarConstant::Bool(_) => ScalarType::Bool, } } } #[derive(Clone, Debug)] pub enum VectorConstant { U8(Vec>), U16(Vec>), U32(Vec>), U64(Vec>), I8(Vec>), I16(Vec>), I32(Vec>), I64(Vec>), F16(Vec>), F32(Vec>), F64(Vec>), Bool(Vec>), } impl VectorConstant { pub fn get_element_type(&self) -> ScalarType { match self { VectorConstant::U8(_) => ScalarType::U8, VectorConstant::U16(_) => ScalarType::U16, VectorConstant::U32(_) => ScalarType::U32, VectorConstant::U64(_) => ScalarType::U64, VectorConstant::I8(_) => ScalarType::I8, VectorConstant::I16(_) => ScalarType::I16, VectorConstant::I32(_) => ScalarType::I32, VectorConstant::I64(_) => ScalarType::I64, VectorConstant::F16(_) => ScalarType::F16, VectorConstant::F32(_) => ScalarType::F32, VectorConstant::F64(_) => ScalarType::F64, VectorConstant::Bool(_) => ScalarType::Bool, } } pub fn get_element_count(&self) -> usize { match self { VectorConstant::U8(v) => v.len(), VectorConstant::U16(v) => v.len(), VectorConstant::U32(v) => v.len(), VectorConstant::U64(v) => v.len(), VectorConstant::I8(v) => v.len(), VectorConstant::I16(v) => v.len(), VectorConstant::I32(v) => v.len(), VectorConstant::I64(v) => v.len(), VectorConstant::F16(v) => v.len(), VectorConstant::F32(v) => v.len(), VectorConstant::F64(v) => v.len(), VectorConstant::Bool(v) => v.len(), } } pub fn get_type(&self) -> Type { Type::Vector(VectorType { element: self.get_element_type(), element_count: self.get_element_count(), }) } } #[derive(Clone, Debug)] pub enum Constant { Scalar(ScalarConstant), Vector(VectorConstant), } impl Constant { pub fn get_type(&self) -> Type { match self { Constant::Scalar(v) => v.get_type(), Constant::Vector(v) => v.get_type(), } } pub fn get_scalar(&self) -> &ScalarConstant { match self { Constant::Scalar(v) => v, _ => unreachable!("not a scalar constant"), } } } #[derive(Debug, Clone)] struct MemberDecoration { member: u32, decoration: Decoration, } #[derive(Debug, Clone)] struct BuiltInVariable { built_in: BuiltIn, } impl BuiltInVariable { fn get_type(&self, _context: &mut Context) -> Rc { match self.built_in { BuiltIn::GlobalInvocationId => Rc::new(Type::Vector(VectorType { element: ScalarType::U32, element_count: 3, })), _ => unreachable!("unknown built-in"), } } } #[derive(Debug, Clone)] struct UniformVariable { binding: u32, descriptor_set: u32, variable_type: Rc, } #[derive(Debug)] enum IdKind<'a, C: shader_compiler_backend::Context<'a>> { Undefined, DecorationGroup, Type(Rc), VoidType, FunctionType { return_type: Option>, arguments: Vec>, }, ForwardPointer(Rc), BuiltInVariable(BuiltInVariable), Constant(Rc), UniformVariable(UniformVariable), Function(Option), BasicBlock { basic_block: C::BasicBlock, buildable_basic_block: Option, }, } #[derive(Debug)] struct IdProperties<'a, C: shader_compiler_backend::Context<'a>> { kind: IdKind<'a, C>, decorations: Vec, member_decorations: Vec, } impl<'a, C: shader_compiler_backend::Context<'a>> IdProperties<'a, C> { fn is_empty(&self) -> bool { match self.kind { IdKind::Undefined => {} _ => return false, } self.decorations.is_empty() && self.member_decorations.is_empty() } fn set_kind(&mut self, kind: IdKind<'a, C>) { match &self.kind { IdKind::Undefined => {} _ => unreachable!("duplicate id"), } self.kind = kind; } fn get_type(&self) -> Option<&Rc> { match &self.kind { IdKind::Type(t) => Some(t), IdKind::VoidType => None, _ => unreachable!("id is not type"), } } fn get_nonvoid_type(&self) -> &Rc { self.get_type().expect("void is not allowed here") } fn get_constant(&self) -> &Rc { match &self.kind { IdKind::Constant(c) => c, _ => unreachable!("id is not a constant"), } } fn assert_no_member_decorations(&self, id: IdRef) { for member_decoration in &self.member_decorations { unreachable!( "member decoration not allowed on {}: {:?}", id, member_decoration ); } } fn assert_no_decorations(&self, id: IdRef) { self.assert_no_member_decorations(id); for decoration in &self.decorations { unreachable!("decoration not allowed on {}: {:?}", id, decoration); } } } struct Ids<'a, C: shader_compiler_backend::Context<'a>>(Vec>); impl<'a, C: shader_compiler_backend::Context<'a>> Ids<'a, C> { pub fn iter(&self) -> impl Iterator)> { (1..self.0.len()).map(move |index| (IdRef(index as u32), &self.0[index])) } } impl<'a, C: shader_compiler_backend::Context<'a>> fmt::Debug for Ids<'a, C> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_map() .entries( self.0 .iter() .enumerate() .filter_map(|(id_index, id_properties)| { if id_properties.is_empty() { return None; } Some((IdRef(id_index as u32), id_properties)) }), ) .finish() } } impl<'a, C: shader_compiler_backend::Context<'a>> Index for Ids<'a, C> { type Output = IdProperties<'a, C>; fn index<'b>(&'b self, index: IdRef) -> &'b IdProperties<'a, C> { &self.0[index.0 as usize] } } impl<'a, C: shader_compiler_backend::Context<'a>> IndexMut for Ids<'a, C> { fn index_mut(&mut self, index: IdRef) -> &mut IdProperties<'a, C> { &mut self.0[index.0 as usize] } } struct ParsedShaderFunction { instructions: Vec, decorations: Vec, } impl fmt::Debug for ParsedShaderFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ParsedShaderFunction:\n")?; for instruction in &self.instructions { write!(f, "{}", instruction)?; } Ok(()) } } #[derive(Debug)] struct ParsedShader<'a, C: shader_compiler_backend::Context<'a>> { ids: Ids<'a, C>, main_function_id: IdRef, interface_variables: Vec, execution_modes: Vec, workgroup_size: Option<(u32, u32, u32)>, } struct ShaderEntryPoint { main_function_id: IdRef, interface_variables: Vec, } impl<'a, C: shader_compiler_backend::Context<'a>> ParsedShader<'a, C> { fn create( context: &mut Context, stage_info: ShaderStageCreateInfo, execution_model: ExecutionModel, ) -> Self { parsed_shader_create::create(context, stage_info, execution_model) } } #[derive(Clone, Debug)] pub struct GenericPipelineOptions { pub optimization_mode: shader_compiler_backend::OptimizationMode, } #[derive(Clone, Debug)] pub enum DescriptorLayout { Sampler { count: usize }, CombinedImageSampler { count: usize }, SampledImage { count: usize }, StorageImage { count: usize }, UniformTexelBuffer { count: usize }, StorageTexelBuffer { count: usize }, UniformBuffer { count: usize }, StorageBuffer { count: usize }, UniformBufferDynamic { count: usize }, StorageBufferDynamic { count: usize }, InputAttachment { count: usize }, } #[derive(Clone, Debug)] pub struct DescriptorSetLayout { pub bindings: Vec>, } #[derive(Clone, Debug)] pub struct PipelineLayout { pub push_constants_size: usize, pub descriptor_sets: Vec, } #[derive(Debug)] pub struct ComputePipeline {} #[derive(Clone, Debug)] pub struct ComputePipelineOptions { pub generic_options: GenericPipelineOptions, } #[derive(Copy, Clone, Debug)] pub struct Specialization<'a> { pub id: u32, pub bytes: &'a [u8], } #[derive(Copy, Clone, Debug)] pub struct ShaderStageCreateInfo<'a> { pub code: &'a [u32], pub entry_point_name: &'a str, pub specializations: &'a [Specialization<'a>], } impl ComputePipeline { pub fn new( options: &ComputePipelineOptions, compute_shader_stage: ShaderStageCreateInfo, pipeline_layout: PipelineLayout, backend_compiler: C, ) -> ComputePipeline { let mut frontend_context = Context::default(); struct CompilerUser<'a> { frontend_context: Context, compute_shader_stage: ShaderStageCreateInfo<'a>, } #[derive(Debug)] enum CompileError {} impl<'cu> shader_compiler_backend::CompilerUser for CompilerUser<'cu> { type FunctionKey = CompiledFunctionKey; type Error = CompileError; fn create_error(message: String) -> CompileError { panic!("compile error: {}", message) } fn run<'a, C: shader_compiler_backend::Context<'a>>( self, context: &'a C, ) -> Result< shader_compiler_backend::CompileInputs<'a, C, CompiledFunctionKey>, CompileError, > { let backend_context = context; let CompilerUser { mut frontend_context, compute_shader_stage, } = self; let parsed_shader = ParsedShader::create( &mut frontend_context, compute_shader_stage, ExecutionModel::GLCompute, ); let mut module = backend_context.create_module(""); let function = parsed_shader.compile( &mut frontend_context, backend_context, &mut module, "fn_", ); Ok(shader_compiler_backend::CompileInputs { module: module.verify().unwrap(), callable_functions: iter::once(( CompiledFunctionKey::ComputeShaderEntrypoint, function, )) .collect(), }) } } let compile_results = backend_compiler .run( CompilerUser { frontend_context, compute_shader_stage, }, shader_compiler_backend::CompilerIndependentConfig { optimization_mode: options.generic_options.optimization_mode, } .into(), ) .unwrap(); unimplemented!() } }