generating x86 code works
[kazan.git] / src / spirv_to_llvm / spirv_to_llvm.h
1 /*
2 * Copyright 2017 Jacob Lifshay
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 * SOFTWARE.
21 *
22 */
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
25
26 #include "spirv/parser.h"
27 #include <stdexcept>
28 #include <memory>
29 #include <vector>
30 #include <string>
31 #include <cassert>
32 #include <type_traits>
33 #include <utility>
34 #include "llvm_wrapper/llvm_wrapper.h"
35
36 namespace vulkan_cpu
37 {
38 namespace spirv_to_llvm
39 {
40 class Simple_type_descriptor;
41 class Vector_type_descriptor;
42 class Matrix_type_descriptor;
43 class Pointer_type_descriptor;
44 class Function_type_descriptor;
45 class Struct_type_descriptor;
46 class Type_descriptor
47 {
48 Type_descriptor(const Type_descriptor &) = delete;
49 Type_descriptor &operator=(const Type_descriptor &) = delete;
50
51 public:
52 struct Type_visitor
53 {
54 virtual ~Type_visitor() = default;
55 virtual void visit(Simple_type_descriptor &type) = 0;
56 virtual void visit(Vector_type_descriptor &type) = 0;
57 virtual void visit(Matrix_type_descriptor &type) = 0;
58 virtual void visit(Pointer_type_descriptor &type) = 0;
59 virtual void visit(Function_type_descriptor &type) = 0;
60 virtual void visit(Struct_type_descriptor &type) = 0;
61 };
62
63 public:
64 const std::vector<spirv::Decoration_with_parameters> decorations;
65
66 public:
67 explicit Type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations) noexcept
68 : decorations(std::move(decorations))
69 {
70 }
71 virtual ~Type_descriptor() = default;
72 virtual ::LLVMTypeRef get_or_make_type() = 0;
73 virtual void visit(Type_visitor &type_visitor) = 0;
74 void visit(Type_visitor &&type_visitor)
75 {
76 visit(type_visitor);
77 }
78 template <typename Fn>
79 typename std::enable_if<!std::is_convertible<Fn &&, const Type_visitor &>::value, void>::type
80 visit(Fn &&fn)
81 {
82 struct Visitor final : public Type_visitor
83 {
84 Fn &fn;
85 virtual void visit(Simple_type_descriptor &type) override
86 {
87 std::forward<Fn>(fn)(type);
88 }
89 virtual void visit(Vector_type_descriptor &type) override
90 {
91 std::forward<Fn>(fn)(type);
92 }
93 virtual void visit(Matrix_type_descriptor &type) override
94 {
95 std::forward<Fn>(fn)(type);
96 }
97 virtual void visit(Pointer_type_descriptor &type) override
98 {
99 std::forward<Fn>(fn)(type);
100 }
101 virtual void visit(Function_type_descriptor &type) override
102 {
103 std::forward<Fn>(fn)(type);
104 }
105 virtual void visit(Struct_type_descriptor &type) override
106 {
107 std::forward<Fn>(fn)(type);
108 }
109 explicit Visitor(Fn &fn) noexcept : fn(fn)
110 {
111 }
112 };
113 visit(Visitor(fn));
114 }
115 class Recursion_checker;
116 class Recursion_checker_state
117 {
118 friend class Recursion_checker;
119
120 private:
121 std::size_t recursion_count = 0;
122 };
123 class Recursion_checker
124 {
125 Recursion_checker(const Recursion_checker &) = delete;
126 Recursion_checker &operator=(const Recursion_checker &) = delete;
127
128 private:
129 Recursion_checker_state &state;
130
131 public:
132 explicit Recursion_checker(Recursion_checker_state &state,
133 std::size_t instruction_start_index)
134 : state(state)
135 {
136 state.recursion_count++;
137 if(state.recursion_count > 5)
138 throw spirv::Parser_error(instruction_start_index,
139 instruction_start_index,
140 "too many recursions making type");
141 }
142 ~Recursion_checker()
143 {
144 state.recursion_count--;
145 }
146 std::size_t get_recursion_count() const noexcept
147 {
148 return state.recursion_count;
149 }
150 bool is_nested_recursion() const noexcept
151 {
152 return get_recursion_count() > 1;
153 }
154 };
155 };
156
157 class Simple_type_descriptor final : public Type_descriptor
158 {
159 private:
160 ::LLVMTypeRef type;
161
162 public:
163 explicit Simple_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
164 ::LLVMTypeRef type) noexcept
165 : Type_descriptor(std::move(decorations)),
166 type(type)
167 {
168 }
169 virtual ::LLVMTypeRef get_or_make_type() override
170 {
171 return type;
172 }
173 virtual void visit(Type_visitor &type_visitor) override
174 {
175 type_visitor.visit(*this);
176 }
177 };
178
179 class Vector_type_descriptor final : public Type_descriptor
180 {
181 private:
182 ::LLVMTypeRef type;
183 std::shared_ptr<Simple_type_descriptor> element_type;
184 std::size_t element_count;
185
186 public:
187 explicit Vector_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
188 std::shared_ptr<Simple_type_descriptor> element_type,
189 std::size_t element_count) noexcept
190 : Type_descriptor(std::move(decorations)),
191 type(::LLVMVectorType(element_type->get_or_make_type(), element_count)),
192 element_type(std::move(element_type)),
193 element_count(element_count)
194 {
195 }
196 virtual ::LLVMTypeRef get_or_make_type() override
197 {
198 return type;
199 }
200 virtual void visit(Type_visitor &type_visitor) override
201 {
202 type_visitor.visit(*this);
203 }
204 const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
205 {
206 return element_type;
207 }
208 std::size_t get_element_count() const noexcept
209 {
210 return element_count;
211 }
212 };
213
214 class Matrix_type_descriptor final : public Type_descriptor
215 {
216 private:
217 ::LLVMTypeRef type;
218 std::shared_ptr<Vector_type_descriptor> column_type;
219 std::size_t column_count;
220
221 public:
222 explicit Matrix_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
223 std::shared_ptr<Vector_type_descriptor> column_type,
224 std::size_t column_count) noexcept
225 : Type_descriptor(std::move(decorations)),
226 type(::LLVMVectorType(column_type->get_element_type()->get_or_make_type(),
227 column_type->get_element_count() * column_count)),
228 column_type(std::move(column_type)),
229 column_count(column_count)
230 {
231 }
232 virtual ::LLVMTypeRef get_or_make_type() override
233 {
234 return type;
235 }
236 virtual void visit(Type_visitor &type_visitor) override
237 {
238 type_visitor.visit(*this);
239 }
240 const std::shared_ptr<Vector_type_descriptor> &get_column_type() const noexcept
241 {
242 return column_type;
243 }
244 std::size_t get_column_count() const noexcept
245 {
246 return column_count;
247 }
248 };
249
250 class Pointer_type_descriptor final : public Type_descriptor
251 {
252 private:
253 std::shared_ptr<Type_descriptor> base;
254 std::size_t instruction_start_index;
255 ::LLVMTypeRef type;
256 Recursion_checker_state recursion_checker_state;
257
258 public:
259 Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
260 std::shared_ptr<Type_descriptor> base,
261 std::size_t instruction_start_index) noexcept
262 : Type_descriptor(std::move(decorations)),
263 base(std::move(base)),
264 instruction_start_index(instruction_start_index),
265 type(nullptr)
266 {
267 }
268 const std::shared_ptr<Type_descriptor> &get_base_type() const noexcept
269 {
270 return base;
271 }
272 void set_base_type(std::shared_ptr<Type_descriptor> new_base) noexcept
273 {
274 assert(!base);
275 assert(new_base);
276 base = std::move(new_base);
277 }
278 explicit Pointer_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
279 std::size_t instruction_start_index) noexcept
280 : Type_descriptor(std::move(decorations)),
281 base(nullptr),
282 instruction_start_index(instruction_start_index),
283 type(nullptr)
284 {
285 }
286 virtual ::LLVMTypeRef get_or_make_type() override
287 {
288 if(!type)
289 {
290 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
291 if(!base)
292 throw spirv::Parser_error(
293 instruction_start_index,
294 instruction_start_index,
295 "attempting to create type from pointer forward declaration");
296 auto base_type = base->get_or_make_type();
297 constexpr unsigned default_address_space = 0;
298 type = ::LLVMPointerType(base_type, default_address_space);
299 }
300 return type;
301 }
302 virtual void visit(Type_visitor &type_visitor) override
303 {
304 type_visitor.visit(*this);
305 }
306 };
307
308 class Function_type_descriptor final : public Type_descriptor
309 {
310 private:
311 std::shared_ptr<Type_descriptor> return_type;
312 std::vector<std::shared_ptr<Type_descriptor>> args;
313 ::LLVMTypeRef type;
314 Recursion_checker_state recursion_checker_state;
315 std::size_t instruction_start_index;
316 bool is_var_arg;
317
318 public:
319 explicit Function_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
320 std::shared_ptr<Type_descriptor> return_type,
321 std::vector<std::shared_ptr<Type_descriptor>> args,
322 std::size_t instruction_start_index,
323 bool is_var_arg = false) noexcept
324 : Type_descriptor(std::move(decorations)),
325 return_type(std::move(return_type)),
326 args(std::move(args)),
327 type(nullptr),
328 instruction_start_index(instruction_start_index),
329 is_var_arg(is_var_arg)
330 {
331 }
332 virtual ::LLVMTypeRef get_or_make_type() override
333 {
334 if(!type)
335 {
336 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
337 std::vector<::LLVMTypeRef> llvm_args;
338 llvm_args.reserve(args.size());
339 auto llvm_return_type = return_type->get_or_make_type();
340 for(auto &arg : args)
341 llvm_args.push_back(arg->get_or_make_type());
342 type = ::LLVMFunctionType(
343 llvm_return_type, llvm_args.data(), llvm_args.size(), is_var_arg);
344 }
345 return type;
346 }
347 virtual void visit(Type_visitor &type_visitor) override
348 {
349 type_visitor.visit(*this);
350 }
351 };
352
353 class Struct_type_descriptor final : public Type_descriptor
354 {
355 public:
356 struct Member
357 {
358 std::vector<spirv::Decoration_with_parameters> decorations;
359 std::size_t llvm_member_index = -1;
360 std::shared_ptr<Type_descriptor> type;
361 explicit Member(std::vector<spirv::Decoration_with_parameters> decorations,
362 std::shared_ptr<Type_descriptor> type) noexcept
363 : decorations(std::move(decorations)),
364 type(std::move(type))
365 {
366 }
367 };
368
369 private:
370 std::vector<Member> members;
371 util::Enum_map<spirv::Built_in, std::size_t> builtin_members;
372 ::LLVMTypeRef type;
373 bool is_complete;
374 Recursion_checker_state recursion_checker_state;
375 std::size_t instruction_start_index;
376 ::LLVMContextRef context;
377 ::LLVMTargetDataRef target_data;
378 void complete_type();
379 void on_add_member(std::size_t added_member_index) noexcept
380 {
381 assert(!is_complete);
382 auto &member = members[added_member_index];
383 for(auto &decoration : member.decorations)
384 if(decoration.value == spirv::Decoration::built_in)
385 builtin_members[util::get<spirv::Decoration_built_in_parameters>(
386 decoration.parameters)
387 .built_in] = added_member_index;
388 }
389
390 public:
391 std::size_t add_member(Member member)
392 {
393 std::size_t index = members.size();
394 members.push_back(std::move(member));
395 on_add_member(index);
396 return index;
397 }
398 const std::vector<Member> &get_members(bool need_llvm_member_indexes)
399 {
400 if(need_llvm_member_indexes)
401 get_or_make_type();
402 return members;
403 }
404 explicit Struct_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
405 ::LLVMContextRef context,
406 ::LLVMTargetDataRef target_data,
407 const char *name,
408 std::size_t instruction_start_index,
409 std::vector<Member> members = {})
410 : Type_descriptor(std::move(decorations)),
411 members(std::move(members)),
412 builtin_members{},
413 type(::LLVMStructCreateNamed(context, name)),
414 is_complete(false),
415 instruction_start_index(instruction_start_index),
416 context(context),
417 target_data(target_data)
418 {
419 for(std::size_t member_index = 0; member_index < members.size(); member_index++)
420 on_add_member(member_index);
421 }
422 virtual ::LLVMTypeRef get_or_make_type() override
423 {
424 if(!is_complete)
425 {
426 Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
427 if(!recursion_checker.is_nested_recursion())
428 complete_type();
429 }
430 return type;
431 }
432 virtual void visit(Type_visitor &type_visitor) override
433 {
434 type_visitor.visit(*this);
435 }
436 };
437
438 class Constant_descriptor
439 {
440 Constant_descriptor(const Constant_descriptor &) = delete;
441 Constant_descriptor &operator=(const Constant_descriptor &) = delete;
442
443 public:
444 const std::shared_ptr<Type_descriptor> type;
445
446 public:
447 explicit Constant_descriptor(std::shared_ptr<Type_descriptor> type) noexcept
448 : type(std::move(type))
449 {
450 }
451 ~Constant_descriptor() = default;
452 virtual ::LLVMValueRef get_or_make_value() = 0;
453 };
454
455 class Simple_constant_descriptor final : public Constant_descriptor
456 {
457 private:
458 ::LLVMValueRef value;
459
460 public:
461 explicit Simple_constant_descriptor(std::shared_ptr<Type_descriptor> type,
462 ::LLVMValueRef value) noexcept
463 : Constant_descriptor(std::move(type)),
464 value(value)
465 {
466 }
467 virtual ::LLVMValueRef get_or_make_value() override
468 {
469 return value;
470 }
471 };
472
473 struct Converted_module
474 {
475 struct Entry_point
476 {
477 std::string name;
478 std::string entry_function_name;
479 #warning finish filling in Entry_point
480 explicit Entry_point(std::string name, std::string entry_function_name) noexcept
481 : name(std::move(name)),
482 entry_function_name(std::move(entry_function_name))
483 {
484 }
485 };
486 llvm_wrapper::Module module;
487 std::vector<Entry_point> entry_points;
488 std::shared_ptr<Struct_type_descriptor> io_struct;
489 std::size_t inputs_member;
490 std::shared_ptr<Struct_type_descriptor> inputs_struct;
491 std::size_t outputs_member;
492 std::shared_ptr<Struct_type_descriptor> outputs_struct;
493 Converted_module() : module(), entry_points()
494 {
495 }
496 explicit Converted_module(llvm_wrapper::Module module,
497 std::vector<Entry_point> entry_points,
498 std::shared_ptr<Struct_type_descriptor> io_struct,
499 std::size_t inputs_member,
500 std::shared_ptr<Struct_type_descriptor> inputs_struct,
501 std::size_t outputs_member,
502 std::shared_ptr<Struct_type_descriptor> outputs_struct) noexcept
503 : module(std::move(module)),
504 entry_points(std::move(entry_points)),
505 io_struct(std::move(io_struct)),
506 inputs_member(inputs_member),
507 inputs_struct(std::move(inputs_struct)),
508 outputs_member(outputs_member),
509 outputs_struct(std::move(outputs_struct))
510 {
511 }
512 };
513
514 class Spirv_to_llvm;
515
516 Converted_module spirv_to_llvm(::LLVMContextRef context,
517 ::LLVMTargetMachineRef target_machine,
518 const spirv::Word *shader_words,
519 std::size_t shader_size,
520 std::uint64_t shader_id);
521 }
522 }
523
524 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */