start adding graphics pipeline
[kazan.git] / src / spirv_to_llvm / spirv_to_llvm.h
1 /*
2 * Copyright 2017 Jacob Lifshay
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 * SOFTWARE.
21 *
22 */
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
25
26 #include "spirv/parser.h"
27 #include <stdexcept>
28 #include <memory>
29 #include <vector>
30 #include <string>
31 #include <cassert>
32 #include <type_traits>
33 #include <utility>
34 #include <cstddef>
35 #include "llvm_wrapper/llvm_wrapper.h"
36
37 namespace vulkan_cpu
38 {
39 namespace spirv_to_llvm
40 {
41 struct LLVM_type_and_alignment
42 {
43 ::LLVMTypeRef type;
44 std::size_t alignment;
45 constexpr LLVM_type_and_alignment() noexcept : type(nullptr), alignment(0)
46 {
47 }
48 constexpr LLVM_type_and_alignment(::LLVMTypeRef type, std::size_t alignment) noexcept
49 : type(type),
50 alignment(alignment)
51 {
52 }
53 };
54
55 class Simple_type_descriptor;
56 class Vector_type_descriptor;
57 class Matrix_type_descriptor;
58 class Array_type_descriptor;
59 class Pointer_type_descriptor;
60 class Function_type_descriptor;
61 class Struct_type_descriptor;
62 class Type_descriptor
63 {
64 Type_descriptor(const Type_descriptor &) = delete;
65 Type_descriptor &operator=(const Type_descriptor &) = delete;
66
67 public:
68 struct Type_visitor
69 {
70 virtual ~Type_visitor() = default;
71 virtual void visit(Simple_type_descriptor &type) = 0;
72 virtual void visit(Vector_type_descriptor &type) = 0;
73 virtual void visit(Matrix_type_descriptor &type) = 0;
74 virtual void visit(Array_type_descriptor &type) = 0;
75 virtual void visit(Pointer_type_descriptor &type) = 0;
76 virtual void visit(Function_type_descriptor &type) = 0;
77 virtual void visit(Struct_type_descriptor &type) = 0;
78 };
79
80 public:
81 const std::vector<spirv::Decoration_with_parameters> decorations;
82
83 public:
84 explicit Type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations) noexcept
85 : decorations(std::move(decorations))
86 {
87 }
88 virtual ~Type_descriptor() = default;
89 virtual LLVM_type_and_alignment get_or_make_type() = 0;
90 virtual void visit(Type_visitor &type_visitor) = 0;
91 void visit(Type_visitor &&type_visitor)
92 {
93 visit(type_visitor);
94 }
95 template <typename Fn>
96 typename std::enable_if<!std::is_convertible<Fn &&, const Type_visitor &>::value, void>::type
97 visit(Fn &&fn)
98 {
99 struct Visitor final : public Type_visitor
100 {
101 Fn &fn;
102 virtual void visit(Simple_type_descriptor &type) override
103 {
104 std::forward<Fn>(fn)(type);
105 }
106 virtual void visit(Vector_type_descriptor &type) override
107 {
108 std::forward<Fn>(fn)(type);
109 }
110 virtual void visit(Matrix_type_descriptor &type) override
111 {
112 std::forward<Fn>(fn)(type);
113 }
114 virtual void visit(Array_type_descriptor &type) override
115 {
116 std::forward<Fn>(fn)(type);
117 }
118 virtual void visit(Pointer_type_descriptor &type) override
119 {
120 std::forward<Fn>(fn)(type);
121 }
122 virtual void visit(Function_type_descriptor &type) override
123 {
124 std::forward<Fn>(fn)(type);
125 }
126 virtual void visit(Struct_type_descriptor &type) override
127 {
128 std::forward<Fn>(fn)(type);
129 }
130 explicit Visitor(Fn &fn) noexcept : fn(fn)
131 {
132 }
133 };
134 visit(Visitor(fn));
135 }
136 class Recursion_checker;
137 class Recursion_checker_state
138 {
139 friend class Recursion_checker;
140
141 private:
142 std::size_t recursion_count = 0;
143 };
144 class Recursion_checker
145 {
146 Recursion_checker(const Recursion_checker &) = delete;
147 Recursion_checker &operator=(const Recursion_checker &) = delete;
148
149 private:
150 Recursion_checker_state &state;
151
152 public:
153 explicit Recursion_checker(Recursion_checker_state &state,
154 std::size_t instruction_start_index)
155 : state(state)
156 {
157 state.recursion_count++;
158 if(state.recursion_count > 5)
159 throw spirv::Parser_error(instruction_start_index,
160 instruction_start_index,
161 "too many recursions making type");
162 }
163 ~Recursion_checker()
164 {
165 state.recursion_count--;
166 }
167 std::size_t get_recursion_count() const noexcept
168 {
169 return state.recursion_count;
170 }
171 bool is_nested_recursion() const noexcept
172 {
173 return get_recursion_count() > 1;
174 }
175 };
176 };
177
178 class Simple_type_descriptor final : public Type_descriptor
179 {
180 private:
181 LLVM_type_and_alignment type;
182
183 public:
184 explicit Simple_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
185 LLVM_type_and_alignment type) noexcept
186 : Type_descriptor(std::move(decorations)),
187 type(type)
188 {
189 }
190 virtual LLVM_type_and_alignment get_or_make_type() override
191 {
192 return type;
193 }
194 virtual void visit(Type_visitor &type_visitor) override
195 {
196 type_visitor.visit(*this);
197 }
198 };
199
200 class Vector_type_descriptor final : public Type_descriptor
201 {
202 private:
203 LLVM_type_and_alignment type;
204 std::shared_ptr<Simple_type_descriptor> element_type;
205 std::size_t element_count;
206
207 public:
208 explicit Vector_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
209 std::shared_ptr<Simple_type_descriptor> element_type,
210 std::size_t element_count,
211 ::LLVMTargetDataRef target_data) noexcept
212 : Type_descriptor(std::move(decorations)),
213 type(make_vector_type(element_type, element_count, target_data)),
214 element_type(std::move(element_type)),
215 element_count(element_count)
216 {
217 }
218 static LLVM_type_and_alignment make_vector_type(
219 const std::shared_ptr<Simple_type_descriptor> &element_type,
220 std::size_t element_count,
221 ::LLVMTargetDataRef target_data)
222 {
223 auto llvm_element_type = element_type->get_or_make_type();
224 auto type = ::LLVMVectorType(llvm_element_type.type, element_count);
225 std::size_t alignment = ::LLVMPreferredAlignmentOfType(target_data, type);
226 constexpr std::size_t max_abi_alignment = alignof(std::max_align_t);
227 if(alignment > max_abi_alignment)
228 alignment = max_abi_alignment;
229 return {type, alignment};
230 }
231 virtual LLVM_type_and_alignment get_or_make_type() override
232 {
233 return type;
234 }
235 virtual void visit(Type_visitor &type_visitor) override
236 {
237 type_visitor.visit(*this);
238 }
239 const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
240 {
241 return element_type;
242 }
243 std::size_t get_element_count() const noexcept
244 {
245 return element_count;
246 }
247 };
248
249 class Matrix_type_descriptor final : public Type_descriptor
250 {
251 private:
252 LLVM_type_and_alignment type;
253 std::shared_ptr<Vector_type_descriptor> column_type;
254 std::size_t column_count;
255
256 public:
257 explicit Matrix_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
258 std::shared_ptr<Vector_type_descriptor> column_type,
259 std::size_t column_count,
260 ::LLVMTargetDataRef target_data) noexcept
261 : Type_descriptor(std::move(decorations)),
262 type(Vector_type_descriptor::make_vector_type(column_type->get_element_type(),
263 column_type->get_element_count()
264 * column_count,
265 target_data)),
266 column_type(std::move(column_type)),
267 column_count(column_count)
268 {
269 }
270 virtual LLVM_type_and_alignment get_or_make_type() override
271 {
272 return type;
273 }
274 virtual void visit(Type_visitor &type_visitor) override
275 {
276 type_visitor.visit(*this);
277 }
278 const std::shared_ptr<Vector_type_descriptor> &get_column_type() const noexcept
279 {
280 return column_type;
281 }
282 std::size_t get_column_count() const noexcept
283 {
284 return column_count;
285 }
286 };
287
288 class Array_type_descriptor final : public Type_descriptor
289 {
290 private:
291 LLVM_type_and_alignment type;
292 std::shared_ptr<Type_descriptor> element_type;
293 std::size_t element_count;
294 std::size_t instruction_start_index;
295 Recursion_checker_state recursion_checker_state;
296
297 public:
298 explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
299 std::shared_ptr<Type_descriptor> element_type,
300 std::size_t element_count,
301 std::size_t instruction_start_index) noexcept
302 : Type_descriptor(std::move(decorations)),
303 type(),
304 element_type(std::move(element_type)),
305 element_count(element_count),
306 instruction_start_index(instruction_start_index)
307 {
308 }
309 virtual LLVM_type_and_alignment get_or_make_type() override
310 {
311 if(!type.type)
312 {
313 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
314 auto llvm_element_type = element_type->get_or_make_type();
315 type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
316 llvm_element_type.alignment);
317 }
318 return type;
319 }
320 virtual void visit(Type_visitor &type_visitor) override
321 {
322 type_visitor.visit(*this);
323 }
324 const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
325 {
326 return element_type;
327 }
328 std::size_t get_element_count() const noexcept
329 {
330 return element_count;
331 }
332 };
333
334 class Pointer_type_descriptor final : public Type_descriptor
335 {
336 private:
337 std::shared_ptr<Type_descriptor> base;
338 std::size_t instruction_start_index;
339 LLVM_type_and_alignment type;
340 Recursion_checker_state recursion_checker_state;
341
342 public:
343 Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
344 std::shared_ptr<Type_descriptor> base,
345 std::size_t instruction_start_index,
346 ::LLVMTargetDataRef target_data) noexcept
347 : Type_descriptor(std::move(decorations)),
348 base(std::move(base)),
349 instruction_start_index(instruction_start_index),
350 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data))
351 {
352 }
353 const std::shared_ptr<Type_descriptor> &get_base_type() const noexcept
354 {
355 return base;
356 }
357 void set_base_type(std::shared_ptr<Type_descriptor> new_base) noexcept
358 {
359 assert(!base);
360 assert(new_base);
361 base = std::move(new_base);
362 }
363 explicit Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
364 std::size_t instruction_start_index,
365 ::LLVMTargetDataRef target_data) noexcept
366 : Type_descriptor(std::move(decorations)),
367 base(nullptr),
368 instruction_start_index(instruction_start_index),
369 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data))
370 {
371 }
372 virtual LLVM_type_and_alignment get_or_make_type() override
373 {
374 if(!type.type)
375 {
376 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
377 if(!base)
378 throw spirv::Parser_error(
379 instruction_start_index,
380 instruction_start_index,
381 "attempting to create type from pointer forward declaration");
382 auto base_type = base->get_or_make_type();
383 constexpr unsigned default_address_space = 0;
384 type.type = ::LLVMPointerType(base_type.type, default_address_space);
385 }
386 return type;
387 }
388 virtual void visit(Type_visitor &type_visitor) override
389 {
390 type_visitor.visit(*this);
391 }
392 };
393
394 class Function_type_descriptor final : public Type_descriptor
395 {
396 private:
397 std::shared_ptr<Type_descriptor> return_type;
398 std::vector<std::shared_ptr<Type_descriptor>> args;
399 LLVM_type_and_alignment type;
400 Recursion_checker_state recursion_checker_state;
401 std::size_t instruction_start_index;
402 bool is_var_arg;
403
404 public:
405 explicit Function_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
406 std::shared_ptr<Type_descriptor> return_type,
407 std::vector<std::shared_ptr<Type_descriptor>> args,
408 std::size_t instruction_start_index,
409 ::LLVMTargetDataRef target_data,
410 bool is_var_arg = false) noexcept
411 : Type_descriptor(std::move(decorations)),
412 return_type(std::move(return_type)),
413 args(std::move(args)),
414 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data)),
415 instruction_start_index(instruction_start_index),
416 is_var_arg(is_var_arg)
417 {
418 }
419 virtual LLVM_type_and_alignment get_or_make_type() override
420 {
421 if(!type.type)
422 {
423 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
424 std::vector<::LLVMTypeRef> llvm_args;
425 llvm_args.reserve(args.size());
426 auto llvm_return_type = return_type->get_or_make_type();
427 for(auto &arg : args)
428 llvm_args.push_back(arg->get_or_make_type().type);
429 type.type = ::LLVMFunctionType(
430 llvm_return_type.type, llvm_args.data(), llvm_args.size(), is_var_arg);
431 }
432 return type;
433 }
434 virtual void visit(Type_visitor &type_visitor) override
435 {
436 type_visitor.visit(*this);
437 }
438 };
439
440 class Struct_type_descriptor final : public Type_descriptor
441 {
442 public:
443 struct Member
444 {
445 std::vector<spirv::Decoration_with_parameters> decorations;
446 std::size_t llvm_member_index = -1;
447 std::shared_ptr<Type_descriptor> type;
448 explicit Member(std::vector<spirv::Decoration_with_parameters> decorations,
449 std::shared_ptr<Type_descriptor> type) noexcept
450 : decorations(std::move(decorations)),
451 type(std::move(type))
452 {
453 }
454 };
455
456 private:
457 std::vector<Member> members;
458 util::Enum_map<spirv::Built_in, std::size_t> builtin_members;
459 LLVM_type_and_alignment type;
460 bool is_complete;
461 Recursion_checker_state recursion_checker_state;
462 std::size_t instruction_start_index;
463 ::LLVMContextRef context;
464 ::LLVMTargetDataRef target_data;
465 void complete_type();
466 void on_add_member(std::size_t added_member_index) noexcept
467 {
468 assert(!is_complete);
469 auto &member = members[added_member_index];
470 for(auto &decoration : member.decorations)
471 if(decoration.value == spirv::Decoration::built_in)
472 builtin_members[util::get<spirv::Decoration_built_in_parameters>(
473 decoration.parameters)
474 .built_in] = added_member_index;
475 }
476
477 public:
478 std::size_t add_member(Member member)
479 {
480 std::size_t index = members.size();
481 members.push_back(std::move(member));
482 on_add_member(index);
483 return index;
484 }
485 const std::vector<Member> &get_members(bool need_llvm_member_indexes)
486 {
487 if(need_llvm_member_indexes)
488 get_or_make_type();
489 return members;
490 }
491 explicit Struct_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
492 ::LLVMContextRef context,
493 ::LLVMTargetDataRef target_data,
494 const char *name,
495 std::size_t instruction_start_index,
496 std::vector<Member> members = {})
497 : Type_descriptor(std::move(decorations)),
498 members(std::move(members)),
499 builtin_members{},
500 type(::LLVMStructCreateNamed(context, name), 0),
501 is_complete(false),
502 instruction_start_index(instruction_start_index),
503 context(context),
504 target_data(target_data)
505 {
506 for(std::size_t member_index = 0; member_index < members.size(); member_index++)
507 on_add_member(member_index);
508 }
509 virtual LLVM_type_and_alignment get_or_make_type() override
510 {
511 if(!is_complete)
512 {
513 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
514 if(!recursion_checker.is_nested_recursion())
515 complete_type();
516 }
517 return type;
518 }
519 virtual void visit(Type_visitor &type_visitor) override
520 {
521 type_visitor.visit(*this);
522 }
523 };
524
525 class Constant_descriptor
526 {
527 Constant_descriptor(const Constant_descriptor &) = delete;
528 Constant_descriptor &operator=(const Constant_descriptor &) = delete;
529
530 public:
531 const std::shared_ptr<Type_descriptor> type;
532
533 public:
534 explicit Constant_descriptor(std::shared_ptr<Type_descriptor> type) noexcept
535 : type(std::move(type))
536 {
537 }
538 ~Constant_descriptor() = default;
539 virtual ::LLVMValueRef get_or_make_value() = 0;
540 };
541
542 class Simple_constant_descriptor final : public Constant_descriptor
543 {
544 private:
545 ::LLVMValueRef value;
546
547 public:
548 explicit Simple_constant_descriptor(std::shared_ptr<Type_descriptor> type,
549 ::LLVMValueRef value) noexcept
550 : Constant_descriptor(std::move(type)),
551 value(value)
552 {
553 }
554 virtual ::LLVMValueRef get_or_make_value() override
555 {
556 return value;
557 }
558 };
559
560 struct Converted_module
561 {
562 llvm_wrapper::Module module;
563 std::string entry_function_name;
564 std::shared_ptr<Struct_type_descriptor> io_struct;
565 std::size_t inputs_member;
566 std::shared_ptr<Struct_type_descriptor> inputs_struct;
567 std::size_t outputs_member;
568 std::shared_ptr<Struct_type_descriptor> outputs_struct;
569 Converted_module() = default;
570 explicit Converted_module(llvm_wrapper::Module module,
571 std::string entry_function_name,
572 std::shared_ptr<Struct_type_descriptor> io_struct,
573 std::size_t inputs_member,
574 std::shared_ptr<Struct_type_descriptor> inputs_struct,
575 std::size_t outputs_member,
576 std::shared_ptr<Struct_type_descriptor> outputs_struct) noexcept
577 : module(std::move(module)),
578 entry_function_name(std::move(entry_function_name)),
579 io_struct(std::move(io_struct)),
580 inputs_member(inputs_member),
581 inputs_struct(std::move(inputs_struct)),
582 outputs_member(outputs_member),
583 outputs_struct(std::move(outputs_struct))
584 {
585 }
586 };
587
588 class Spirv_to_llvm;
589
590 Converted_module spirv_to_llvm(::LLVMContextRef context,
591 ::LLVMTargetMachineRef target_machine,
592 const spirv::Word *shader_words,
593 std::size_t shader_size,
594 std::uint64_t shader_id,
595 spirv::Execution_model execution_model,
596 util::string_view entry_point_name);
597 }
598 }
599
600 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */