1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
4 extern crate shader_compiler_backend;
5 extern crate spirv_parser;
7 mod parsed_shader_compile;
8 mod parsed_shader_create;
10 use parsed_shader_compile::ParsedShaderCompile;
11 use shader_compiler_backend::Module;
12 use spirv_parser::{BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction};
13 use std::cell::RefCell;
14 use std::collections::HashSet;
16 use std::hash::{Hash, Hasher};
18 use std::ops::{Index, IndexMut};
21 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
22 pub enum CompiledFunctionKey {
23 ComputeShaderEntrypoint,
27 types: pointer_type::ContextTypes,
28 next_struct_id: usize,
31 impl Default for Context {
32 fn default() -> Context {
34 types: Default::default(),
41 use super::{Context, Type};
42 use std::cell::RefCell;
44 use std::hash::{Hash, Hasher};
45 use std::rc::{Rc, Weak};
48 pub struct ContextTypes(Vec<Rc<Type>>);
50 #[derive(Clone, Debug)]
51 enum PointerTypeState {
58 pub struct PointerType {
59 pointee: RefCell<PointerTypeState>,
62 impl fmt::Debug for PointerType {
63 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 let mut state = f.debug_struct("PointerType");
65 if let PointerTypeState::Unresolved = *self.pointee.borrow() {
66 state.field("pointee", &PointerTypeState::Unresolved);
68 state.field("pointee", &self.pointee());
75 pub fn new(context: &mut Context, pointee: Option<Rc<Type>>) -> Self {
77 pointee: RefCell::new(match pointee {
79 let weak = Rc::downgrade(&pointee);
80 context.types.0.push(pointee);
81 PointerTypeState::Normal(weak)
83 None => PointerTypeState::Void,
87 pub fn new_void() -> Self {
89 pointee: RefCell::new(PointerTypeState::Void),
92 pub fn unresolved() -> Self {
94 pointee: RefCell::new(PointerTypeState::Unresolved),
97 pub fn resolve(&self, context: &mut Context, new_pointee: Option<Rc<Type>>) {
98 let mut pointee = self.pointee.borrow_mut();
100 PointerTypeState::Unresolved => {}
101 _ => unreachable!("pointer already resolved"),
103 *pointee = Self::new(context, new_pointee).pointee.into_inner();
105 pub fn pointee(&self) -> Option<Rc<Type>> {
106 match *self.pointee.borrow() {
107 PointerTypeState::Normal(ref pointee) => Some(
110 .expect("PointerType is not valid after the associated Context is dropped"),
112 PointerTypeState::Void => None,
113 PointerTypeState::Unresolved => {
114 unreachable!("pointee() called on unresolved pointer")
120 impl PartialEq for PointerType {
121 fn eq(&self, rhs: &Self) -> bool {
122 self.pointee() == rhs.pointee()
126 impl Eq for PointerType {}
128 impl Hash for PointerType {
129 fn hash<H: Hasher>(&self, hasher: &mut H) {
130 self.pointee().hash(hasher);
135 pub use pointer_type::PointerType;
137 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
138 pub enum ScalarType {
151 Pointer(PointerType),
154 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
155 pub struct VectorType {
156 pub element: ScalarType,
157 pub element_count: usize,
160 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
161 pub struct StructMember {
162 pub decorations: Vec<Decoration>,
163 pub member_type: Rc<Type>,
166 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
167 pub struct StructId(usize);
170 pub fn new(context: &mut Context) -> Self {
171 let retval = StructId(context.next_struct_id);
172 context.next_struct_id += 1;
178 pub struct StructType {
180 pub decorations: Vec<Decoration>,
181 pub members: Vec<StructMember>,
184 impl Eq for StructType {}
186 impl PartialEq for StructType {
187 fn eq(&self, rhs: &Self) -> bool {
192 impl Hash for StructType {
193 fn hash<H: Hasher>(&self, h: &mut H) {
198 impl fmt::Debug for StructType {
199 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
201 static CURRENTLY_FORMATTING: RefCell<HashSet<StructId>> = RefCell::new(HashSet::new());
203 struct CurrentlyFormatting {
205 was_formatting: bool,
207 impl CurrentlyFormatting {
208 fn new(id: StructId) -> Self {
209 let was_formatting = CURRENTLY_FORMATTING
210 .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id));
211 Self { id, was_formatting }
214 impl Drop for CurrentlyFormatting {
216 if !self.was_formatting {
217 CURRENTLY_FORMATTING.with(|currently_formatting| {
218 currently_formatting.borrow_mut().remove(&self.id);
223 let currently_formatting = CurrentlyFormatting::new(self.id);
224 let mut state = f.debug_struct("StructType");
225 state.field("id", &self.id);
226 if !currently_formatting.was_formatting {
227 state.field("decorations", &self.decorations);
228 state.field("members", &self.members);
234 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
235 pub struct ArrayType {
236 pub decorations: Vec<Decoration>,
237 pub element: Rc<Type>,
238 pub element_count: Option<usize>,
241 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
250 pub fn is_pointer(&self) -> bool {
251 if let Type::Scalar(ScalarType::Pointer(_)) = self {
257 pub fn is_scalar(&self) -> bool {
258 if let Type::Scalar(_) = self {
264 pub fn is_vector(&self) -> bool {
265 if let Type::Vector(_) = self {
271 pub fn get_pointee(&self) -> Option<Rc<Type>> {
272 if let Type::Scalar(ScalarType::Pointer(pointer)) = self {
275 unreachable!("not a pointer")
278 pub fn get_nonvoid_pointee(&self) -> Rc<Type> {
279 self.get_pointee().expect("void is not allowed here")
281 pub fn get_scalar(&self) -> &ScalarType {
282 if let Type::Scalar(scalar) = self {
285 unreachable!("not a scalar type")
288 pub fn get_vector(&self) -> &VectorType {
289 if let Type::Vector(vector) = self {
292 unreachable!("not a vector type")
297 /// value that can be either defined or undefined
298 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
299 pub enum Undefable<T> {
304 impl<T> Undefable<T> {
305 pub fn unwrap(self) -> T {
307 Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"),
308 Undefable::Defined(v) => v,
313 impl<T> From<T> for Undefable<T> {
314 fn from(v: T) -> Undefable<T> {
315 Undefable::Defined(v)
319 #[derive(Copy, Clone, Debug)]
320 pub enum ScalarConstant {
332 Bool(Undefable<bool>),
335 macro_rules! define_scalar_vector_constant_impl_without_from {
336 ($type:ident, $name:ident, $get_name:ident) => {
337 impl ScalarConstant {
338 pub fn $get_name(self) -> Undefable<$type> {
340 ScalarConstant::$name(v) => v,
341 _ => unreachable!(concat!("expected a constant ", stringify!($type))),
345 impl VectorConstant {
346 pub fn $get_name(&self) -> &Vec<Undefable<$type>> {
348 VectorConstant::$name(v) => v,
349 _ => unreachable!(concat!(
350 "expected a constant vector with ",
360 macro_rules! define_scalar_vector_constant_impl {
361 ($type:ident, $name:ident, $get_name:ident) => {
362 define_scalar_vector_constant_impl_without_from!($type, $name, $get_name);
363 impl From<Undefable<$type>> for ScalarConstant {
364 fn from(v: Undefable<$type>) -> ScalarConstant {
365 ScalarConstant::$name(v)
368 impl From<Vec<Undefable<$type>>> for VectorConstant {
369 fn from(v: Vec<Undefable<$type>>) -> VectorConstant {
370 VectorConstant::$name(v)
376 define_scalar_vector_constant_impl!(u8, U8, get_u8);
377 define_scalar_vector_constant_impl!(u16, U16, get_u16);
378 define_scalar_vector_constant_impl!(u32, U32, get_u32);
379 define_scalar_vector_constant_impl!(u64, U64, get_u64);
380 define_scalar_vector_constant_impl!(i8, I8, get_i8);
381 define_scalar_vector_constant_impl!(i16, I16, get_i16);
382 define_scalar_vector_constant_impl!(i32, I32, get_i32);
383 define_scalar_vector_constant_impl!(i64, I64, get_i64);
384 define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16);
385 define_scalar_vector_constant_impl!(f32, F32, get_f32);
386 define_scalar_vector_constant_impl!(f64, F64, get_f64);
387 define_scalar_vector_constant_impl!(bool, Bool, get_bool);
389 impl ScalarConstant {
390 pub fn get_type(self) -> Type {
391 Type::Scalar(self.get_scalar_type())
393 pub fn get_scalar_type(self) -> ScalarType {
395 ScalarConstant::U8(_) => ScalarType::U8,
396 ScalarConstant::U16(_) => ScalarType::U16,
397 ScalarConstant::U32(_) => ScalarType::U32,
398 ScalarConstant::U64(_) => ScalarType::U64,
399 ScalarConstant::I8(_) => ScalarType::I8,
400 ScalarConstant::I16(_) => ScalarType::I16,
401 ScalarConstant::I32(_) => ScalarType::I32,
402 ScalarConstant::I64(_) => ScalarType::I64,
403 ScalarConstant::F16(_) => ScalarType::F16,
404 ScalarConstant::F32(_) => ScalarType::F32,
405 ScalarConstant::F64(_) => ScalarType::F64,
406 ScalarConstant::Bool(_) => ScalarType::Bool,
411 #[derive(Clone, Debug)]
412 pub enum VectorConstant {
413 U8(Vec<Undefable<u8>>),
414 U16(Vec<Undefable<u16>>),
415 U32(Vec<Undefable<u32>>),
416 U64(Vec<Undefable<u64>>),
417 I8(Vec<Undefable<i8>>),
418 I16(Vec<Undefable<i16>>),
419 I32(Vec<Undefable<i32>>),
420 I64(Vec<Undefable<i64>>),
421 F16(Vec<Undefable<u16>>),
422 F32(Vec<Undefable<f32>>),
423 F64(Vec<Undefable<f64>>),
424 Bool(Vec<Undefable<bool>>),
427 impl VectorConstant {
428 pub fn get_element_type(&self) -> ScalarType {
430 VectorConstant::U8(_) => ScalarType::U8,
431 VectorConstant::U16(_) => ScalarType::U16,
432 VectorConstant::U32(_) => ScalarType::U32,
433 VectorConstant::U64(_) => ScalarType::U64,
434 VectorConstant::I8(_) => ScalarType::I8,
435 VectorConstant::I16(_) => ScalarType::I16,
436 VectorConstant::I32(_) => ScalarType::I32,
437 VectorConstant::I64(_) => ScalarType::I64,
438 VectorConstant::F16(_) => ScalarType::F16,
439 VectorConstant::F32(_) => ScalarType::F32,
440 VectorConstant::F64(_) => ScalarType::F64,
441 VectorConstant::Bool(_) => ScalarType::Bool,
444 pub fn get_element_count(&self) -> usize {
446 VectorConstant::U8(v) => v.len(),
447 VectorConstant::U16(v) => v.len(),
448 VectorConstant::U32(v) => v.len(),
449 VectorConstant::U64(v) => v.len(),
450 VectorConstant::I8(v) => v.len(),
451 VectorConstant::I16(v) => v.len(),
452 VectorConstant::I32(v) => v.len(),
453 VectorConstant::I64(v) => v.len(),
454 VectorConstant::F16(v) => v.len(),
455 VectorConstant::F32(v) => v.len(),
456 VectorConstant::F64(v) => v.len(),
457 VectorConstant::Bool(v) => v.len(),
460 pub fn get_type(&self) -> Type {
461 Type::Vector(VectorType {
462 element: self.get_element_type(),
463 element_count: self.get_element_count(),
468 #[derive(Clone, Debug)]
470 Scalar(ScalarConstant),
471 Vector(VectorConstant),
475 pub fn get_type(&self) -> Type {
477 Constant::Scalar(v) => v.get_type(),
478 Constant::Vector(v) => v.get_type(),
481 pub fn get_scalar(&self) -> &ScalarConstant {
483 Constant::Scalar(v) => v,
484 _ => unreachable!("not a scalar constant"),
489 #[derive(Debug, Clone)]
490 struct MemberDecoration {
492 decoration: Decoration,
495 #[derive(Debug, Clone)]
496 struct BuiltInVariable {
500 impl BuiltInVariable {
501 fn get_type(&self, _context: &mut Context) -> Rc<Type> {
502 match self.built_in {
503 BuiltIn::GlobalInvocationId => Rc::new(Type::Vector(VectorType {
504 element: ScalarType::U32,
507 _ => unreachable!("unknown built-in"),
512 #[derive(Debug, Clone)]
513 struct UniformVariable {
516 variable_type: Rc<Type>,
526 return_type: Option<Rc<Type>>,
527 arguments: Vec<Rc<Type>>,
529 ForwardPointer(Rc<Type>),
530 BuiltInVariable(BuiltInVariable),
531 Constant(Rc<Constant>),
532 UniformVariable(UniformVariable),
533 Function(Option<ParsedShaderFunction>),
537 struct IdProperties {
539 decorations: Vec<Decoration>,
540 member_decorations: Vec<MemberDecoration>,
544 fn is_empty(&self) -> bool {
546 IdKind::Undefined => {}
549 self.decorations.is_empty() && self.member_decorations.is_empty()
551 fn set_kind(&mut self, kind: IdKind) {
553 IdKind::Undefined => {}
554 _ => unreachable!("duplicate id"),
558 fn get_type(&self) -> Option<&Rc<Type>> {
560 IdKind::Type(t) => Some(t),
561 IdKind::VoidType => None,
562 _ => unreachable!("id is not type"),
565 fn get_nonvoid_type(&self) -> &Rc<Type> {
566 self.get_type().expect("void is not allowed here")
568 fn get_constant(&self) -> &Rc<Constant> {
570 IdKind::Constant(c) => c,
571 _ => unreachable!("id is not a constant"),
574 fn assert_no_member_decorations(&self, id: IdRef) {
575 for member_decoration in &self.member_decorations {
577 "member decoration not allowed on {}: {:?}",
578 id, member_decoration
582 fn assert_no_decorations(&self, id: IdRef) {
583 self.assert_no_member_decorations(id);
584 for decoration in &self.decorations {
585 unreachable!("decoration not allowed on {}: {:?}", id, decoration);
590 struct Ids(Vec<IdProperties>);
593 pub fn iter(&self) -> impl Iterator<Item = (IdRef, &IdProperties)> {
594 (1..self.0.len()).map(move |index| (IdRef(index as u32), &self.0[index]))
598 impl fmt::Debug for Ids {
599 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
605 .filter_map(|(id_index, id_properties)| {
606 if id_properties.is_empty() {
609 Some((IdRef(id_index as u32), id_properties))
616 impl Index<IdRef> for Ids {
617 type Output = IdProperties;
618 fn index(&self, index: IdRef) -> &IdProperties {
619 &self.0[index.0 as usize]
623 impl IndexMut<IdRef> for Ids {
624 fn index_mut(&mut self, index: IdRef) -> &mut IdProperties {
625 &mut self.0[index.0 as usize]
629 struct ParsedShaderFunction {
630 instructions: Vec<Instruction>,
631 decorations: Vec<Decoration>,
634 impl fmt::Debug for ParsedShaderFunction {
635 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
636 write!(f, "ParsedShaderFunction:\n")?;
637 for instruction in &self.instructions {
638 write!(f, "{}", instruction)?;
645 struct ParsedShader {
647 main_function_id: IdRef,
648 interface_variables: Vec<IdRef>,
649 execution_modes: Vec<ExecutionMode>,
650 workgroup_size: Option<(u32, u32, u32)>,
653 struct ShaderEntryPoint {
654 main_function_id: IdRef,
655 interface_variables: Vec<IdRef>,
660 context: &mut Context,
661 stage_info: ShaderStageCreateInfo,
662 execution_model: ExecutionModel,
664 parsed_shader_create::create(context, stage_info, execution_model)
668 #[derive(Clone, Debug)]
669 pub struct GenericPipelineOptions {
670 pub optimization_mode: shader_compiler_backend::OptimizationMode,
673 #[derive(Clone, Debug)]
674 pub enum DescriptorLayout {
675 Sampler { count: usize },
676 CombinedImageSampler { count: usize },
677 SampledImage { count: usize },
678 StorageImage { count: usize },
679 UniformTexelBuffer { count: usize },
680 StorageTexelBuffer { count: usize },
681 UniformBuffer { count: usize },
682 StorageBuffer { count: usize },
683 UniformBufferDynamic { count: usize },
684 StorageBufferDynamic { count: usize },
685 InputAttachment { count: usize },
688 #[derive(Clone, Debug)]
689 pub struct DescriptorSetLayout {
690 pub bindings: Vec<Option<DescriptorLayout>>,
693 #[derive(Clone, Debug)]
694 pub struct PipelineLayout {
695 pub push_constants_size: usize,
696 pub descriptor_sets: Vec<DescriptorSetLayout>,
700 pub struct ComputePipeline {}
702 #[derive(Clone, Debug)]
703 pub struct ComputePipelineOptions {
704 pub generic_options: GenericPipelineOptions,
707 #[derive(Copy, Clone, Debug)]
708 pub struct Specialization<'a> {
713 #[derive(Copy, Clone, Debug)]
714 pub struct ShaderStageCreateInfo<'a> {
716 pub entry_point_name: &'a str,
717 pub specializations: &'a [Specialization<'a>],
720 impl ComputePipeline {
721 pub fn new<C: shader_compiler_backend::Compiler>(
722 options: &ComputePipelineOptions,
723 compute_shader_stage: ShaderStageCreateInfo,
724 pipeline_layout: PipelineLayout,
726 ) -> ComputePipeline {
727 let mut frontend_context = Context::default();
728 let parsed_shader = ParsedShader::create(
729 &mut frontend_context,
730 compute_shader_stage,
731 ExecutionModel::GLCompute,
733 println!("parsed_shader:\n{:#?}", parsed_shader);
734 struct CompilerUser {
735 frontend_context: Context,
736 parsed_shader: ParsedShader,
740 impl shader_compiler_backend::CompilerUser for CompilerUser {
741 type FunctionKey = CompiledFunctionKey;
742 type Error = CompileError;
743 fn create_error(message: String) -> CompileError {
744 panic!("compile error: {}", message)
746 fn run<'a, C: shader_compiler_backend::Context<'a>>(
750 shader_compiler_backend::CompileInputs<'a, C, CompiledFunctionKey>,
753 let backend_context = context;
755 mut frontend_context,
758 let mut module = backend_context.create_module("");
760 parsed_shader.compile(&mut frontend_context, backend_context, &mut module);
761 Ok(shader_compiler_backend::CompileInputs {
762 module: module.verify().unwrap(),
763 callable_functions: iter::once((
764 CompiledFunctionKey::ComputeShaderEntrypoint,
771 let compile_results = backend_compiler
777 shader_compiler_backend::CompilerIndependentConfig {
778 optimization_mode: options.generic_options.optimization_mode,