2 * Copyright 2017 Jacob Lifshay
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
26 #include "spirv/parser.h"
32 #include <type_traits>
35 #include "llvm_wrapper/llvm_wrapper.h"
39 namespace spirv_to_llvm
41 struct LLVM_type_and_alignment
44 std::size_t alignment
;
45 constexpr LLVM_type_and_alignment() noexcept
: type(nullptr), alignment(0)
48 constexpr LLVM_type_and_alignment(::LLVMTypeRef type
, std::size_t alignment
) noexcept
55 class Simple_type_descriptor
;
56 class Vector_type_descriptor
;
57 class Matrix_type_descriptor
;
58 class Array_type_descriptor
;
59 class Pointer_type_descriptor
;
60 class Function_type_descriptor
;
61 class Struct_type_descriptor
;
64 Type_descriptor(const Type_descriptor
&) = delete;
65 Type_descriptor
&operator=(const Type_descriptor
&) = delete;
70 virtual ~Type_visitor() = default;
71 virtual void visit(Simple_type_descriptor
&type
) = 0;
72 virtual void visit(Vector_type_descriptor
&type
) = 0;
73 virtual void visit(Matrix_type_descriptor
&type
) = 0;
74 virtual void visit(Array_type_descriptor
&type
) = 0;
75 virtual void visit(Pointer_type_descriptor
&type
) = 0;
76 virtual void visit(Function_type_descriptor
&type
) = 0;
77 virtual void visit(Struct_type_descriptor
&type
) = 0;
81 const std::vector
<spirv::Decoration_with_parameters
> decorations
;
84 explicit Type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
) noexcept
85 : decorations(std::move(decorations
))
88 virtual ~Type_descriptor() = default;
89 virtual LLVM_type_and_alignment
get_or_make_type() = 0;
90 virtual void visit(Type_visitor
&type_visitor
) = 0;
91 void visit(Type_visitor
&&type_visitor
)
95 template <typename Fn
>
96 typename
std::enable_if
<!std::is_convertible
<Fn
&&, const Type_visitor
&>::value
, void>::type
99 struct Visitor final
: public Type_visitor
102 virtual void visit(Simple_type_descriptor
&type
) override
104 std::forward
<Fn
>(fn
)(type
);
106 virtual void visit(Vector_type_descriptor
&type
) override
108 std::forward
<Fn
>(fn
)(type
);
110 virtual void visit(Matrix_type_descriptor
&type
) override
112 std::forward
<Fn
>(fn
)(type
);
114 virtual void visit(Array_type_descriptor
&type
) override
116 std::forward
<Fn
>(fn
)(type
);
118 virtual void visit(Pointer_type_descriptor
&type
) override
120 std::forward
<Fn
>(fn
)(type
);
122 virtual void visit(Function_type_descriptor
&type
) override
124 std::forward
<Fn
>(fn
)(type
);
126 virtual void visit(Struct_type_descriptor
&type
) override
128 std::forward
<Fn
>(fn
)(type
);
130 explicit Visitor(Fn
&fn
) noexcept
: fn(fn
)
136 class Recursion_checker
;
137 class Recursion_checker_state
139 friend class Recursion_checker
;
142 std::size_t recursion_count
= 0;
144 class Recursion_checker
146 Recursion_checker(const Recursion_checker
&) = delete;
147 Recursion_checker
&operator=(const Recursion_checker
&) = delete;
150 Recursion_checker_state
&state
;
153 explicit Recursion_checker(Recursion_checker_state
&state
,
154 std::size_t instruction_start_index
)
157 state
.recursion_count
++;
158 if(state
.recursion_count
> 5)
159 throw spirv::Parser_error(instruction_start_index
,
160 instruction_start_index
,
161 "too many recursions making type");
165 state
.recursion_count
--;
167 std::size_t get_recursion_count() const noexcept
169 return state
.recursion_count
;
171 bool is_nested_recursion() const noexcept
173 return get_recursion_count() > 1;
178 class Simple_type_descriptor final
: public Type_descriptor
181 LLVM_type_and_alignment type
;
184 explicit Simple_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
185 LLVM_type_and_alignment type
) noexcept
186 : Type_descriptor(std::move(decorations
)),
190 virtual LLVM_type_and_alignment
get_or_make_type() override
194 virtual void visit(Type_visitor
&type_visitor
) override
196 type_visitor
.visit(*this);
200 class Vector_type_descriptor final
: public Type_descriptor
203 LLVM_type_and_alignment type
;
204 std::shared_ptr
<Simple_type_descriptor
> element_type
;
205 std::size_t element_count
;
208 explicit Vector_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
209 std::shared_ptr
<Simple_type_descriptor
> element_type
,
210 std::size_t element_count
,
211 ::LLVMTargetDataRef target_data
) noexcept
212 : Type_descriptor(std::move(decorations
)),
213 type(make_vector_type(element_type
, element_count
, target_data
)),
214 element_type(std::move(element_type
)),
215 element_count(element_count
)
218 static LLVM_type_and_alignment
make_vector_type(
219 const std::shared_ptr
<Simple_type_descriptor
> &element_type
,
220 std::size_t element_count
,
221 ::LLVMTargetDataRef target_data
)
223 auto llvm_element_type
= element_type
->get_or_make_type();
224 auto type
= ::LLVMVectorType(llvm_element_type
.type
, element_count
);
225 std::size_t alignment
= ::LLVMPreferredAlignmentOfType(target_data
, type
);
226 constexpr std::size_t max_abi_alignment
= alignof(std::max_align_t
);
227 if(alignment
> max_abi_alignment
)
228 alignment
= max_abi_alignment
;
229 return {type
, alignment
};
231 virtual LLVM_type_and_alignment
get_or_make_type() override
235 virtual void visit(Type_visitor
&type_visitor
) override
237 type_visitor
.visit(*this);
239 const std::shared_ptr
<Simple_type_descriptor
> &get_element_type() const noexcept
243 std::size_t get_element_count() const noexcept
245 return element_count
;
249 class Matrix_type_descriptor final
: public Type_descriptor
252 LLVM_type_and_alignment type
;
253 std::shared_ptr
<Vector_type_descriptor
> column_type
;
254 std::size_t column_count
;
257 explicit Matrix_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
258 std::shared_ptr
<Vector_type_descriptor
> column_type
,
259 std::size_t column_count
,
260 ::LLVMTargetDataRef target_data
) noexcept
261 : Type_descriptor(std::move(decorations
)),
262 type(Vector_type_descriptor::make_vector_type(column_type
->get_element_type(),
263 column_type
->get_element_count()
266 column_type(std::move(column_type
)),
267 column_count(column_count
)
270 virtual LLVM_type_and_alignment
get_or_make_type() override
274 virtual void visit(Type_visitor
&type_visitor
) override
276 type_visitor
.visit(*this);
278 const std::shared_ptr
<Vector_type_descriptor
> &get_column_type() const noexcept
282 std::size_t get_column_count() const noexcept
288 class Array_type_descriptor final
: public Type_descriptor
291 LLVM_type_and_alignment type
;
292 std::shared_ptr
<Type_descriptor
> element_type
;
293 std::size_t element_count
;
294 std::size_t instruction_start_index
;
295 Recursion_checker_state recursion_checker_state
;
298 explicit Array_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
299 std::shared_ptr
<Type_descriptor
> element_type
,
300 std::size_t element_count
,
301 std::size_t instruction_start_index
) noexcept
302 : Type_descriptor(std::move(decorations
)),
304 element_type(std::move(element_type
)),
305 element_count(element_count
),
306 instruction_start_index(instruction_start_index
)
309 virtual LLVM_type_and_alignment
get_or_make_type() override
313 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
314 auto llvm_element_type
= element_type
->get_or_make_type();
315 type
= LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type
.type
, element_count
),
316 llvm_element_type
.alignment
);
320 virtual void visit(Type_visitor
&type_visitor
) override
322 type_visitor
.visit(*this);
324 const std::shared_ptr
<Type_descriptor
> &get_element_type() const noexcept
328 std::size_t get_element_count() const noexcept
330 return element_count
;
334 class Pointer_type_descriptor final
: public Type_descriptor
337 std::shared_ptr
<Type_descriptor
> base
;
338 std::size_t instruction_start_index
;
339 LLVM_type_and_alignment type
;
340 Recursion_checker_state recursion_checker_state
;
343 Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
344 std::shared_ptr
<Type_descriptor
> base
,
345 std::size_t instruction_start_index
,
346 ::LLVMTargetDataRef target_data
) noexcept
347 : Type_descriptor(std::move(decorations
)),
348 base(std::move(base
)),
349 instruction_start_index(instruction_start_index
),
350 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
))
353 const std::shared_ptr
<Type_descriptor
> &get_base_type() const noexcept
357 void set_base_type(std::shared_ptr
<Type_descriptor
> new_base
) noexcept
361 base
= std::move(new_base
);
363 explicit Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
364 std::size_t instruction_start_index
,
365 ::LLVMTargetDataRef target_data
) noexcept
366 : Type_descriptor(std::move(decorations
)),
368 instruction_start_index(instruction_start_index
),
369 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
))
372 virtual LLVM_type_and_alignment
get_or_make_type() override
376 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
378 throw spirv::Parser_error(
379 instruction_start_index
,
380 instruction_start_index
,
381 "attempting to create type from pointer forward declaration");
382 auto base_type
= base
->get_or_make_type();
383 constexpr unsigned default_address_space
= 0;
384 type
.type
= ::LLVMPointerType(base_type
.type
, default_address_space
);
388 virtual void visit(Type_visitor
&type_visitor
) override
390 type_visitor
.visit(*this);
394 class Function_type_descriptor final
: public Type_descriptor
397 std::shared_ptr
<Type_descriptor
> return_type
;
398 std::vector
<std::shared_ptr
<Type_descriptor
>> args
;
399 LLVM_type_and_alignment type
;
400 Recursion_checker_state recursion_checker_state
;
401 std::size_t instruction_start_index
;
405 explicit Function_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
406 std::shared_ptr
<Type_descriptor
> return_type
,
407 std::vector
<std::shared_ptr
<Type_descriptor
>> args
,
408 std::size_t instruction_start_index
,
409 ::LLVMTargetDataRef target_data
,
410 bool is_var_arg
= false) noexcept
411 : Type_descriptor(std::move(decorations
)),
412 return_type(std::move(return_type
)),
413 args(std::move(args
)),
414 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
)),
415 instruction_start_index(instruction_start_index
),
416 is_var_arg(is_var_arg
)
419 virtual LLVM_type_and_alignment
get_or_make_type() override
423 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
424 std::vector
<::LLVMTypeRef
> llvm_args
;
425 llvm_args
.reserve(args
.size());
426 auto llvm_return_type
= return_type
->get_or_make_type();
427 for(auto &arg
: args
)
428 llvm_args
.push_back(arg
->get_or_make_type().type
);
429 type
.type
= ::LLVMFunctionType(
430 llvm_return_type
.type
, llvm_args
.data(), llvm_args
.size(), is_var_arg
);
434 virtual void visit(Type_visitor
&type_visitor
) override
436 type_visitor
.visit(*this);
440 class Struct_type_descriptor final
: public Type_descriptor
445 std::vector
<spirv::Decoration_with_parameters
> decorations
;
446 std::size_t llvm_member_index
= -1;
447 std::shared_ptr
<Type_descriptor
> type
;
448 explicit Member(std::vector
<spirv::Decoration_with_parameters
> decorations
,
449 std::shared_ptr
<Type_descriptor
> type
) noexcept
450 : decorations(std::move(decorations
)),
451 type(std::move(type
))
457 std::vector
<Member
> members
;
458 util::Enum_map
<spirv::Built_in
, std::size_t> builtin_members
;
459 LLVM_type_and_alignment type
;
461 Recursion_checker_state recursion_checker_state
;
462 std::size_t instruction_start_index
;
463 ::LLVMContextRef context
;
464 ::LLVMTargetDataRef target_data
;
465 void complete_type();
466 void on_add_member(std::size_t added_member_index
) noexcept
468 assert(!is_complete
);
469 auto &member
= members
[added_member_index
];
470 for(auto &decoration
: member
.decorations
)
471 if(decoration
.value
== spirv::Decoration::built_in
)
472 builtin_members
[util::get
<spirv::Decoration_built_in_parameters
>(
473 decoration
.parameters
)
474 .built_in
] = added_member_index
;
478 std::size_t add_member(Member member
)
480 std::size_t index
= members
.size();
481 members
.push_back(std::move(member
));
482 on_add_member(index
);
485 const std::vector
<Member
> &get_members(bool need_llvm_member_indexes
)
487 if(need_llvm_member_indexes
)
491 explicit Struct_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
492 ::LLVMContextRef context
,
493 ::LLVMTargetDataRef target_data
,
495 std::size_t instruction_start_index
,
496 std::vector
<Member
> members
= {})
497 : Type_descriptor(std::move(decorations
)),
498 members(std::move(members
)),
500 type(::LLVMStructCreateNamed(context
, name
), 0),
502 instruction_start_index(instruction_start_index
),
504 target_data(target_data
)
506 for(std::size_t member_index
= 0; member_index
< members
.size(); member_index
++)
507 on_add_member(member_index
);
509 virtual LLVM_type_and_alignment
get_or_make_type() override
513 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
514 if(!recursion_checker
.is_nested_recursion())
519 virtual void visit(Type_visitor
&type_visitor
) override
521 type_visitor
.visit(*this);
525 class Constant_descriptor
527 Constant_descriptor(const Constant_descriptor
&) = delete;
528 Constant_descriptor
&operator=(const Constant_descriptor
&) = delete;
531 const std::shared_ptr
<Type_descriptor
> type
;
534 explicit Constant_descriptor(std::shared_ptr
<Type_descriptor
> type
) noexcept
535 : type(std::move(type
))
538 ~Constant_descriptor() = default;
539 virtual ::LLVMValueRef
get_or_make_value() = 0;
542 class Simple_constant_descriptor final
: public Constant_descriptor
545 ::LLVMValueRef value
;
548 explicit Simple_constant_descriptor(std::shared_ptr
<Type_descriptor
> type
,
549 ::LLVMValueRef value
) noexcept
550 : Constant_descriptor(std::move(type
)),
554 virtual ::LLVMValueRef
get_or_make_value() override
560 struct Converted_module
562 llvm_wrapper::Module module
;
563 std::string entry_function_name
;
564 std::shared_ptr
<Struct_type_descriptor
> io_struct
;
565 std::size_t inputs_member
;
566 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
;
567 std::size_t outputs_member
;
568 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
;
569 Converted_module() = default;
570 explicit Converted_module(llvm_wrapper::Module module
,
571 std::string entry_function_name
,
572 std::shared_ptr
<Struct_type_descriptor
> io_struct
,
573 std::size_t inputs_member
,
574 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
,
575 std::size_t outputs_member
,
576 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
) noexcept
577 : module(std::move(module
)),
578 entry_function_name(std::move(entry_function_name
)),
579 io_struct(std::move(io_struct
)),
580 inputs_member(inputs_member
),
581 inputs_struct(std::move(inputs_struct
)),
582 outputs_member(outputs_member
),
583 outputs_struct(std::move(outputs_struct
))
590 Converted_module
spirv_to_llvm(::LLVMContextRef context
,
591 ::LLVMTargetMachineRef target_machine
,
592 const spirv::Word
*shader_words
,
593 std::size_t shader_size
,
594 std::uint64_t shader_id
,
595 spirv::Execution_model execution_model
,
596 util::string_view entry_point_name
);
600 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */