2 * Copyright 2017 Jacob Lifshay
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
26 #include "spirv/parser.h"
32 #include <type_traits>
35 #include "llvm_wrapper/llvm_wrapper.h"
36 #include "util/string_view.h"
37 #include "vulkan/vulkan.h"
41 namespace spirv_to_llvm
43 struct LLVM_type_and_alignment
46 std::size_t alignment
;
47 constexpr LLVM_type_and_alignment() noexcept
: type(nullptr), alignment(0)
50 constexpr LLVM_type_and_alignment(::LLVMTypeRef type
, std::size_t alignment
) noexcept
57 class Simple_type_descriptor
;
58 class Vector_type_descriptor
;
59 class Matrix_type_descriptor
;
60 class Array_type_descriptor
;
61 class Pointer_type_descriptor
;
62 class Function_type_descriptor
;
63 class Struct_type_descriptor
;
66 Type_descriptor(const Type_descriptor
&) = delete;
67 Type_descriptor
&operator=(const Type_descriptor
&) = delete;
72 virtual ~Type_visitor() = default;
73 virtual void visit(Simple_type_descriptor
&type
) = 0;
74 virtual void visit(Vector_type_descriptor
&type
) = 0;
75 virtual void visit(Matrix_type_descriptor
&type
) = 0;
76 virtual void visit(Array_type_descriptor
&type
) = 0;
77 virtual void visit(Pointer_type_descriptor
&type
) = 0;
78 virtual void visit(Function_type_descriptor
&type
) = 0;
79 virtual void visit(Struct_type_descriptor
&type
) = 0;
83 const std::vector
<spirv::Decoration_with_parameters
> decorations
;
86 explicit Type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
) noexcept
87 : decorations(std::move(decorations
))
90 virtual ~Type_descriptor() = default;
91 virtual LLVM_type_and_alignment
get_or_make_type() = 0;
92 virtual void visit(Type_visitor
&type_visitor
) = 0;
93 void visit(Type_visitor
&&type_visitor
)
97 template <typename Fn
>
98 typename
std::enable_if
<!std::is_convertible
<Fn
&&, const Type_visitor
&>::value
, void>::type
101 struct Visitor final
: public Type_visitor
104 virtual void visit(Simple_type_descriptor
&type
) override
106 std::forward
<Fn
>(fn
)(type
);
108 virtual void visit(Vector_type_descriptor
&type
) override
110 std::forward
<Fn
>(fn
)(type
);
112 virtual void visit(Matrix_type_descriptor
&type
) override
114 std::forward
<Fn
>(fn
)(type
);
116 virtual void visit(Array_type_descriptor
&type
) override
118 std::forward
<Fn
>(fn
)(type
);
120 virtual void visit(Pointer_type_descriptor
&type
) override
122 std::forward
<Fn
>(fn
)(type
);
124 virtual void visit(Function_type_descriptor
&type
) override
126 std::forward
<Fn
>(fn
)(type
);
128 virtual void visit(Struct_type_descriptor
&type
) override
130 std::forward
<Fn
>(fn
)(type
);
132 explicit Visitor(Fn
&fn
) noexcept
: fn(fn
)
138 class Recursion_checker
;
139 class Recursion_checker_state
141 friend class Recursion_checker
;
144 std::size_t recursion_count
= 0;
146 class Recursion_checker
148 Recursion_checker(const Recursion_checker
&) = delete;
149 Recursion_checker
&operator=(const Recursion_checker
&) = delete;
152 Recursion_checker_state
&state
;
155 explicit Recursion_checker(Recursion_checker_state
&state
,
156 std::size_t instruction_start_index
)
159 state
.recursion_count
++;
160 if(state
.recursion_count
> 5)
161 throw spirv::Parser_error(instruction_start_index
,
162 instruction_start_index
,
163 "too many recursions making type");
167 state
.recursion_count
--;
169 std::size_t get_recursion_count() const noexcept
171 return state
.recursion_count
;
173 bool is_nested_recursion() const noexcept
175 return get_recursion_count() > 1;
180 class Simple_type_descriptor final
: public Type_descriptor
183 LLVM_type_and_alignment type
;
186 explicit Simple_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
187 LLVM_type_and_alignment type
) noexcept
188 : Type_descriptor(std::move(decorations
)),
192 virtual LLVM_type_and_alignment
get_or_make_type() override
196 virtual void visit(Type_visitor
&type_visitor
) override
198 type_visitor
.visit(*this);
202 class Vector_type_descriptor final
: public Type_descriptor
205 LLVM_type_and_alignment type
;
206 std::shared_ptr
<Simple_type_descriptor
> element_type
;
207 std::size_t element_count
;
210 explicit Vector_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
211 std::shared_ptr
<Simple_type_descriptor
> element_type
,
212 std::size_t element_count
,
213 ::LLVMTargetDataRef target_data
) noexcept
214 : Type_descriptor(std::move(decorations
)),
215 type(make_vector_type(element_type
, element_count
, target_data
)),
216 element_type(std::move(element_type
)),
217 element_count(element_count
)
220 static LLVM_type_and_alignment
make_vector_type(
221 const std::shared_ptr
<Simple_type_descriptor
> &element_type
,
222 std::size_t element_count
,
223 ::LLVMTargetDataRef target_data
)
225 auto llvm_element_type
= element_type
->get_or_make_type();
226 auto type
= ::LLVMVectorType(llvm_element_type
.type
, element_count
);
227 std::size_t alignment
= ::LLVMPreferredAlignmentOfType(target_data
, type
);
228 constexpr std::size_t max_abi_alignment
= alignof(std::max_align_t
);
229 if(alignment
> max_abi_alignment
)
230 alignment
= max_abi_alignment
;
231 return {type
, alignment
};
233 virtual LLVM_type_and_alignment
get_or_make_type() override
237 virtual void visit(Type_visitor
&type_visitor
) override
239 type_visitor
.visit(*this);
241 const std::shared_ptr
<Simple_type_descriptor
> &get_element_type() const noexcept
245 std::size_t get_element_count() const noexcept
247 return element_count
;
251 class Matrix_type_descriptor final
: public Type_descriptor
254 LLVM_type_and_alignment type
;
255 std::shared_ptr
<Vector_type_descriptor
> column_type
;
256 std::size_t column_count
;
259 explicit Matrix_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
260 std::shared_ptr
<Vector_type_descriptor
> column_type
,
261 std::size_t column_count
,
262 ::LLVMTargetDataRef target_data
) noexcept
263 : Type_descriptor(std::move(decorations
)),
264 type(Vector_type_descriptor::make_vector_type(column_type
->get_element_type(),
265 column_type
->get_element_count()
268 column_type(std::move(column_type
)),
269 column_count(column_count
)
272 virtual LLVM_type_and_alignment
get_or_make_type() override
276 virtual void visit(Type_visitor
&type_visitor
) override
278 type_visitor
.visit(*this);
280 const std::shared_ptr
<Vector_type_descriptor
> &get_column_type() const noexcept
284 std::size_t get_column_count() const noexcept
290 class Array_type_descriptor final
: public Type_descriptor
293 LLVM_type_and_alignment type
;
294 std::shared_ptr
<Type_descriptor
> element_type
;
295 std::size_t element_count
;
296 std::size_t instruction_start_index
;
297 Recursion_checker_state recursion_checker_state
;
300 explicit Array_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
301 std::shared_ptr
<Type_descriptor
> element_type
,
302 std::size_t element_count
,
303 std::size_t instruction_start_index
) noexcept
304 : Type_descriptor(std::move(decorations
)),
306 element_type(std::move(element_type
)),
307 element_count(element_count
),
308 instruction_start_index(instruction_start_index
)
311 virtual LLVM_type_and_alignment
get_or_make_type() override
315 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
316 auto llvm_element_type
= element_type
->get_or_make_type();
317 type
= LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type
.type
, element_count
),
318 llvm_element_type
.alignment
);
322 virtual void visit(Type_visitor
&type_visitor
) override
324 type_visitor
.visit(*this);
326 const std::shared_ptr
<Type_descriptor
> &get_element_type() const noexcept
330 std::size_t get_element_count() const noexcept
332 return element_count
;
336 class Pointer_type_descriptor final
: public Type_descriptor
339 std::shared_ptr
<Type_descriptor
> base
;
340 std::size_t instruction_start_index
;
341 LLVM_type_and_alignment type
;
342 Recursion_checker_state recursion_checker_state
;
345 Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
346 std::shared_ptr
<Type_descriptor
> base
,
347 std::size_t instruction_start_index
,
348 ::LLVMTargetDataRef target_data
) noexcept
349 : Type_descriptor(std::move(decorations
)),
350 base(std::move(base
)),
351 instruction_start_index(instruction_start_index
),
352 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
))
355 const std::shared_ptr
<Type_descriptor
> &get_base_type() const noexcept
359 void set_base_type(std::shared_ptr
<Type_descriptor
> new_base
) noexcept
363 base
= std::move(new_base
);
365 explicit Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
366 std::size_t instruction_start_index
,
367 ::LLVMTargetDataRef target_data
) noexcept
368 : Type_descriptor(std::move(decorations
)),
370 instruction_start_index(instruction_start_index
),
371 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
))
374 virtual LLVM_type_and_alignment
get_or_make_type() override
378 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
380 throw spirv::Parser_error(
381 instruction_start_index
,
382 instruction_start_index
,
383 "attempting to create type from pointer forward declaration");
384 auto base_type
= base
->get_or_make_type();
385 constexpr unsigned default_address_space
= 0;
386 type
.type
= ::LLVMPointerType(base_type
.type
, default_address_space
);
390 virtual void visit(Type_visitor
&type_visitor
) override
392 type_visitor
.visit(*this);
396 class Function_type_descriptor final
: public Type_descriptor
399 std::shared_ptr
<Type_descriptor
> return_type
;
400 std::vector
<std::shared_ptr
<Type_descriptor
>> args
;
401 LLVM_type_and_alignment type
;
402 Recursion_checker_state recursion_checker_state
;
403 std::size_t instruction_start_index
;
404 bool valid_for_entry_point
;
408 explicit Function_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
409 std::shared_ptr
<Type_descriptor
> return_type
,
410 std::vector
<std::shared_ptr
<Type_descriptor
>> args
,
411 std::size_t instruction_start_index
,
412 ::LLVMTargetDataRef target_data
,
413 bool valid_for_entry_point
,
414 bool is_var_arg
) noexcept
415 : Type_descriptor(std::move(decorations
)),
416 return_type(std::move(return_type
)),
417 args(std::move(args
)),
418 type(nullptr, llvm_wrapper::Target_data::get_pointer_alignment(target_data
)),
419 instruction_start_index(instruction_start_index
),
420 valid_for_entry_point(valid_for_entry_point
),
421 is_var_arg(is_var_arg
)
424 virtual LLVM_type_and_alignment
get_or_make_type() override
428 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
429 std::vector
<::LLVMTypeRef
> llvm_args
;
430 llvm_args
.reserve(args
.size());
431 auto llvm_return_type
= return_type
->get_or_make_type();
432 for(auto &arg
: args
)
433 llvm_args
.push_back(arg
->get_or_make_type().type
);
434 type
.type
= ::LLVMFunctionType(
435 llvm_return_type
.type
, llvm_args
.data(), llvm_args
.size(), is_var_arg
);
439 virtual void visit(Type_visitor
&type_visitor
) override
441 type_visitor
.visit(*this);
443 bool is_valid_for_entry_point() const noexcept
445 return valid_for_entry_point
;
449 class Struct_type_descriptor final
: public Type_descriptor
454 std::vector
<spirv::Decoration_with_parameters
> decorations
;
455 std::size_t llvm_member_index
= -1;
456 std::shared_ptr
<Type_descriptor
> type
;
457 explicit Member(std::vector
<spirv::Decoration_with_parameters
> decorations
,
458 std::shared_ptr
<Type_descriptor
> type
) noexcept
459 : decorations(std::move(decorations
)),
460 type(std::move(type
))
466 std::vector
<Member
> members
;
467 util::Enum_map
<spirv::Built_in
, std::size_t> builtin_members
;
468 LLVM_type_and_alignment type
;
470 Recursion_checker_state recursion_checker_state
;
471 std::size_t instruction_start_index
;
472 ::LLVMContextRef context
;
473 ::LLVMTargetDataRef target_data
;
474 void complete_type();
475 void on_add_member(std::size_t added_member_index
) noexcept
477 assert(!is_complete
);
478 auto &member
= members
[added_member_index
];
479 for(auto &decoration
: member
.decorations
)
480 if(decoration
.value
== spirv::Decoration::built_in
)
481 builtin_members
[util::get
<spirv::Decoration_built_in_parameters
>(
482 decoration
.parameters
)
483 .built_in
] = added_member_index
;
487 std::size_t add_member(Member member
)
489 std::size_t index
= members
.size();
490 members
.push_back(std::move(member
));
491 on_add_member(index
);
494 const std::vector
<Member
> &get_members(bool need_llvm_member_indexes
)
496 if(need_llvm_member_indexes
)
500 explicit Struct_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
501 ::LLVMContextRef context
,
502 ::LLVMTargetDataRef target_data
,
504 std::size_t instruction_start_index
,
505 std::vector
<Member
> members
= {})
506 : Type_descriptor(std::move(decorations
)),
507 members(std::move(members
)),
509 type(::LLVMStructCreateNamed(context
, name
), 0),
511 instruction_start_index(instruction_start_index
),
513 target_data(target_data
)
515 for(std::size_t member_index
= 0; member_index
< members
.size(); member_index
++)
516 on_add_member(member_index
);
518 virtual LLVM_type_and_alignment
get_or_make_type() override
522 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
523 if(!recursion_checker
.is_nested_recursion())
528 virtual void visit(Type_visitor
&type_visitor
) override
530 type_visitor
.visit(*this);
534 class Constant_descriptor
536 Constant_descriptor(const Constant_descriptor
&) = delete;
537 Constant_descriptor
&operator=(const Constant_descriptor
&) = delete;
540 const std::shared_ptr
<Type_descriptor
> type
;
543 explicit Constant_descriptor(std::shared_ptr
<Type_descriptor
> type
) noexcept
544 : type(std::move(type
))
547 ~Constant_descriptor() = default;
548 virtual ::LLVMValueRef
get_or_make_value() = 0;
551 class Simple_constant_descriptor final
: public Constant_descriptor
554 ::LLVMValueRef value
;
557 explicit Simple_constant_descriptor(std::shared_ptr
<Type_descriptor
> type
,
558 ::LLVMValueRef value
) noexcept
559 : Constant_descriptor(std::move(type
)),
563 virtual ::LLVMValueRef
get_or_make_value() override
569 struct Converted_module
571 llvm_wrapper::Module module
;
572 std::string entry_function_name
;
573 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
;
574 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
;
575 spirv::Execution_model execution_model
;
576 Converted_module() = default;
577 explicit Converted_module(llvm_wrapper::Module module
,
578 std::string entry_function_name
,
579 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
,
580 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
,
581 spirv::Execution_model execution_model
) noexcept
582 : module(std::move(module
)),
583 entry_function_name(std::move(entry_function_name
)),
584 inputs_struct(std::move(inputs_struct
)),
585 outputs_struct(std::move(outputs_struct
)),
586 execution_model(execution_model
)
591 struct Jit_symbol_resolver
593 typedef void (*Resolved_symbol
)();
594 Resolved_symbol
resolve(util::string_view name
);
595 static std::uint64_t resolve(const char *name
, void *user_data
) noexcept
597 return reinterpret_cast<std::uint64_t>(
598 static_cast<Jit_symbol_resolver
*>(user_data
)->resolve(name
));
600 static std::uintptr_t resolve(const std::string
&name
, void *user_data
) noexcept
602 return reinterpret_cast<std::uintptr_t>(
603 static_cast<Jit_symbol_resolver
*>(user_data
)->resolve(name
));
609 Converted_module
spirv_to_llvm(::LLVMContextRef context
,
610 ::LLVMTargetMachineRef target_machine
,
611 const spirv::Word
*shader_words
,
612 std::size_t shader_size
,
613 std::uint64_t shader_id
,
614 spirv::Execution_model execution_model
,
615 util::string_view entry_point_name
,
616 const VkPipelineVertexInputStateCreateInfo
*vertex_input_state
);
620 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */