de1288c2c69da2228f7380d83d0dec0af2ec7e2b
[kazan.git] / shader-compiler / src / lib.rs
1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 // Copyright 2018 Jacob Lifshay
3
4 extern crate shader_compiler_backend;
5 extern crate spirv_parser;
6
7 mod parsed_shader_compile;
8 mod parsed_shader_create;
9
10 use parsed_shader_compile::ParsedShaderCompile;
11 use shader_compiler_backend::Module;
12 use spirv_parser::{BuiltIn, Decoration, ExecutionMode, ExecutionModel, IdRef, Instruction};
13 use std::cell::RefCell;
14 use std::collections::HashSet;
15 use std::fmt;
16 use std::hash::{Hash, Hasher};
17 use std::iter;
18 use std::ops::{Index, IndexMut};
19 use std::rc::Rc;
20
21 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
22 pub enum CompiledFunctionKey {
23 ComputeShaderEntrypoint,
24 }
25
26 pub struct Context {
27 types: pointer_type::ContextTypes,
28 next_struct_id: usize,
29 }
30
31 impl Default for Context {
32 fn default() -> Context {
33 Context {
34 types: Default::default(),
35 next_struct_id: 0,
36 }
37 }
38 }
39
40 mod pointer_type {
41 use super::{Context, FrontendType};
42 use std::cell::RefCell;
43 use std::fmt;
44 use std::hash::{Hash, Hasher};
45 use std::rc::{Rc, Weak};
46
47 #[derive(Default)]
48 pub struct ContextTypes(Vec<Rc<FrontendType>>);
49
50 #[derive(Clone, Debug)]
51 enum PointerTypeState {
52 Void,
53 Normal(Weak<FrontendType>),
54 Unresolved,
55 }
56
57 #[derive(Clone)]
58 pub struct PointerType {
59 pointee: RefCell<PointerTypeState>,
60 }
61
62 impl fmt::Debug for PointerType {
63 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 let mut state = f.debug_struct("PointerType");
65 if let PointerTypeState::Unresolved = *self.pointee.borrow() {
66 state.field("pointee", &PointerTypeState::Unresolved);
67 } else {
68 state.field("pointee", &self.pointee());
69 }
70 state.finish()
71 }
72 }
73
74 impl PointerType {
75 pub fn new(context: &mut Context, pointee: Option<Rc<FrontendType>>) -> Self {
76 Self {
77 pointee: RefCell::new(match pointee {
78 Some(pointee) => {
79 let weak = Rc::downgrade(&pointee);
80 context.types.0.push(pointee);
81 PointerTypeState::Normal(weak)
82 }
83 None => PointerTypeState::Void,
84 }),
85 }
86 }
87 pub fn new_void() -> Self {
88 Self {
89 pointee: RefCell::new(PointerTypeState::Void),
90 }
91 }
92 pub fn unresolved() -> Self {
93 Self {
94 pointee: RefCell::new(PointerTypeState::Unresolved),
95 }
96 }
97 pub fn resolve(&self, context: &mut Context, new_pointee: Option<Rc<FrontendType>>) {
98 let mut pointee = self.pointee.borrow_mut();
99 match &*pointee {
100 PointerTypeState::Unresolved => {}
101 _ => unreachable!("pointer already resolved"),
102 }
103 *pointee = Self::new(context, new_pointee).pointee.into_inner();
104 }
105 pub fn pointee(&self) -> Option<Rc<FrontendType>> {
106 match *self.pointee.borrow() {
107 PointerTypeState::Normal(ref pointee) => Some(
108 pointee
109 .upgrade()
110 .expect("PointerType is not valid after the associated Context is dropped"),
111 ),
112 PointerTypeState::Void => None,
113 PointerTypeState::Unresolved => {
114 unreachable!("pointee() called on unresolved pointer")
115 }
116 }
117 }
118 }
119
120 impl PartialEq for PointerType {
121 fn eq(&self, rhs: &Self) -> bool {
122 self.pointee() == rhs.pointee()
123 }
124 }
125
126 impl Eq for PointerType {}
127
128 impl Hash for PointerType {
129 fn hash<H: Hasher>(&self, hasher: &mut H) {
130 self.pointee().hash(hasher);
131 }
132 }
133 }
134
135 pub use pointer_type::PointerType;
136
137 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
138 pub enum ScalarType {
139 I8,
140 U8,
141 I16,
142 U16,
143 I32,
144 U32,
145 I64,
146 U64,
147 F16,
148 F32,
149 F64,
150 Bool,
151 Pointer(PointerType),
152 }
153
154 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
155 pub struct VectorType {
156 pub element: ScalarType,
157 pub element_count: usize,
158 }
159
160 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
161 pub struct StructMember {
162 pub decorations: Vec<Decoration>,
163 pub member_type: Rc<FrontendType>,
164 }
165
166 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
167 pub struct StructId(usize);
168
169 impl StructId {
170 pub fn new(context: &mut Context) -> Self {
171 let retval = StructId(context.next_struct_id);
172 context.next_struct_id += 1;
173 retval
174 }
175 }
176
177 #[derive(Clone)]
178 pub struct StructType {
179 pub id: StructId,
180 pub decorations: Vec<Decoration>,
181 pub members: Vec<StructMember>,
182 }
183
184 impl Eq for StructType {}
185
186 impl PartialEq for StructType {
187 fn eq(&self, rhs: &Self) -> bool {
188 self.id == rhs.id
189 }
190 }
191
192 impl Hash for StructType {
193 fn hash<H: Hasher>(&self, h: &mut H) {
194 self.id.hash(h)
195 }
196 }
197
198 impl fmt::Debug for StructType {
199 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
200 thread_local! {
201 static CURRENTLY_FORMATTING: RefCell<HashSet<StructId>> = RefCell::new(HashSet::new());
202 }
203 struct CurrentlyFormatting {
204 id: StructId,
205 was_formatting: bool,
206 }
207 impl CurrentlyFormatting {
208 fn new(id: StructId) -> Self {
209 let was_formatting = CURRENTLY_FORMATTING
210 .with(|currently_formatting| !currently_formatting.borrow_mut().insert(id));
211 Self { id, was_formatting }
212 }
213 }
214 impl Drop for CurrentlyFormatting {
215 fn drop(&mut self) {
216 if !self.was_formatting {
217 CURRENTLY_FORMATTING.with(|currently_formatting| {
218 currently_formatting.borrow_mut().remove(&self.id);
219 });
220 }
221 }
222 }
223 let currently_formatting = CurrentlyFormatting::new(self.id);
224 let mut state = f.debug_struct("StructType");
225 state.field("id", &self.id);
226 if !currently_formatting.was_formatting {
227 state.field("decorations", &self.decorations);
228 state.field("members", &self.members);
229 }
230 state.finish()
231 }
232 }
233
234 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
235 pub struct ArrayType {
236 pub decorations: Vec<Decoration>,
237 pub element: Rc<FrontendType>,
238 pub element_count: Option<usize>,
239 }
240
241 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
242 pub enum FrontendType {
243 Scalar(ScalarType),
244 Vector(VectorType),
245 Struct(StructType),
246 Array(ArrayType),
247 }
248
249 impl FrontendType {
250 pub fn is_pointer(&self) -> bool {
251 if let FrontendType::Scalar(ScalarType::Pointer(_)) = self {
252 true
253 } else {
254 false
255 }
256 }
257 pub fn is_scalar(&self) -> bool {
258 if let FrontendType::Scalar(_) = self {
259 true
260 } else {
261 false
262 }
263 }
264 pub fn is_vector(&self) -> bool {
265 if let FrontendType::Vector(_) = self {
266 true
267 } else {
268 false
269 }
270 }
271 pub fn get_pointee(&self) -> Option<Rc<FrontendType>> {
272 if let FrontendType::Scalar(ScalarType::Pointer(pointer)) = self {
273 pointer.pointee()
274 } else {
275 unreachable!("not a pointer")
276 }
277 }
278 pub fn get_nonvoid_pointee(&self) -> Rc<FrontendType> {
279 self.get_pointee().expect("void is not allowed here")
280 }
281 pub fn get_scalar(&self) -> &ScalarType {
282 if let FrontendType::Scalar(scalar) = self {
283 scalar
284 } else {
285 unreachable!("not a scalar type")
286 }
287 }
288 pub fn get_vector(&self) -> &VectorType {
289 if let FrontendType::Vector(vector) = self {
290 vector
291 } else {
292 unreachable!("not a vector type")
293 }
294 }
295 }
296
297 /// value that can be either defined or undefined
298 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
299 pub enum Undefable<T> {
300 Undefined,
301 Defined(T),
302 }
303
304 impl<T> Undefable<T> {
305 pub fn unwrap(self) -> T {
306 match self {
307 Undefable::Undefined => panic!("Undefable::unwrap called on Undefined"),
308 Undefable::Defined(v) => v,
309 }
310 }
311 }
312
313 impl<T> From<T> for Undefable<T> {
314 fn from(v: T) -> Undefable<T> {
315 Undefable::Defined(v)
316 }
317 }
318
319 #[derive(Copy, Clone, Debug)]
320 pub enum ScalarConstant {
321 U8(Undefable<u8>),
322 U16(Undefable<u16>),
323 U32(Undefable<u32>),
324 U64(Undefable<u64>),
325 I8(Undefable<i8>),
326 I16(Undefable<i16>),
327 I32(Undefable<i32>),
328 I64(Undefable<i64>),
329 F16(Undefable<u16>),
330 F32(Undefable<f32>),
331 F64(Undefable<f64>),
332 Bool(Undefable<bool>),
333 }
334
335 macro_rules! define_scalar_vector_constant_impl_without_from {
336 ($type:ident, $name:ident, $get_name:ident) => {
337 impl ScalarConstant {
338 pub fn $get_name(self) -> Undefable<$type> {
339 match self {
340 ScalarConstant::$name(v) => v,
341 _ => unreachable!(concat!("expected a constant ", stringify!($type))),
342 }
343 }
344 }
345 impl VectorConstant {
346 pub fn $get_name(&self) -> &Vec<Undefable<$type>> {
347 match self {
348 VectorConstant::$name(v) => v,
349 _ => unreachable!(concat!(
350 "expected a constant vector with ",
351 stringify!($type),
352 " elements"
353 )),
354 }
355 }
356 }
357 };
358 }
359
360 macro_rules! define_scalar_vector_constant_impl {
361 ($type:ident, $name:ident, $get_name:ident) => {
362 define_scalar_vector_constant_impl_without_from!($type, $name, $get_name);
363 impl From<Undefable<$type>> for ScalarConstant {
364 fn from(v: Undefable<$type>) -> ScalarConstant {
365 ScalarConstant::$name(v)
366 }
367 }
368 impl From<Vec<Undefable<$type>>> for VectorConstant {
369 fn from(v: Vec<Undefable<$type>>) -> VectorConstant {
370 VectorConstant::$name(v)
371 }
372 }
373 };
374 }
375
376 define_scalar_vector_constant_impl!(u8, U8, get_u8);
377 define_scalar_vector_constant_impl!(u16, U16, get_u16);
378 define_scalar_vector_constant_impl!(u32, U32, get_u32);
379 define_scalar_vector_constant_impl!(u64, U64, get_u64);
380 define_scalar_vector_constant_impl!(i8, I8, get_i8);
381 define_scalar_vector_constant_impl!(i16, I16, get_i16);
382 define_scalar_vector_constant_impl!(i32, I32, get_i32);
383 define_scalar_vector_constant_impl!(i64, I64, get_i64);
384 define_scalar_vector_constant_impl_without_from!(u16, F16, get_f16);
385 define_scalar_vector_constant_impl!(f32, F32, get_f32);
386 define_scalar_vector_constant_impl!(f64, F64, get_f64);
387 define_scalar_vector_constant_impl!(bool, Bool, get_bool);
388
389 impl ScalarConstant {
390 pub fn get_type(self) -> FrontendType {
391 FrontendType::Scalar(self.get_scalar_type())
392 }
393 pub fn get_scalar_type(self) -> ScalarType {
394 match self {
395 ScalarConstant::U8(_) => ScalarType::U8,
396 ScalarConstant::U16(_) => ScalarType::U16,
397 ScalarConstant::U32(_) => ScalarType::U32,
398 ScalarConstant::U64(_) => ScalarType::U64,
399 ScalarConstant::I8(_) => ScalarType::I8,
400 ScalarConstant::I16(_) => ScalarType::I16,
401 ScalarConstant::I32(_) => ScalarType::I32,
402 ScalarConstant::I64(_) => ScalarType::I64,
403 ScalarConstant::F16(_) => ScalarType::F16,
404 ScalarConstant::F32(_) => ScalarType::F32,
405 ScalarConstant::F64(_) => ScalarType::F64,
406 ScalarConstant::Bool(_) => ScalarType::Bool,
407 }
408 }
409 }
410
411 #[derive(Clone, Debug)]
412 pub enum VectorConstant {
413 U8(Vec<Undefable<u8>>),
414 U16(Vec<Undefable<u16>>),
415 U32(Vec<Undefable<u32>>),
416 U64(Vec<Undefable<u64>>),
417 I8(Vec<Undefable<i8>>),
418 I16(Vec<Undefable<i16>>),
419 I32(Vec<Undefable<i32>>),
420 I64(Vec<Undefable<i64>>),
421 F16(Vec<Undefable<u16>>),
422 F32(Vec<Undefable<f32>>),
423 F64(Vec<Undefable<f64>>),
424 Bool(Vec<Undefable<bool>>),
425 }
426
427 impl VectorConstant {
428 pub fn get_element_type(&self) -> ScalarType {
429 match self {
430 VectorConstant::U8(_) => ScalarType::U8,
431 VectorConstant::U16(_) => ScalarType::U16,
432 VectorConstant::U32(_) => ScalarType::U32,
433 VectorConstant::U64(_) => ScalarType::U64,
434 VectorConstant::I8(_) => ScalarType::I8,
435 VectorConstant::I16(_) => ScalarType::I16,
436 VectorConstant::I32(_) => ScalarType::I32,
437 VectorConstant::I64(_) => ScalarType::I64,
438 VectorConstant::F16(_) => ScalarType::F16,
439 VectorConstant::F32(_) => ScalarType::F32,
440 VectorConstant::F64(_) => ScalarType::F64,
441 VectorConstant::Bool(_) => ScalarType::Bool,
442 }
443 }
444 pub fn get_element_count(&self) -> usize {
445 match self {
446 VectorConstant::U8(v) => v.len(),
447 VectorConstant::U16(v) => v.len(),
448 VectorConstant::U32(v) => v.len(),
449 VectorConstant::U64(v) => v.len(),
450 VectorConstant::I8(v) => v.len(),
451 VectorConstant::I16(v) => v.len(),
452 VectorConstant::I32(v) => v.len(),
453 VectorConstant::I64(v) => v.len(),
454 VectorConstant::F16(v) => v.len(),
455 VectorConstant::F32(v) => v.len(),
456 VectorConstant::F64(v) => v.len(),
457 VectorConstant::Bool(v) => v.len(),
458 }
459 }
460 pub fn get_type(&self) -> FrontendType {
461 FrontendType::Vector(VectorType {
462 element: self.get_element_type(),
463 element_count: self.get_element_count(),
464 })
465 }
466 }
467
468 #[derive(Clone, Debug)]
469 pub enum Constant {
470 Scalar(ScalarConstant),
471 Vector(VectorConstant),
472 }
473
474 impl Constant {
475 pub fn get_type(&self) -> FrontendType {
476 match self {
477 Constant::Scalar(v) => v.get_type(),
478 Constant::Vector(v) => v.get_type(),
479 }
480 }
481 pub fn get_scalar(&self) -> &ScalarConstant {
482 match self {
483 Constant::Scalar(v) => v,
484 _ => unreachable!("not a scalar constant"),
485 }
486 }
487 }
488
489 #[derive(Debug, Clone)]
490 struct MemberDecoration {
491 member: u32,
492 decoration: Decoration,
493 }
494
495 #[derive(Debug, Clone)]
496 struct BuiltInVariable {
497 built_in: BuiltIn,
498 }
499
500 impl BuiltInVariable {
501 fn get_type(&self, _context: &mut Context) -> Rc<FrontendType> {
502 match self.built_in {
503 BuiltIn::GlobalInvocationId => Rc::new(FrontendType::Vector(VectorType {
504 element: ScalarType::U32,
505 element_count: 3,
506 })),
507 _ => unreachable!("unknown built-in"),
508 }
509 }
510 }
511
512 #[derive(Debug, Clone)]
513 struct UniformVariable {
514 binding: u32,
515 descriptor_set: u32,
516 variable_type: Rc<FrontendType>,
517 }
518
519 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
520 enum CrossLaneBehavior {
521 Uniform,
522 Nonuniform,
523 }
524
525 #[derive(Debug)]
526 struct FrontendValue<'a, C: shader_compiler_backend::Context<'a>> {
527 frontend_type: Rc<FrontendType>,
528 backend_value: Option<C::Value>,
529 cross_lane_behavior: CrossLaneBehavior,
530 }
531
532 #[derive(Debug)]
533 enum IdKind<'a, C: shader_compiler_backend::Context<'a>> {
534 Undefined,
535 DecorationGroup,
536 Type(Rc<FrontendType>),
537 VoidType,
538 FunctionType {
539 return_type: Option<Rc<FrontendType>>,
540 arguments: Vec<Rc<FrontendType>>,
541 },
542 ForwardPointer(Rc<FrontendType>),
543 BuiltInVariable(BuiltInVariable),
544 Constant(Rc<Constant>),
545 UniformVariable(UniformVariable),
546 Function(Option<ParsedShaderFunction>),
547 BasicBlock {
548 basic_block: C::BasicBlock,
549 buildable_basic_block: Option<C::BuildableBasicBlock>,
550 },
551 Value(FrontendValue<'a, C>),
552 }
553
554 #[derive(Debug)]
555 struct IdProperties<'a, C: shader_compiler_backend::Context<'a>> {
556 kind: IdKind<'a, C>,
557 decorations: Vec<Decoration>,
558 member_decorations: Vec<MemberDecoration>,
559 }
560
561 impl<'a, C: shader_compiler_backend::Context<'a>> IdProperties<'a, C> {
562 fn is_empty(&self) -> bool {
563 match self.kind {
564 IdKind::Undefined => {}
565 _ => return false,
566 }
567 self.decorations.is_empty() && self.member_decorations.is_empty()
568 }
569 fn set_kind(&mut self, kind: IdKind<'a, C>) {
570 match &self.kind {
571 IdKind::Undefined => {}
572 _ => unreachable!("duplicate id"),
573 }
574 self.kind = kind;
575 }
576 fn get_type(&self) -> Option<&Rc<FrontendType>> {
577 match &self.kind {
578 IdKind::Type(t) => Some(t),
579 IdKind::VoidType => None,
580 _ => unreachable!("id is not type"),
581 }
582 }
583 fn get_nonvoid_type(&self) -> &Rc<FrontendType> {
584 self.get_type().expect("void is not allowed here")
585 }
586 fn get_constant(&self) -> &Rc<Constant> {
587 match &self.kind {
588 IdKind::Constant(c) => c,
589 _ => unreachable!("id is not a constant"),
590 }
591 }
592 fn get_value(&self) -> &FrontendValue<'a, C> {
593 match &self.kind {
594 IdKind::Value(retval) => retval,
595 _ => unreachable!("id is not a value"),
596 }
597 }
598 fn get_value_mut(&mut self) -> &mut FrontendValue<'a, C> {
599 match &mut self.kind {
600 IdKind::Value(retval) => retval,
601 _ => unreachable!("id is not a value"),
602 }
603 }
604 fn assert_no_member_decorations(&self, id: IdRef) {
605 for member_decoration in &self.member_decorations {
606 unreachable!(
607 "member decoration not allowed on {}: {:?}",
608 id, member_decoration
609 );
610 }
611 }
612 fn assert_no_decorations(&self, id: IdRef) {
613 self.assert_no_member_decorations(id);
614 for decoration in &self.decorations {
615 unreachable!("decoration not allowed on {}: {:?}", id, decoration);
616 }
617 }
618 }
619
620 struct Ids<'a, C: shader_compiler_backend::Context<'a>>(Vec<IdProperties<'a, C>>);
621
622 impl<'a, C: shader_compiler_backend::Context<'a>> Ids<'a, C> {
623 pub fn iter(&self) -> impl Iterator<Item = (IdRef, &IdProperties<'a, C>)> {
624 (1..self.0.len()).map(move |index| (IdRef(index as u32), &self.0[index]))
625 }
626 }
627
628 impl<'a, C: shader_compiler_backend::Context<'a>> fmt::Debug for Ids<'a, C> {
629 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
630 f.debug_map()
631 .entries(
632 self.0
633 .iter()
634 .enumerate()
635 .filter_map(|(id_index, id_properties)| {
636 if id_properties.is_empty() {
637 return None;
638 }
639 Some((IdRef(id_index as u32), id_properties))
640 }),
641 )
642 .finish()
643 }
644 }
645
646 impl<'a, C: shader_compiler_backend::Context<'a>> Index<IdRef> for Ids<'a, C> {
647 type Output = IdProperties<'a, C>;
648 fn index<'b>(&'b self, index: IdRef) -> &'b IdProperties<'a, C> {
649 &self.0[index.0 as usize]
650 }
651 }
652
653 impl<'a, C: shader_compiler_backend::Context<'a>> IndexMut<IdRef> for Ids<'a, C> {
654 fn index_mut(&mut self, index: IdRef) -> &mut IdProperties<'a, C> {
655 &mut self.0[index.0 as usize]
656 }
657 }
658
659 struct ParsedShaderFunction {
660 instructions: Vec<Instruction>,
661 decorations: Vec<Decoration>,
662 }
663
664 impl fmt::Debug for ParsedShaderFunction {
665 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
666 write!(f, "ParsedShaderFunction:\n")?;
667 for instruction in &self.instructions {
668 write!(f, "{}", instruction)?;
669 }
670 Ok(())
671 }
672 }
673
674 #[derive(Debug)]
675 struct ParsedShader<'a, C: shader_compiler_backend::Context<'a>> {
676 ids: Ids<'a, C>,
677 main_function_id: IdRef,
678 interface_variables: Vec<IdRef>,
679 execution_modes: Vec<ExecutionMode>,
680 workgroup_size: Option<(u32, u32, u32)>,
681 }
682
683 struct ShaderEntryPoint {
684 main_function_id: IdRef,
685 interface_variables: Vec<IdRef>,
686 }
687
688 impl<'a, C: shader_compiler_backend::Context<'a>> ParsedShader<'a, C> {
689 fn create(
690 context: &mut Context,
691 stage_info: ShaderStageCreateInfo,
692 execution_model: ExecutionModel,
693 ) -> Self {
694 parsed_shader_create::create(context, stage_info, execution_model)
695 }
696 }
697
698 #[derive(Clone, Debug)]
699 pub struct GenericPipelineOptions {
700 pub optimization_mode: shader_compiler_backend::OptimizationMode,
701 }
702
703 #[derive(Clone, Debug)]
704 pub enum DescriptorLayout {
705 Sampler { count: usize },
706 CombinedImageSampler { count: usize },
707 SampledImage { count: usize },
708 StorageImage { count: usize },
709 UniformTexelBuffer { count: usize },
710 StorageTexelBuffer { count: usize },
711 UniformBuffer { count: usize },
712 StorageBuffer { count: usize },
713 UniformBufferDynamic { count: usize },
714 StorageBufferDynamic { count: usize },
715 InputAttachment { count: usize },
716 }
717
718 #[derive(Clone, Debug)]
719 pub struct DescriptorSetLayout {
720 pub bindings: Vec<Option<DescriptorLayout>>,
721 }
722
723 #[derive(Clone, Debug)]
724 pub struct PipelineLayout {
725 pub push_constants_size: usize,
726 pub descriptor_sets: Vec<DescriptorSetLayout>,
727 }
728
729 #[derive(Debug)]
730 pub struct ComputePipeline {}
731
732 #[derive(Clone, Debug)]
733 pub struct ComputePipelineOptions {
734 pub generic_options: GenericPipelineOptions,
735 }
736
737 #[derive(Copy, Clone, Debug)]
738 pub struct Specialization<'a> {
739 pub id: u32,
740 pub bytes: &'a [u8],
741 }
742
743 #[derive(Copy, Clone, Debug)]
744 pub struct ShaderStageCreateInfo<'a> {
745 pub code: &'a [u32],
746 pub entry_point_name: &'a str,
747 pub specializations: &'a [Specialization<'a>],
748 }
749
750 impl ComputePipeline {
751 pub fn new<C: shader_compiler_backend::Compiler>(
752 options: &ComputePipelineOptions,
753 compute_shader_stage: ShaderStageCreateInfo,
754 pipeline_layout: PipelineLayout,
755 backend_compiler: C,
756 ) -> ComputePipeline {
757 let mut frontend_context = Context::default();
758 struct CompilerUser<'a> {
759 frontend_context: Context,
760 compute_shader_stage: ShaderStageCreateInfo<'a>,
761 }
762 #[derive(Debug)]
763 enum CompileError {}
764 impl<'cu> shader_compiler_backend::CompilerUser for CompilerUser<'cu> {
765 type FunctionKey = CompiledFunctionKey;
766 type Error = CompileError;
767 fn create_error(message: String) -> CompileError {
768 panic!("compile error: {}", message)
769 }
770 fn run<'a, C: shader_compiler_backend::Context<'a>>(
771 self,
772 context: &'a C,
773 ) -> Result<
774 shader_compiler_backend::CompileInputs<'a, C, CompiledFunctionKey>,
775 CompileError,
776 > {
777 let backend_context = context;
778 let CompilerUser {
779 mut frontend_context,
780 compute_shader_stage,
781 } = self;
782 let parsed_shader = ParsedShader::create(
783 &mut frontend_context,
784 compute_shader_stage,
785 ExecutionModel::GLCompute,
786 );
787 let mut module = backend_context.create_module("");
788 let function = parsed_shader.compile(
789 &mut frontend_context,
790 backend_context,
791 &mut module,
792 "fn_",
793 );
794 Ok(shader_compiler_backend::CompileInputs {
795 module: module.verify().unwrap(),
796 callable_functions: iter::once((
797 CompiledFunctionKey::ComputeShaderEntrypoint,
798 function,
799 ))
800 .collect(),
801 })
802 }
803 }
804 let compile_results = backend_compiler
805 .run(
806 CompilerUser {
807 frontend_context,
808 compute_shader_stage,
809 },
810 shader_compiler_backend::CompilerIndependentConfig {
811 optimization_mode: options.generic_options.optimization_mode,
812 }
813 .into(),
814 )
815 .unwrap();
816 unimplemented!()
817 }
818 }