2 * Copyright 2017 Jacob Lifshay
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in all
12 * copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 #ifndef SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
24 #define SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_
26 #include "spirv/parser.h"
32 #include <type_traits>
34 #include "llvm_wrapper/llvm_wrapper.h"
38 namespace spirv_to_llvm
40 class Simple_type_descriptor
;
41 class Vector_type_descriptor
;
42 class Matrix_type_descriptor
;
43 class Pointer_type_descriptor
;
44 class Function_type_descriptor
;
45 class Struct_type_descriptor
;
48 Type_descriptor(const Type_descriptor
&) = delete;
49 Type_descriptor
&operator=(const Type_descriptor
&) = delete;
54 virtual ~Type_visitor() = default;
55 virtual void visit(Simple_type_descriptor
&type
) = 0;
56 virtual void visit(Vector_type_descriptor
&type
) = 0;
57 virtual void visit(Matrix_type_descriptor
&type
) = 0;
58 virtual void visit(Pointer_type_descriptor
&type
) = 0;
59 virtual void visit(Function_type_descriptor
&type
) = 0;
60 virtual void visit(Struct_type_descriptor
&type
) = 0;
64 const std::vector
<spirv::Decoration_with_parameters
> decorations
;
67 explicit Type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
) noexcept
68 : decorations(std::move(decorations
))
71 virtual ~Type_descriptor() = default;
72 virtual ::LLVMTypeRef
get_or_make_type() = 0;
73 virtual void visit(Type_visitor
&type_visitor
) = 0;
74 void visit(Type_visitor
&&type_visitor
)
78 template <typename Fn
>
79 typename
std::enable_if
<!std::is_convertible
<Fn
&&, const Type_visitor
&>::value
, void>::type
82 struct Visitor final
: public Type_visitor
85 virtual void visit(Simple_type_descriptor
&type
) override
87 std::forward
<Fn
>(fn
)(type
);
89 virtual void visit(Vector_type_descriptor
&type
) override
91 std::forward
<Fn
>(fn
)(type
);
93 virtual void visit(Matrix_type_descriptor
&type
) override
95 std::forward
<Fn
>(fn
)(type
);
97 virtual void visit(Pointer_type_descriptor
&type
) override
99 std::forward
<Fn
>(fn
)(type
);
101 virtual void visit(Function_type_descriptor
&type
) override
103 std::forward
<Fn
>(fn
)(type
);
105 virtual void visit(Struct_type_descriptor
&type
) override
107 std::forward
<Fn
>(fn
)(type
);
109 explicit Visitor(Fn
&fn
) noexcept
: fn(fn
)
115 class Recursion_checker
;
116 class Recursion_checker_state
118 friend class Recursion_checker
;
121 std::size_t recursion_count
= 0;
123 class Recursion_checker
125 Recursion_checker(const Recursion_checker
&) = delete;
126 Recursion_checker
&operator=(const Recursion_checker
&) = delete;
129 Recursion_checker_state
&state
;
132 explicit Recursion_checker(Recursion_checker_state
&state
,
133 std::size_t instruction_start_index
)
136 state
.recursion_count
++;
137 if(state
.recursion_count
> 5)
138 throw spirv::Parser_error(instruction_start_index
,
139 instruction_start_index
,
140 "too many recursions making type");
144 state
.recursion_count
--;
146 std::size_t get_recursion_count() const noexcept
148 return state
.recursion_count
;
150 bool is_nested_recursion() const noexcept
152 return get_recursion_count() > 1;
157 class Simple_type_descriptor final
: public Type_descriptor
163 explicit Simple_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
164 ::LLVMTypeRef type
) noexcept
165 : Type_descriptor(std::move(decorations
)),
169 virtual ::LLVMTypeRef
get_or_make_type() override
173 virtual void visit(Type_visitor
&type_visitor
) override
175 type_visitor
.visit(*this);
179 class Vector_type_descriptor final
: public Type_descriptor
183 std::shared_ptr
<Simple_type_descriptor
> element_type
;
184 std::size_t element_count
;
187 explicit Vector_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
188 std::shared_ptr
<Simple_type_descriptor
> element_type
,
189 std::size_t element_count
) noexcept
190 : Type_descriptor(std::move(decorations
)),
191 type(::LLVMVectorType(element_type
->get_or_make_type(), element_count
)),
192 element_type(std::move(element_type
)),
193 element_count(element_count
)
196 virtual ::LLVMTypeRef
get_or_make_type() override
200 virtual void visit(Type_visitor
&type_visitor
) override
202 type_visitor
.visit(*this);
204 const std::shared_ptr
<Simple_type_descriptor
> &get_element_type() const noexcept
208 std::size_t get_element_count() const noexcept
210 return element_count
;
214 class Matrix_type_descriptor final
: public Type_descriptor
218 std::shared_ptr
<Vector_type_descriptor
> column_type
;
219 std::size_t column_count
;
222 explicit Matrix_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
223 std::shared_ptr
<Vector_type_descriptor
> column_type
,
224 std::size_t column_count
) noexcept
225 : Type_descriptor(std::move(decorations
)),
226 type(::LLVMVectorType(column_type
->get_element_type()->get_or_make_type(),
227 column_type
->get_element_count() * column_count
)),
228 column_type(std::move(column_type
)),
229 column_count(column_count
)
232 virtual ::LLVMTypeRef
get_or_make_type() override
236 virtual void visit(Type_visitor
&type_visitor
) override
238 type_visitor
.visit(*this);
240 const std::shared_ptr
<Vector_type_descriptor
> &get_column_type() const noexcept
244 std::size_t get_column_count() const noexcept
250 class Pointer_type_descriptor final
: public Type_descriptor
253 std::shared_ptr
<Type_descriptor
> base
;
254 std::size_t instruction_start_index
;
256 Recursion_checker_state recursion_checker_state
;
259 Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
260 std::shared_ptr
<Type_descriptor
> base
,
261 std::size_t instruction_start_index
) noexcept
262 : Type_descriptor(std::move(decorations
)),
263 base(std::move(base
)),
264 instruction_start_index(instruction_start_index
),
268 const std::shared_ptr
<Type_descriptor
> &get_base_type() const noexcept
272 void set_base_type(std::shared_ptr
<Type_descriptor
> new_base
) noexcept
276 base
= std::move(new_base
);
278 explicit Pointer_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
279 std::size_t instruction_start_index
) noexcept
280 : Type_descriptor(std::move(decorations
)),
282 instruction_start_index(instruction_start_index
),
286 virtual ::LLVMTypeRef
get_or_make_type() override
290 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
292 throw spirv::Parser_error(
293 instruction_start_index
,
294 instruction_start_index
,
295 "attempting to create type from pointer forward declaration");
296 auto base_type
= base
->get_or_make_type();
297 constexpr unsigned default_address_space
= 0;
298 type
= ::LLVMPointerType(base_type
, default_address_space
);
302 virtual void visit(Type_visitor
&type_visitor
) override
304 type_visitor
.visit(*this);
308 class Function_type_descriptor final
: public Type_descriptor
311 std::shared_ptr
<Type_descriptor
> return_type
;
312 std::vector
<std::shared_ptr
<Type_descriptor
>> args
;
314 Recursion_checker_state recursion_checker_state
;
315 std::size_t instruction_start_index
;
319 explicit Function_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
320 std::shared_ptr
<Type_descriptor
> return_type
,
321 std::vector
<std::shared_ptr
<Type_descriptor
>> args
,
322 std::size_t instruction_start_index
,
323 bool is_var_arg
= false) noexcept
324 : Type_descriptor(std::move(decorations
)),
325 return_type(std::move(return_type
)),
326 args(std::move(args
)),
328 instruction_start_index(instruction_start_index
),
329 is_var_arg(is_var_arg
)
332 virtual ::LLVMTypeRef
get_or_make_type() override
336 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
337 std::vector
<::LLVMTypeRef
> llvm_args
;
338 llvm_args
.reserve(args
.size());
339 auto llvm_return_type
= return_type
->get_or_make_type();
340 for(auto &arg
: args
)
341 llvm_args
.push_back(arg
->get_or_make_type());
342 type
= ::LLVMFunctionType(
343 llvm_return_type
, llvm_args
.data(), llvm_args
.size(), is_var_arg
);
347 virtual void visit(Type_visitor
&type_visitor
) override
349 type_visitor
.visit(*this);
353 class Struct_type_descriptor final
: public Type_descriptor
358 std::vector
<spirv::Decoration_with_parameters
> decorations
;
359 std::size_t llvm_member_index
= -1;
360 std::shared_ptr
<Type_descriptor
> type
;
361 explicit Member(std::vector
<spirv::Decoration_with_parameters
> decorations
,
362 std::shared_ptr
<Type_descriptor
> type
) noexcept
363 : decorations(std::move(decorations
)),
364 type(std::move(type
))
370 std::vector
<Member
> members
;
371 util::Enum_map
<spirv::Built_in
, std::size_t> builtin_members
;
374 Recursion_checker_state recursion_checker_state
;
375 std::size_t instruction_start_index
;
376 ::LLVMContextRef context
;
377 ::LLVMTargetDataRef target_data
;
378 void complete_type();
379 void on_add_member(std::size_t added_member_index
) noexcept
381 assert(!is_complete
);
382 auto &member
= members
[added_member_index
];
383 for(auto &decoration
: member
.decorations
)
384 if(decoration
.value
== spirv::Decoration::built_in
)
385 builtin_members
[util::get
<spirv::Decoration_built_in_parameters
>(
386 decoration
.parameters
)
387 .built_in
] = added_member_index
;
391 std::size_t add_member(Member member
)
393 std::size_t index
= members
.size();
394 members
.push_back(std::move(member
));
395 on_add_member(index
);
398 const std::vector
<Member
> &get_members(bool need_llvm_member_indexes
)
400 if(need_llvm_member_indexes
)
404 explicit Struct_type_descriptor(std::vector
<spirv::Decoration_with_parameters
> decorations
,
405 ::LLVMContextRef context
,
406 ::LLVMTargetDataRef target_data
,
408 std::size_t instruction_start_index
,
409 std::vector
<Member
> members
= {})
410 : Type_descriptor(std::move(decorations
)),
411 members(std::move(members
)),
413 type(::LLVMStructCreateNamed(context
, name
)),
415 instruction_start_index(instruction_start_index
),
417 target_data(target_data
)
419 for(std::size_t member_index
= 0; member_index
< members
.size(); member_index
++)
420 on_add_member(member_index
);
422 virtual ::LLVMTypeRef
get_or_make_type() override
426 Recursion_checker
recursion_checker(recursion_checker_state
, instruction_start_index
);
427 if(!recursion_checker
.is_nested_recursion())
432 virtual void visit(Type_visitor
&type_visitor
) override
434 type_visitor
.visit(*this);
438 class Constant_descriptor
440 Constant_descriptor(const Constant_descriptor
&) = delete;
441 Constant_descriptor
&operator=(const Constant_descriptor
&) = delete;
444 const std::shared_ptr
<Type_descriptor
> type
;
447 explicit Constant_descriptor(std::shared_ptr
<Type_descriptor
> type
) noexcept
448 : type(std::move(type
))
451 ~Constant_descriptor() = default;
452 virtual ::LLVMValueRef
get_or_make_value() = 0;
455 class Simple_constant_descriptor final
: public Constant_descriptor
458 ::LLVMValueRef value
;
461 explicit Simple_constant_descriptor(std::shared_ptr
<Type_descriptor
> type
,
462 ::LLVMValueRef value
) noexcept
463 : Constant_descriptor(std::move(type
)),
467 virtual ::LLVMValueRef
get_or_make_value() override
473 struct Converted_module
478 #warning finish filling in Entry_point
479 explicit Entry_point(std::string name
) noexcept
: name(std::move(name
))
483 llvm_wrapper::Module module
;
484 std::vector
<Entry_point
> entry_points
;
485 std::shared_ptr
<Struct_type_descriptor
> io_struct
;
486 std::size_t inputs_member
;
487 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
;
488 std::size_t outputs_member
;
489 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
;
490 Converted_module() : module(), entry_points()
493 explicit Converted_module(llvm_wrapper::Module module
,
494 std::vector
<Entry_point
> entry_points
,
495 std::shared_ptr
<Struct_type_descriptor
> io_struct
,
496 std::size_t inputs_member
,
497 std::shared_ptr
<Struct_type_descriptor
> inputs_struct
,
498 std::size_t outputs_member
,
499 std::shared_ptr
<Struct_type_descriptor
> outputs_struct
) noexcept
500 : module(std::move(module
)),
501 entry_points(std::move(entry_points
)),
502 io_struct(std::move(io_struct
)),
503 inputs_member(inputs_member
),
504 inputs_struct(std::move(inputs_struct
)),
505 outputs_member(outputs_member
),
506 outputs_struct(std::move(outputs_struct
))
513 Converted_module
spirv_to_llvm(::LLVMContextRef context
,
514 const spirv::Word
*shader_words
,
515 std::size_t shader_size
,
516 std::uint64_t shader_id
);
520 #endif /* SPIRV_TO_LLVM_SPIRV_TO_LLVM_H_ */