1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
4 extern crate shader_compiler_backend;
5 extern crate spirv_parser;
7 mod parsed_shader_compile;
8 mod parsed_shader_create;
10 use parsed_shader_compile::ParsedShaderCompile;
11 use shader_compiler_backend::Module;
12 use spirv_parser::{BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction};
13 use std::cell::RefCell;
14 use std::collections::HashSet;
16 use std::hash::{Hash, Hasher};
18 use std::ops::{Index, IndexMut};
21 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
22 pub enum CompiledFunctionKey {
23 ComputeShaderEntrypoint,
27 types: pointer_type::ContextTypes,
28 next_struct_id: usize,
31 impl Default for Context {
32 fn default() -> Context {
34 types: Default::default(),
41 use super::{Context, FrontendType};
42 use std::cell::RefCell;
44 use std::hash::{Hash, Hasher};
45 use std::rc::{Rc, Weak};
48 pub struct ContextTypes(Vec<Rc<FrontendType>>);
50 #[derive(Clone, Debug)]
51 enum PointerTypeState {
53 Normal(Weak<FrontendType>),
58 pub struct PointerType {
59 pointee: RefCell<PointerTypeState>,
62 impl fmt::Debug for PointerType {
63 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 let mut state = f.debug_struct("PointerType");
65 if let PointerTypeState::Unresolved = *self.pointee.borrow() {
66 state.field("pointee", &PointerTypeState::Unresolved);
68 state.field("pointee", &self.pointee());
75 pub fn new(context: &mut Context, pointee: Option<Rc<FrontendType>>) -> Self {
77 pointee: RefCell::new(match pointee {
79 let weak = Rc::downgrade(&pointee);
80 context.types.0.push(pointee);
81 PointerTypeState::Normal(weak)
83 None => PointerTypeState::Void,
87 pub fn new_void() -> Self {
89 pointee: RefCell::new(PointerTypeState::Void),
92 pub fn unresolved() -> Self {
94 pointee: RefCell::new(PointerTypeState::Unresolved),
97 pub fn resolve(&self, context: &mut Context, new_pointee: Option<Rc<FrontendType>>) {
98 let mut pointee = self.pointee.borrow_mut();
100 PointerTypeState::Unresolved => {}
101 _ => unreachable!("pointer already resolved"),
103 *pointee = Self::new(context, new_pointee).pointee.into_inner();
105 pub fn pointee(&self) -> Option<Rc<FrontendType>> {
106 match *self.pointee.borrow() {
107 PointerTypeState::Normal(ref pointee) => Some(
110 .expect("PointerType is not valid after the associated Context is dropped"),
112 PointerTypeState::Void => None,
113 PointerTypeState::Unresolved => {
114 unreachable!("pointee() called on unresolved pointer")
120 impl PartialEq for PointerType {
121 fn eq(&self, rhs: &Self) -> bool {
122 self.pointee() == rhs.pointee()
126 impl Eq for PointerType {}
128 impl Hash for PointerType {
129 fn hash<H: Hasher>(&self, hasher: &mut H) {
130 self.pointee().hash(hasher);
135 pub use pointer_type::PointerType;
137 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
138 pub enum ScalarType {
151 Pointer(PointerType),
154 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
155 pub struct VectorType {
156 pub element: ScalarType,
157 pub element_count: usize,
160 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
161 pub struct StructMember {
162 pub decorations: Vec<Decoration>,
163 pub member_type: Rc<FrontendType>,
166 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
167 pub struct StructId(usize);
170 pub fn new(context: &mut Context) -> Self {
171 let retval = StructId(context.next_struct_id);
172 context.next_struct_id += 1;
178 pub struct StructType {
180 pub decorations: Vec<Decoration>,
181 pub members: Vec<StructMember>,
184 impl Eq for StructType {}
186 impl PartialEq for StructType {
187 fn eq(&self, rhs: &Self) -> bool {
192 impl Hash for StructType {
193 fn hash<H: Hasher>(&self, h: &mut H) {
198 impl fmt::Debug for StructType {
199 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
201 static CURRENTLY_FORMATTING: RefCell<HashSet<StructId>> = RefCell::new(HashSet::new());
203 struct CurrentlyFormatting {
205 was_formatting: bool,
207 impl CurrentlyFormatting {
208 fn new(id: StructId) -> Self {
209 let was_formatting = CURRENTLY_FORMATTING
210 .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id));
211 Self { id, was_formatting }
214 impl Drop for CurrentlyFormatting {
216 if !self.was_formatting {
217 CURRENTLY_FORMATTING.with(|currently_formatting| {
218 currently_formatting.borrow_mut().remove(&self.id);
223 let currently_formatting = CurrentlyFormatting::new(self.id);
224 let mut state = f.debug_struct("StructType");
225 state.field("id", &self.id);
226 if !currently_formatting.was_formatting {
227 state.field("decorations", &self.decorations);
228 state.field("members", &self.members);
234 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
235 pub struct ArrayType {
236 pub decorations: Vec<Decoration>,
237 pub element: Rc<FrontendType>,
238 pub element_count: Option<usize>,
241 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
242 pub enum FrontendType {
250 pub fn is_pointer(&self) -> bool {
251 if let FrontendType::Scalar(ScalarType::Pointer(_)) = self {
257 pub fn is_scalar(&self) -> bool {
258 if let FrontendType::Scalar(_) = self {
264 pub fn is_vector(&self) -> bool {
265 if let FrontendType::Vector(_) = self {
271 pub fn get_pointee(&self) -> Option<Rc<FrontendType>> {
272 if let FrontendType::Scalar(ScalarType::Pointer(pointer)) = self {
275 unreachable!("not a pointer")
278 pub fn get_nonvoid_pointee(&self) -> Rc<FrontendType> {
279 self.get_pointee().expect("void is not allowed here")
281 pub fn get_scalar(&self) -> &ScalarType {
282 if let FrontendType::Scalar(scalar) = self {
285 unreachable!("not a scalar type")
288 pub fn get_vector(&self) -> &VectorType {
289 if let FrontendType::Vector(vector) = self {
292 unreachable!("not a vector type")
297 /// value that can be either defined or undefined
298 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
299 pub enum Undefable<T> {
304 impl<T> Undefable<T> {
305 pub fn unwrap(self) -> T {
307 Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"),
308 Undefable::Defined(v) => v,
313 impl<T> From<T> for Undefable<T> {
314 fn from(v: T) -> Undefable<T> {
315 Undefable::Defined(v)
319 #[derive(Copy, Clone, Debug)]
320 pub enum ScalarConstant {
332 Bool(Undefable<bool>),
335 macro_rules! define_scalar_vector_constant_impl_without_from {
336 ($type:ident, $name:ident, $get_name:ident) => {
337 impl ScalarConstant {
338 pub fn $get_name(self) -> Undefable<$type> {
340 ScalarConstant::$name(v) => v,
341 _ => unreachable!(concat!("expected a constant ", stringify!($type))),
345 impl VectorConstant {
346 pub fn $get_name(&self) -> &Vec<Undefable<$type>> {
348 VectorConstant::$name(v) => v,
349 _ => unreachable!(concat!(
350 "expected a constant vector with ",
360 macro_rules! define_scalar_vector_constant_impl {
361 ($type:ident, $name:ident, $get_name:ident) => {
362 define_scalar_vector_constant_impl_without_from!($type, $name, $get_name);
363 impl From<Undefable<$type>> for ScalarConstant {
364 fn from(v: Undefable<$type>) -> ScalarConstant {
365 ScalarConstant::$name(v)
368 impl From<Vec<Undefable<$type>>> for VectorConstant {
369 fn from(v: Vec<Undefable<$type>>) -> VectorConstant {
370 VectorConstant::$name(v)
376 define_scalar_vector_constant_impl!(u8, U8, get_u8);
377 define_scalar_vector_constant_impl!(u16, U16, get_u16);
378 define_scalar_vector_constant_impl!(u32, U32, get_u32);
379 define_scalar_vector_constant_impl!(u64, U64, get_u64);
380 define_scalar_vector_constant_impl!(i8, I8, get_i8);
381 define_scalar_vector_constant_impl!(i16, I16, get_i16);
382 define_scalar_vector_constant_impl!(i32, I32, get_i32);
383 define_scalar_vector_constant_impl!(i64, I64, get_i64);
384 define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16);
385 define_scalar_vector_constant_impl!(f32, F32, get_f32);
386 define_scalar_vector_constant_impl!(f64, F64, get_f64);
387 define_scalar_vector_constant_impl!(bool, Bool, get_bool);
389 impl ScalarConstant {
390 pub fn get_type(self) -> FrontendType {
391 FrontendType::Scalar(self.get_scalar_type())
393 pub fn get_scalar_type(self) -> ScalarType {
395 ScalarConstant::U8(_) => ScalarType::U8,
396 ScalarConstant::U16(_) => ScalarType::U16,
397 ScalarConstant::U32(_) => ScalarType::U32,
398 ScalarConstant::U64(_) => ScalarType::U64,
399 ScalarConstant::I8(_) => ScalarType::I8,
400 ScalarConstant::I16(_) => ScalarType::I16,
401 ScalarConstant::I32(_) => ScalarType::I32,
402 ScalarConstant::I64(_) => ScalarType::I64,
403 ScalarConstant::F16(_) => ScalarType::F16,
404 ScalarConstant::F32(_) => ScalarType::F32,
405 ScalarConstant::F64(_) => ScalarType::F64,
406 ScalarConstant::Bool(_) => ScalarType::Bool,
411 #[derive(Clone, Debug)]
412 pub enum VectorConstant {
413 U8(Vec<Undefable<u8>>),
414 U16(Vec<Undefable<u16>>),
415 U32(Vec<Undefable<u32>>),
416 U64(Vec<Undefable<u64>>),
417 I8(Vec<Undefable<i8>>),
418 I16(Vec<Undefable<i16>>),
419 I32(Vec<Undefable<i32>>),
420 I64(Vec<Undefable<i64>>),
421 F16(Vec<Undefable<u16>>),
422 F32(Vec<Undefable<f32>>),
423 F64(Vec<Undefable<f64>>),
424 Bool(Vec<Undefable<bool>>),
427 impl VectorConstant {
428 pub fn get_element_type(&self) -> ScalarType {
430 VectorConstant::U8(_) => ScalarType::U8,
431 VectorConstant::U16(_) => ScalarType::U16,
432 VectorConstant::U32(_) => ScalarType::U32,
433 VectorConstant::U64(_) => ScalarType::U64,
434 VectorConstant::I8(_) => ScalarType::I8,
435 VectorConstant::I16(_) => ScalarType::I16,
436 VectorConstant::I32(_) => ScalarType::I32,
437 VectorConstant::I64(_) => ScalarType::I64,
438 VectorConstant::F16(_) => ScalarType::F16,
439 VectorConstant::F32(_) => ScalarType::F32,
440 VectorConstant::F64(_) => ScalarType::F64,
441 VectorConstant::Bool(_) => ScalarType::Bool,
444 pub fn get_element_count(&self) -> usize {
446 VectorConstant::U8(v) => v.len(),
447 VectorConstant::U16(v) => v.len(),
448 VectorConstant::U32(v) => v.len(),
449 VectorConstant::U64(v) => v.len(),
450 VectorConstant::I8(v) => v.len(),
451 VectorConstant::I16(v) => v.len(),
452 VectorConstant::I32(v) => v.len(),
453 VectorConstant::I64(v) => v.len(),
454 VectorConstant::F16(v) => v.len(),
455 VectorConstant::F32(v) => v.len(),
456 VectorConstant::F64(v) => v.len(),
457 VectorConstant::Bool(v) => v.len(),
460 pub fn get_type(&self) -> FrontendType {
461 FrontendType::Vector(VectorType {
462 element: self.get_element_type(),
463 element_count: self.get_element_count(),
468 #[derive(Clone, Debug)]
470 Scalar(ScalarConstant),
471 Vector(VectorConstant),
475 pub fn get_type(&self) -> FrontendType {
477 Constant::Scalar(v) => v.get_type(),
478 Constant::Vector(v) => v.get_type(),
481 pub fn get_scalar(&self) -> &ScalarConstant {
483 Constant::Scalar(v) => v,
484 _ => unreachable!("not a scalar constant"),
489 #[derive(Debug, Clone)]
490 struct MemberDecoration {
492 decoration: Decoration,
495 #[derive(Debug, Clone)]
496 struct BuiltInVariable {
500 impl BuiltInVariable {
501 fn get_type(&self, _context: &mut Context) -> Rc<FrontendType> {
502 match self.built_in {
503 BuiltIn::GlobalInvocationId => Rc::new(FrontendType::Vector(VectorType {
504 element: ScalarType::U32,
507 _ => unreachable!("unknown built-in"),
512 #[derive(Debug, Clone)]
513 struct UniformVariable {
516 variable_type: Rc<FrontendType>,
519 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
520 enum CrossLaneBehavior {
526 struct FrontendValue<'a, C: shader_compiler_backend::Context<'a>> {
527 frontend_type: Rc<FrontendType>,
528 backend_value: Option<C::Value>,
529 cross_lane_behavior: CrossLaneBehavior,
533 enum IdKind<'a, C: shader_compiler_backend::Context<'a>> {
536 Type(Rc<FrontendType>),
539 return_type: Option<Rc<FrontendType>>,
540 arguments: Vec<Rc<FrontendType>>,
542 ForwardPointer(Rc<FrontendType>),
543 BuiltInVariable(BuiltInVariable),
544 Constant(Rc<Constant>),
545 UniformVariable(UniformVariable),
546 Function(Option<ParsedShaderFunction>),
548 basic_block: C::BasicBlock,
549 buildable_basic_block: Option<C::BuildableBasicBlock>,
551 Value(FrontendValue<'a, C>),
555 struct IdProperties<'a, C: shader_compiler_backend::Context<'a>> {
557 decorations: Vec<Decoration>,
558 member_decorations: Vec<MemberDecoration>,
561 impl<'a, C: shader_compiler_backend::Context<'a>> IdProperties<'a, C> {
562 fn is_empty(&self) -> bool {
564 IdKind::Undefined => {}
567 self.decorations.is_empty() && self.member_decorations.is_empty()
569 fn set_kind(&mut self, kind: IdKind<'a, C>) {
571 IdKind::Undefined => {}
572 _ => unreachable!("duplicate id"),
576 fn get_type(&self) -> Option<&Rc<FrontendType>> {
578 IdKind::Type(t) => Some(t),
579 IdKind::VoidType => None,
580 _ => unreachable!("id is not type"),
583 fn get_nonvoid_type(&self) -> &Rc<FrontendType> {
584 self.get_type().expect("void is not allowed here")
586 fn get_constant(&self) -> &Rc<Constant> {
588 IdKind::Constant(c) => c,
589 _ => unreachable!("id is not a constant"),
592 fn get_value(&self) -> &FrontendValue<'a, C> {
594 IdKind::Value(retval) => retval,
595 _ => unreachable!("id is not a value"),
598 fn get_value_mut(&mut self) -> &mut FrontendValue<'a, C> {
599 match &mut self.kind {
600 IdKind::Value(retval) => retval,
601 _ => unreachable!("id is not a value"),
604 fn assert_no_member_decorations(&self, id: IdRef) {
605 for member_decoration in &self.member_decorations {
607 "member decoration not allowed on {}: {:?}",
608 id, member_decoration
612 fn assert_no_decorations(&self, id: IdRef) {
613 self.assert_no_member_decorations(id);
614 for decoration in &self.decorations {
615 unreachable!("decoration not allowed on {}: {:?}", id, decoration);
620 struct Ids<'a, C: shader_compiler_backend::Context<'a>>(Vec<IdProperties<'a, C>>);
622 impl<'a, C: shader_compiler_backend::Context<'a>> Ids<'a, C> {
623 pub fn iter(&self) -> impl Iterator<Item = (IdRef, &IdProperties<'a, C>)> {
624 (1..self.0.len()).map(move |index| (IdRef(index as u32), &self.0[index]))
628 impl<'a, C: shader_compiler_backend::Context<'a>> fmt::Debug for Ids<'a, C> {
629 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
635 .filter_map(|(id_index, id_properties)| {
636 if id_properties.is_empty() {
639 Some((IdRef(id_index as u32), id_properties))
646 impl<'a, C: shader_compiler_backend::Context<'a>> Index<IdRef> for Ids<'a, C> {
647 type Output = IdProperties<'a, C>;
648 fn index<'b>(&'b self, index: IdRef) -> &'b IdProperties<'a, C> {
649 &self.0[index.0 as usize]
653 impl<'a, C: shader_compiler_backend::Context<'a>> IndexMut<IdRef> for Ids<'a, C> {
654 fn index_mut(&mut self, index: IdRef) -> &mut IdProperties<'a, C> {
655 &mut self.0[index.0 as usize]
659 struct ParsedShaderFunction {
660 instructions: Vec<Instruction>,
661 decorations: Vec<Decoration>,
664 impl fmt::Debug for ParsedShaderFunction {
665 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
666 write!(f, "ParsedShaderFunction:\n")?;
667 for instruction in &self.instructions {
668 write!(f, "{}", instruction)?;
675 struct ParsedShader<'a, C: shader_compiler_backend::Context<'a>> {
677 main_function_id: IdRef,
678 interface_variables: Vec<IdRef>,
679 execution_modes: Vec<ExecutionMode>,
680 workgroup_size: Option<(u32, u32, u32)>,
683 struct ShaderEntryPoint {
684 main_function_id: IdRef,
685 interface_variables: Vec<IdRef>,
688 impl<'a, C: shader_compiler_backend::Context<'a>> ParsedShader<'a, C> {
690 context: &mut Context,
691 stage_info: ShaderStageCreateInfo,
692 execution_model: ExecutionModel,
694 parsed_shader_create::create(context, stage_info, execution_model)
698 #[derive(Clone, Debug)]
699 pub struct GenericPipelineOptions {
700 pub optimization_mode: shader_compiler_backend::OptimizationMode,
703 #[derive(Clone, Debug)]
704 pub enum DescriptorLayout {
705 Sampler { count: usize },
706 CombinedImageSampler { count: usize },
707 SampledImage { count: usize },
708 StorageImage { count: usize },
709 UniformTexelBuffer { count: usize },
710 StorageTexelBuffer { count: usize },
711 UniformBuffer { count: usize },
712 StorageBuffer { count: usize },
713 UniformBufferDynamic { count: usize },
714 StorageBufferDynamic { count: usize },
715 InputAttachment { count: usize },
718 #[derive(Clone, Debug)]
719 pub struct DescriptorSetLayout {
720 pub bindings: Vec<Option<DescriptorLayout>>,
723 #[derive(Clone, Debug)]
724 pub struct PipelineLayout {
725 pub push_constants_size: usize,
726 pub descriptor_sets: Vec<DescriptorSetLayout>,
730 pub struct ComputePipeline {}
732 #[derive(Clone, Debug)]
733 pub struct ComputePipelineOptions {
734 pub generic_options: GenericPipelineOptions,
737 #[derive(Copy, Clone, Debug)]
738 pub struct Specialization<'a> {
743 #[derive(Copy, Clone, Debug)]
744 pub struct ShaderStageCreateInfo<'a> {
746 pub entry_point_name: &'a str,
747 pub specializations: &'a [Specialization<'a>],
750 impl ComputePipeline {
751 pub fn new<C: shader_compiler_backend::Compiler>(
752 options: &ComputePipelineOptions,
753 compute_shader_stage: ShaderStageCreateInfo,
754 pipeline_layout: PipelineLayout,
756 ) -> ComputePipeline {
757 let mut frontend_context = Context::default();
758 struct CompilerUser<'a> {
759 frontend_context: Context,
760 compute_shader_stage: ShaderStageCreateInfo<'a>,
764 impl<'cu> shader_compiler_backend::CompilerUser for CompilerUser<'cu> {
765 type FunctionKey = CompiledFunctionKey;
766 type Error = CompileError;
767 fn create_error(message: String) -> CompileError {
768 panic!("compile error: {}", message)
770 fn run<'a, C: shader_compiler_backend::Context<'a>>(
774 shader_compiler_backend::CompileInputs<'a, C, CompiledFunctionKey>,
777 let backend_context = context;
779 mut frontend_context,
780 compute_shader_stage,
782 let parsed_shader = ParsedShader::create(
783 &mut frontend_context,
784 compute_shader_stage,
785 ExecutionModel::GLCompute,
787 let mut module = backend_context.create_module("");
788 let function = parsed_shader.compile(
789 &mut frontend_context,
794 Ok(shader_compiler_backend::CompileInputs {
795 module: module.verify().unwrap(),
796 callable_functions: iter::once((
797 CompiledFunctionKey::ComputeShaderEntrypoint,
804 let compile_results = backend_compiler
808 compute_shader_stage,
810 shader_compiler_backend::CompilerIndependentConfig {
811 optimization_mode: options.generic_options.optimization_mode,