change project name to Kazan and reformat code
[kazan.git] / src / spirv_to_llvm / spirv_to_llvm.h
1 /*
2 * Copyright 2017 Jacob Lifshay
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 * SOFTWARE.
21 *
22 */
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
25
26 #include "spirv/parser.h"
27 #include <stdexcept>
28 #include <memory>
29 #include <vector>
30 #include <string>
31 #include <cassert>
32 #include <type_traits>
33 #include <utility>
34 #include <cstddef>
35 #include "llvm_wrapper/llvm_wrapper.h"
36 #include "util/string_view.h"
37 #include "vulkan/vulkan.h"
38
39 namespace kazan
40 {
41 namespace spirv_to_llvm
42 {
43 struct LLVM_type_and_alignment
44 {
45 ::LLVMTypeRef type;
46 std::size_t alignment;
47 constexpr LLVM_type_and_alignment() noexcept : type(nullptr), alignment(0)
48 {
49 }
50 constexpr LLVM_type_and_alignment(::LLVMTypeRef type, std::size_t alignment) noexcept
51 : type(type),
52 alignment(alignment)
53 {
54 }
55 };
56
57 class Simple_type_descriptor;
58 class Vector_type_descriptor;
59 class Matrix_type_descriptor;
60 class Array_type_descriptor;
61 class Pointer_type_descriptor;
62 class Function_type_descriptor;
63 class Struct_type_descriptor;
64 class Type_descriptor
65 {
66 Type_descriptor(const Type_descriptor &) = delete;
67 Type_descriptor &operator=(const Type_descriptor &) = delete;
68
69 public:
70 struct Type_visitor
71 {
72 virtual ~Type_visitor() = default;
73 virtual void visit(Simple_type_descriptor &type) = 0;
74 virtual void visit(Vector_type_descriptor &type) = 0;
75 virtual void visit(Matrix_type_descriptor &type) = 0;
76 virtual void visit(Array_type_descriptor &type) = 0;
77 virtual void visit(Pointer_type_descriptor &type) = 0;
78 virtual void visit(Function_type_descriptor &type) = 0;
79 virtual void visit(Struct_type_descriptor &type) = 0;
80 };
81
82 public:
83 const std::vector<spirv::Decoration_with_parameters> decorations;
84
85 public:
86 explicit Type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations) noexcept
87 : decorations(std::move(decorations))
88 {
89 }
90 virtual ~Type_descriptor() = default;
91 virtual LLVM_type_and_alignment get_or_make_type() = 0;
92 virtual void visit(Type_visitor &type_visitor) = 0;
93 void visit(Type_visitor &&type_visitor)
94 {
95 visit(type_visitor);
96 }
97 template <typename Fn>
98 typename std::enable_if<!std::is_convertible<Fn &&, const Type_visitor &>::value, void>::type
99 visit(Fn &&fn)
100 {
101 struct Visitor final : public Type_visitor
102 {
103 Fn &fn;
104 virtual void visit(Simple_type_descriptor &type) override
105 {
106 std::forward<Fn>(fn)(type);
107 }
108 virtual void visit(Vector_type_descriptor &type) override
109 {
110 std::forward<Fn>(fn)(type);
111 }
112 virtual void visit(Matrix_type_descriptor &type) override
113 {
114 std::forward<Fn>(fn)(type);
115 }
116 virtual void visit(Array_type_descriptor &type) override
117 {
118 std::forward<Fn>(fn)(type);
119 }
120 virtual void visit(Pointer_type_descriptor &type) override
121 {
122 std::forward<Fn>(fn)(type);
123 }
124 virtual void visit(Function_type_descriptor &type) override
125 {
126 std::forward<Fn>(fn)(type);
127 }
128 virtual void visit(Struct_type_descriptor &type) override
129 {
130 std::forward<Fn>(fn)(type);
131 }
132 explicit Visitor(Fn &fn) noexcept : fn(fn)
133 {
134 }
135 };
136 visit(Visitor(fn));
137 }
138 class Recursion_checker;
139 class Recursion_checker_state
140 {
141 friend class Recursion_checker;
142
143 private:
144 std::size_t recursion_count = 0;
145 };
146 class Recursion_checker
147 {
148 Recursion_checker(const Recursion_checker &) = delete;
149 Recursion_checker &operator=(const Recursion_checker &) = delete;
150
151 private:
152 Recursion_checker_state &state;
153
154 public:
155 explicit Recursion_checker(Recursion_checker_state &state,
156 std::size_t instruction_start_index)
157 : state(state)
158 {
159 state.recursion_count++;
160 if(state.recursion_count > 5)
161 throw spirv::Parser_error(instruction_start_index,
162 instruction_start_index,
163 "too many recursions making type");
164 }
165 ~Recursion_checker()
166 {
167 state.recursion_count--;
168 }
169 std::size_t get_recursion_count() const noexcept
170 {
171 return state.recursion_count;
172 }
173 bool is_nested_recursion() const noexcept
174 {
175 return get_recursion_count() > 1;
176 }
177 };
178 };
179
180 class Simple_type_descriptor final : public Type_descriptor
181 {
182 private:
183 LLVM_type_and_alignment type;
184
185 public:
186 explicit Simple_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
187 LLVM_type_and_alignment type) noexcept
188 : Type_descriptor(std::move(decorations)),
189 type(type)
190 {
191 }
192 virtual LLVM_type_and_alignment get_or_make_type() override
193 {
194 return type;
195 }
196 virtual void visit(Type_visitor &type_visitor) override
197 {
198 type_visitor.visit(*this);
199 }
200 };
201
202 class Vector_type_descriptor final : public Type_descriptor
203 {
204 private:
205 LLVM_type_and_alignment type;
206 std::shared_ptr<Simple_type_descriptor> element_type;
207 std::size_t element_count;
208
209 public:
210 explicit Vector_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
211 std::shared_ptr<Simple_type_descriptor> element_type,
212 std::size_t element_count,
213 ::LLVMTargetDataRef target_data) noexcept
214 : Type_descriptor(std::move(decorations)),
215 type(make_vector_type(element_type, element_count, target_data)),
216 element_type(std::move(element_type)),
217 element_count(element_count)
218 {
219 }
220 static LLVM_type_and_alignment make_vector_type(
221 const std::shared_ptr<Simple_type_descriptor> &element_type,
222 std::size_t element_count,
223 ::LLVMTargetDataRef target_data)
224 {
225 auto llvm_element_type = element_type->get_or_make_type();
226 auto type = ::LLVMVectorType(llvm_element_type.type, element_count);
227 std::size_t alignment = ::LLVMPreferredAlignmentOfType(target_data, type);
228 constexpr std::size_t max_abi_alignment = alignof(std::max_align_t);
229 if(alignment > max_abi_alignment)
230 alignment = max_abi_alignment;
231 return {type, alignment};
232 }
233 virtual LLVM_type_and_alignment get_or_make_type() override
234 {
235 return type;
236 }
237 virtual void visit(Type_visitor &type_visitor) override
238 {
239 type_visitor.visit(*this);
240 }
241 const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
242 {
243 return element_type;
244 }
245 std::size_t get_element_count() const noexcept
246 {
247 return element_count;
248 }
249 };
250
251 class Matrix_type_descriptor final : public Type_descriptor
252 {
253 private:
254 LLVM_type_and_alignment type;
255 std::shared_ptr<Vector_type_descriptor> column_type;
256 std::size_t column_count;
257
258 public:
259 explicit Matrix_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
260 std::shared_ptr<Vector_type_descriptor> column_type,
261 std::size_t column_count,
262 ::LLVMTargetDataRef target_data) noexcept
263 : Type_descriptor(std::move(decorations)),
264 type(Vector_type_descriptor::make_vector_type(column_type->get_element_type(),
265 column_type->get_element_count()
266 * column_count,
267 target_data)),
268 column_type(std::move(column_type)),
269 column_count(column_count)
270 {
271 }
272 virtual LLVM_type_and_alignment get_or_make_type() override
273 {
274 return type;
275 }
276 virtual void visit(Type_visitor &type_visitor) override
277 {
278 type_visitor.visit(*this);
279 }
280 const std::shared_ptr<Vector_type_descriptor> &get_column_type() const noexcept
281 {
282 return column_type;
283 }
284 std::size_t get_column_count() const noexcept
285 {
286 return column_count;
287 }
288 };
289
290 class Array_type_descriptor final : public Type_descriptor
291 {
292 private:
293 LLVM_type_and_alignment type;
294 std::shared_ptr<Type_descriptor> element_type;
295 std::size_t element_count;
296 std::size_t instruction_start_index;
297 Recursion_checker_state recursion_checker_state;
298
299 public:
300 explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
301 std::shared_ptr<Type_descriptor> element_type,
302 std::size_t element_count,
303 std::size_t instruction_start_index) noexcept
304 : Type_descriptor(std::move(decorations)),
305 type(),
306 element_type(std::move(element_type)),
307 element_count(element_count),
308 instruction_start_index(instruction_start_index)
309 {
310 }
311 virtual LLVM_type_and_alignment get_or_make_type() override
312 {
313 if(!type.type)
314 {
315 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
316 auto llvm_element_type = element_type->get_or_make_type();
317 type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
318 llvm_element_type.alignment);
319 }
320 return type;
321 }
322 virtual void visit(Type_visitor &type_visitor) override
323 {
324 type_visitor.visit(*this);
325 }
326 const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
327 {
328 return element_type;
329 }
330 std::size_t get_element_count() const noexcept
331 {
332 return element_count;
333 }
334 };
335
336 class Pointer_type_descriptor final : public Type_descriptor
337 {
338 private:
339 std::shared_ptr<Type_descriptor> base;
340 std::size_t instruction_start_index;
341 LLVM_type_and_alignment type;
342 Recursion_checker_state recursion_checker_state;
343
344 public:
345 Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
346 std::shared_ptr<Type_descriptor> base,
347 std::size_t instruction_start_index,
348 ::LLVMTargetDataRef target_data) noexcept
349 : Type_descriptor(std::move(decorations)),
350 base(std::move(base)),
351 instruction_start_index(instruction_start_index),
352 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data))
353 {
354 }
355 const std::shared_ptr<Type_descriptor> &get_base_type() const noexcept
356 {
357 return base;
358 }
359 void set_base_type(std::shared_ptr<Type_descriptor> new_base) noexcept
360 {
361 assert(!base);
362 assert(new_base);
363 base = std::move(new_base);
364 }
365 explicit Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
366 std::size_t instruction_start_index,
367 ::LLVMTargetDataRef target_data) noexcept
368 : Type_descriptor(std::move(decorations)),
369 base(nullptr),
370 instruction_start_index(instruction_start_index),
371 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data))
372 {
373 }
374 virtual LLVM_type_and_alignment get_or_make_type() override
375 {
376 if(!type.type)
377 {
378 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
379 if(!base)
380 throw spirv::Parser_error(
381 instruction_start_index,
382 instruction_start_index,
383 "attempting to create type from pointer forward declaration");
384 auto base_type = base->get_or_make_type();
385 constexpr unsigned default_address_space = 0;
386 type.type = ::LLVMPointerType(base_type.type, default_address_space);
387 }
388 return type;
389 }
390 virtual void visit(Type_visitor &type_visitor) override
391 {
392 type_visitor.visit(*this);
393 }
394 };
395
396 class Function_type_descriptor final : public Type_descriptor
397 {
398 private:
399 std::shared_ptr<Type_descriptor> return_type;
400 std::vector<std::shared_ptr<Type_descriptor>> args;
401 LLVM_type_and_alignment type;
402 Recursion_checker_state recursion_checker_state;
403 std::size_t instruction_start_index;
404 bool valid_for_entry_point;
405 bool is_var_arg;
406
407 public:
408 explicit Function_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
409 std::shared_ptr<Type_descriptor> return_type,
410 std::vector<std::shared_ptr<Type_descriptor>> args,
411 std::size_t instruction_start_index,
412 ::LLVMTargetDataRef target_data,
413 bool valid_for_entry_point,
414 bool is_var_arg) noexcept
415 : Type_descriptor(std::move(decorations)),
416 return_type(std::move(return_type)),
417 args(std::move(args)),
418 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data)),
419 instruction_start_index(instruction_start_index),
420 valid_for_entry_point(valid_for_entry_point),
421 is_var_arg(is_var_arg)
422 {
423 }
424 virtual LLVM_type_and_alignment get_or_make_type() override
425 {
426 if(!type.type)
427 {
428 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
429 std::vector<::LLVMTypeRef> llvm_args;
430 llvm_args.reserve(args.size());
431 auto llvm_return_type = return_type->get_or_make_type();
432 for(auto &arg : args)
433 llvm_args.push_back(arg->get_or_make_type().type);
434 type.type = ::LLVMFunctionType(
435 llvm_return_type.type, llvm_args.data(), llvm_args.size(), is_var_arg);
436 }
437 return type;
438 }
439 virtual void visit(Type_visitor &type_visitor) override
440 {
441 type_visitor.visit(*this);
442 }
443 bool is_valid_for_entry_point() const noexcept
444 {
445 return valid_for_entry_point;
446 }
447 };
448
449 class Struct_type_descriptor final : public Type_descriptor
450 {
451 public:
452 struct Member
453 {
454 std::vector<spirv::Decoration_with_parameters> decorations;
455 std::size_t llvm_member_index = -1;
456 std::shared_ptr<Type_descriptor> type;
457 explicit Member(std::vector<spirv::Decoration_with_parameters> decorations,
458 std::shared_ptr<Type_descriptor> type) noexcept
459 : decorations(std::move(decorations)),
460 type(std::move(type))
461 {
462 }
463 };
464
465 private:
466 std::vector<Member> members;
467 util::Enum_map<spirv::Built_in, std::size_t> builtin_members;
468 LLVM_type_and_alignment type;
469 bool is_complete;
470 Recursion_checker_state recursion_checker_state;
471 std::size_t instruction_start_index;
472 ::LLVMContextRef context;
473 ::LLVMTargetDataRef target_data;
474 void complete_type();
475 void on_add_member(std::size_t added_member_index) noexcept
476 {
477 assert(!is_complete);
478 auto &member = members[added_member_index];
479 for(auto &decoration : member.decorations)
480 if(decoration.value == spirv::Decoration::built_in)
481 builtin_members[util::get<spirv::Decoration_built_in_parameters>(
482 decoration.parameters)
483 .built_in] = added_member_index;
484 }
485
486 public:
487 std::size_t add_member(Member member)
488 {
489 std::size_t index = members.size();
490 members.push_back(std::move(member));
491 on_add_member(index);
492 return index;
493 }
494 const std::vector<Member> &get_members(bool need_llvm_member_indexes)
495 {
496 if(need_llvm_member_indexes)
497 get_or_make_type();
498 return members;
499 }
500 explicit Struct_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
501 ::LLVMContextRef context,
502 ::LLVMTargetDataRef target_data,
503 const char *name,
504 std::size_t instruction_start_index,
505 std::vector<Member> members = {})
506 : Type_descriptor(std::move(decorations)),
507 members(std::move(members)),
508 builtin_members{},
509 type(::LLVMStructCreateNamed(context, name), 0),
510 is_complete(false),
511 instruction_start_index(instruction_start_index),
512 context(context),
513 target_data(target_data)
514 {
515 for(std::size_t member_index = 0; member_index < members.size(); member_index++)
516 on_add_member(member_index);
517 }
518 virtual LLVM_type_and_alignment get_or_make_type() override
519 {
520 if(!is_complete)
521 {
522 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
523 if(!recursion_checker.is_nested_recursion())
524 complete_type();
525 }
526 return type;
527 }
528 virtual void visit(Type_visitor &type_visitor) override
529 {
530 type_visitor.visit(*this);
531 }
532 };
533
534 class Constant_descriptor
535 {
536 Constant_descriptor(const Constant_descriptor &) = delete;
537 Constant_descriptor &operator=(const Constant_descriptor &) = delete;
538
539 public:
540 const std::shared_ptr<Type_descriptor> type;
541
542 public:
543 explicit Constant_descriptor(std::shared_ptr<Type_descriptor> type) noexcept
544 : type(std::move(type))
545 {
546 }
547 ~Constant_descriptor() = default;
548 virtual ::LLVMValueRef get_or_make_value() = 0;
549 };
550
551 class Simple_constant_descriptor final : public Constant_descriptor
552 {
553 private:
554 ::LLVMValueRef value;
555
556 public:
557 explicit Simple_constant_descriptor(std::shared_ptr<Type_descriptor> type,
558 ::LLVMValueRef value) noexcept
559 : Constant_descriptor(std::move(type)),
560 value(value)
561 {
562 }
563 virtual ::LLVMValueRef get_or_make_value() override
564 {
565 return value;
566 }
567 };
568
569 struct Converted_module
570 {
571 llvm_wrapper::Module module;
572 std::string entry_function_name;
573 std::shared_ptr<Struct_type_descriptor> inputs_struct;
574 std::shared_ptr<Struct_type_descriptor> outputs_struct;
575 spirv::Execution_model execution_model;
576 Converted_module() = default;
577 explicit Converted_module(llvm_wrapper::Module module,
578 std::string entry_function_name,
579 std::shared_ptr<Struct_type_descriptor> inputs_struct,
580 std::shared_ptr<Struct_type_descriptor> outputs_struct,
581 spirv::Execution_model execution_model) noexcept
582 : module(std::move(module)),
583 entry_function_name(std::move(entry_function_name)),
584 inputs_struct(std::move(inputs_struct)),
585 outputs_struct(std::move(outputs_struct)),
586 execution_model(execution_model)
587 {
588 }
589 };
590
591 struct Jit_symbol_resolver
592 {
593 typedef void (*Resolved_symbol)();
594 Resolved_symbol resolve(util::string_view name)
595 {
596 #warning finish implementing
597 return nullptr;
598 }
599 static std::uint64_t resolve(const char *name, void *user_data) noexcept
600 {
601 return reinterpret_cast<std::uint64_t>(
602 static_cast<Jit_symbol_resolver *>(user_data)->resolve(name));
603 }
604 static std::uintptr_t resolve(const std::string &name, void *user_data) noexcept
605 {
606 return reinterpret_cast<std::uintptr_t>(
607 static_cast<Jit_symbol_resolver *>(user_data)->resolve(name));
608 }
609 };
610
611 class Spirv_to_llvm;
612
613 Converted_module spirv_to_llvm(::LLVMContextRef context,
614 ::LLVMTargetMachineRef target_machine,
615 const spirv::Word *shader_words,
616 std::size_t shader_size,
617 std::uint64_t shader_id,
618 spirv::Execution_model execution_model,
619 util::string_view entry_point_name,
620 const VkPipelineVertexInputStateCreateInfo *vertex_input_state);
621 }
622 }
623
624 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */