From e9d53b36929e159ad9d7f523e970fe97b649c9d0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 20 Jul 2017 01:39:25 -0700 Subject: [PATCH] calculate_types stage works in Spirv_to_llvm --- src/spirv_to_llvm/spirv_to_llvm.cpp | 531 +++++++++++++++++++++++----- src/spirv_to_llvm/spirv_to_llvm.h | 145 ++++++++ 2 files changed, 585 insertions(+), 91 deletions(-) diff --git a/src/spirv_to_llvm/spirv_to_llvm.cpp b/src/spirv_to_llvm/spirv_to_llvm.cpp index eefbe0b..9ef798d 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.cpp +++ b/src/spirv_to_llvm/spirv_to_llvm.cpp @@ -88,12 +88,30 @@ private: Variable_state; struct Function_state { + struct Entry_block + { + ::LLVMBasicBlockRef entry_block; + ::LLVMValueRef io_struct; + ::LLVMValueRef inputs_struct; + ::LLVMValueRef outputs_struct; + explicit Entry_block(::LLVMBasicBlockRef entry_block, + ::LLVMValueRef io_struct, + ::LLVMValueRef inputs_struct, + ::LLVMValueRef outputs_struct) noexcept + : entry_block(entry_block), + io_struct(io_struct), + inputs_struct(inputs_struct), + outputs_struct(outputs_struct) + { + } + }; std::shared_ptr type; ::LLVMValueRef function; - ::LLVMBasicBlockRef entry_block = nullptr; + util::optional entry_block; explicit Function_state(std::shared_ptr type, ::LLVMValueRef function) noexcept : type(std::move(type)), - function(function) + function(function), + entry_block() { } }; @@ -104,6 +122,16 @@ private: { } }; + struct Value + { + ::LLVMValueRef value; + std::shared_ptr type; + explicit Value(::LLVMValueRef value, std::shared_ptr type) noexcept + : value(value), + type(std::move(type)) + { + } + }; struct Id_state { util::optional op_string; @@ -118,6 +146,7 @@ private: std::shared_ptr constant; util::optional function; util::optional label; + util::optional value; private: template @@ -162,6 +191,17 @@ private: { } }; + struct Last_merge_instruction + { + typedef util::variant Instruction_variant; + Instruction_variant instruction; + std::size_t instruction_start_index; + explicit Last_merge_instruction(Instruction_variant instruction, + std::size_t instruction_start_index) + : instruction(std::move(instruction)), instruction_start_index(instruction_start_index) + { + } + }; private: std::vector id_states; @@ -174,6 +214,7 @@ private: std::string name_prefix; llvm_wrapper::Module module; std::shared_ptr io_struct; + static constexpr std::size_t io_struct_argument_index = 0; std::array, 1> implicit_function_arguments; std::size_t inputs_member; std::shared_ptr inputs_struct; @@ -183,6 +224,7 @@ private: Id current_function_id = 0; Id current_basic_block_id = 0; llvm_wrapper::Builder builder; + util::optional last_merge_instruction; private: Id_state &get_id_state(Id id) @@ -236,7 +278,8 @@ public: io_struct = std::make_shared( context, (name_prefix + "Io_struct").c_str(), no_instruction_index); assert(implicit_function_arguments.size() == 1); - implicit_function_arguments[0] = io_struct; + static_assert(io_struct_argument_index == 0, ""); + implicit_function_arguments[io_struct_argument_index] = io_struct; inputs_struct = std::make_shared( context, (name_prefix + "Inputs").c_str(), no_instruction_index); inputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, inputs_struct)); @@ -1719,10 +1762,9 @@ void Spirv_to_llvm::handle_instruction_op_type_vector(Op_type_vector instruction switch(stage) { case Stage::calculate_types: - get_id_state(instruction.result) - .type = std::make_shared(::LLVMVectorType( - get_type(instruction.component_type, instruction_start_index)->get_or_make_type(false), - instruction.component_count)); + get_id_state(instruction.result).type = std::make_shared( + get_type(instruction.component_type, instruction_start_index), + instruction.component_count); break; case Stage::generate_code: break; @@ -1735,18 +1777,10 @@ void Spirv_to_llvm::handle_instruction_op_type_matrix(Op_type_matrix instruction switch(stage) { case Stage::calculate_types: - { - auto column_type = - get_type(instruction.column_type, instruction_start_index)->get_or_make_type(false); - if(::LLVMGetTypeKind(column_type) != LLVMVectorTypeKind) - throw Parser_error(instruction_start_index, - instruction_start_index, - "column type must be a vector type"); - get_id_state(instruction.result).type = std::make_shared( - ::LLVMVectorType(::LLVMGetElementType(column_type), - instruction.column_count * ::LLVMGetVectorSize(column_type))); + get_id_state(instruction.result).type = std::make_shared( + get_type(instruction.column_type, instruction_start_index), + instruction.column_count); break; - } case Stage::generate_code: break; } @@ -2113,8 +2147,13 @@ void Spirv_to_llvm::handle_instruction_op_constant(Op_constant instruction, break; } case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + state.value = Value(state.constant->get_or_make_value(), + get_type(instruction.result_type, instruction_start_index)); break; } + } } void Spirv_to_llvm::handle_instruction_op_constant_composite(Op_constant_composite instruction, @@ -2234,7 +2273,7 @@ void Spirv_to_llvm::handle_instruction_op_function_parameter(Op_function_paramet + std::string(get_enumerant_name(instruction.get_operation()))); } -void Spirv_to_llvm::handle_instruction_op_function_end(Op_function_end instruction, +void Spirv_to_llvm::handle_instruction_op_function_end([[gnu::unused]] Op_function_end instruction, std::size_t instruction_start_index) { if(!current_function_id) @@ -2242,11 +2281,13 @@ void Spirv_to_llvm::handle_instruction_op_function_end(Op_function_end instructi instruction_start_index, "OpFunctionEnd without matching OpFunction"); current_function_id = 0; -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + break; + } } void Spirv_to_llvm::handle_instruction_op_function_call(Op_function_call instruction, @@ -2510,8 +2551,87 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, break; } case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + switch(instruction.storage_class) + { + case Storage_class::uniform_constant: +#warning finish implementing Storage_class::uniform_constant + break; + case Storage_class::input: + { + if(instruction.initializer) + throw Parser_error(instruction_start_index, + instruction_start_index, + "shader input variable initializers are not implemented"); + auto &variable = util::get(state.variable); + state.value = + Value(::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id).function->entry_block->inputs_struct, + inputs_struct->get_members(true)[variable.member_index].llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + return; + } + case Storage_class::uniform: +#warning finish implementing Storage_class::uniform + break; + case Storage_class::output: + { + if(instruction.initializer) + throw Parser_error(instruction_start_index, + instruction_start_index, + "shader output variable initializers are not implemented"); + auto &variable = util::get(state.variable); + state.value = Value( + ::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id).function->entry_block->outputs_struct, + outputs_struct->get_members(true)[variable.member_index].llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + return; + } + case Storage_class::workgroup: +#warning finish implementing Storage_class::workgroup + break; + case Storage_class::cross_workgroup: +#warning finish implementing Storage_class::cross_workgroup + break; + case Storage_class::private_: +#warning finish implementing Storage_class::private_ + break; + case Storage_class::function: + { + if(!current_function_id) + throw Parser_error(instruction_start_index, + instruction_start_index, + "function-local variable must be inside function"); +#warning finish implementing Storage_class::function + throw Parser_error(instruction_start_index, + instruction_start_index, + "function-local variables are not implemented"); + } + case Storage_class::generic: +#warning finish implementing Storage_class::generic + break; + case Storage_class::push_constant: +#warning finish implementing Storage_class::push_constant + break; + case Storage_class::atomic_counter: +#warning finish implementing Storage_class::atomic_counter + break; + case Storage_class::image: +#warning finish implementing Storage_class::image + break; + case Storage_class::storage_buffer: +#warning finish implementing Storage_class::storage_buffer + break; + } break; } + } } void Spirv_to_llvm::handle_instruction_op_image_texel_pointer(Op_image_texel_pointer instruction, @@ -2527,21 +2647,53 @@ void Spirv_to_llvm::handle_instruction_op_image_texel_pointer(Op_image_texel_poi void Spirv_to_llvm::handle_instruction_op_load(Op_load instruction, std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + auto memory_access = instruction.memory_access.value_or( + Memory_access_with_parameters(Memory_access::none, {})); + if((memory_access.value & Memory_access::volatile_) == Memory_access::volatile_) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpLoad volatile not implemented"); + if((memory_access.value & Memory_access::aligned) == Memory_access::aligned) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpLoad alignment not implemented"); + if((memory_access.value & Memory_access::nontemporal) == Memory_access::nontemporal) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpLoad nontemporal not implemented"); + state.value = Value(::LLVMBuildLoad(builder.get(), + get_id_state(instruction.pointer).value.value().value, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_store(Op_store instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_copy_memory(Op_copy_memory instruction, @@ -2567,11 +2719,20 @@ void Spirv_to_llvm::handle_instruction_op_copy_memory_sized(Op_copy_memory_sized void Spirv_to_llvm::handle_instruction_op_access_chain(Op_access_chain instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_in_bounds_access_chain( @@ -2700,21 +2861,139 @@ void Spirv_to_llvm::handle_instruction_op_vector_shuffle(Op_vector_shuffle instr void Spirv_to_llvm::handle_instruction_op_composite_construct(Op_composite_construct instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + auto result_type = get_type(instruction.result_type, instruction_start_index); + ::LLVMValueRef result_value = nullptr; + struct Visitor + { + Op_composite_construct &instruction; + std::size_t instruction_start_index; + Id_state &state; + ::LLVMValueRef &result_value; + void operator()(Simple_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid result type for OpCompositeConstruct"); + } + void operator()(Vector_type_descriptor &) + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented result type for OpCompositeConstruct"); + } + void operator()(Matrix_type_descriptor &) + { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented result type for OpCompositeConstruct"); + } + void operator()(Pointer_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid result type for OpCompositeConstruct"); + } + void operator()(Function_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid result type for OpCompositeConstruct"); + } + void operator()(Struct_type_descriptor &) + { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented result type for OpCompositeConstruct"); + } + }; + result_type->visit(Visitor{instruction, instruction_start_index, state, result_value}); + state.value = Value(result_value, std::move(result_type)); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_composite_extract(Op_composite_extract instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + auto result = get_id_state(instruction.composite).value.value(); + std::string name = ""; + for(std::size_t i = 0; i < instruction.indexes.size(); i++) + { + std::size_t index = instruction.indexes[i]; + if(i == instruction.indexes.size() - 1) + name = get_name(instruction.result); + struct Visitor + { + std::size_t instruction_start_index; + Id_state &state; + Value &result; + std::string &name; + std::size_t index; + void operator()(Simple_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid composite type for OpCompositeExtract"); + } + void operator()(Vector_type_descriptor &) + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented composite type for OpCompositeExtract"); + } + void operator()(Matrix_type_descriptor &) + { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented composite type for OpCompositeExtract"); + } + void operator()(Pointer_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid composite type for OpCompositeExtract"); + } + void operator()(Function_type_descriptor &) + { + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid composite type for OpCompositeExtract"); + } + void operator()(Struct_type_descriptor &) + { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented composite type for OpCompositeExtract"); + } + }; + auto *type = result.type.get(); + type->visit(Visitor{instruction_start_index, state, result, name, index}); + } + state.value = result; + break; + } + } } void Spirv_to_llvm::handle_instruction_op_composite_insert(Op_composite_insert instruction, @@ -2970,11 +3249,20 @@ void Spirv_to_llvm::handle_instruction_op_image_query_samples(Op_image_query_sam void Spirv_to_llvm::handle_instruction_op_convert_f_to_u(Op_convert_f_to_u instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_convert_f_to_s(Op_convert_f_to_s instruction, @@ -3010,11 +3298,20 @@ void Spirv_to_llvm::handle_instruction_op_convert_u_to_f(Op_convert_u_to_f instr void Spirv_to_llvm::handle_instruction_op_u_convert(Op_u_convert instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_s_convert(Op_s_convert instruction, @@ -3120,11 +3417,20 @@ void Spirv_to_llvm::handle_instruction_op_generic_cast_to_ptr_explicit( void Spirv_to_llvm::handle_instruction_op_bitcast(Op_bitcast instruction, std::size_t instruction_start_index) { + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { #warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + throw Parser_error(instruction_start_index, + instruction_start_index, + "instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_s_negate(Op_s_negate instruction, @@ -4220,30 +4526,28 @@ void Spirv_to_llvm::handle_instruction_op_phi(Op_phi instruction, void Spirv_to_llvm::handle_instruction_op_loop_merge(Op_loop_merge instruction, std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + last_merge_instruction = + Last_merge_instruction(std::move(instruction), instruction_start_index); } void Spirv_to_llvm::handle_instruction_op_selection_merge(Op_selection_merge instruction, std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + last_merge_instruction = + Last_merge_instruction(std::move(instruction), instruction_start_index); } void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, std::size_t instruction_start_index) { if(current_function_id == 0) - throw Parser_error(instruction_start_index, instruction_start_index, "OpLabel not allowed outside a function"); + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpLabel not allowed outside a function"); if(current_basic_block_id != 0) - throw Parser_error(instruction_start_index, instruction_start_index, "missing block terminator before OpLabel"); + throw Parser_error(instruction_start_index, + instruction_start_index, + "missing block terminator before OpLabel"); current_basic_block_id = instruction.result; switch(stage) { @@ -4256,22 +4560,42 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, ::LLVMPositionBuilderAtEnd(builder.get(), block); if(!function.entry_block) { - function.entry_block = block; + auto io_struct_value = ::LLVMGetParam(function.function, io_struct_argument_index); + auto inputs_struct_value = ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->inputs_member].llvm_member_index, + "inputs"); + auto outputs_struct_value = ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->outputs_member].llvm_member_index, + "outputs"); #warning finish adding function entry instructions + function.entry_block = Function_state::Entry_block( + block, io_struct_value, inputs_struct_value, outputs_struct_value); } break; } } } -void Spirv_to_llvm::handle_instruction_op_branch(Op_branch instruction, - std::size_t instruction_start_index) +void Spirv_to_llvm::handle_instruction_op_branch( + Op_branch instruction, [[gnu::unused]] std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + auto merge = std::move(last_merge_instruction); + last_merge_instruction.reset(); + current_basic_block_id = 0; + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + ::LLVMBuildBr(builder.get(), get_or_make_label(instruction.target_label)); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_branch_conditional(Op_branch_conditional instruction, @@ -4284,14 +4608,33 @@ void Spirv_to_llvm::handle_instruction_op_branch_conditional(Op_branch_condition + std::string(get_enumerant_name(instruction.get_operation()))); } -void Spirv_to_llvm::handle_instruction_op_switch(Op_switch instruction, - std::size_t instruction_start_index) +void Spirv_to_llvm::handle_instruction_op_switch( + Op_switch instruction, [[gnu::unused]] std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + auto merge = std::move(last_merge_instruction.value()); + last_merge_instruction.reset(); + current_basic_block_id = 0; + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + for(auto &target : instruction.target) + get_or_make_label(target.part_2); // create basic blocks first + auto selector = get_id_state(instruction.selector).value.value(); + auto switch_instruction = ::LLVMBuildSwitch(builder.get(), + selector.value, + get_or_make_label(instruction.default_), + instruction.target.size()); + for(auto &target : instruction.target) + ::LLVMAddCase( + switch_instruction, + ::LLVMConstInt(selector.type->get_or_make_type(true), target.part_1, false), + get_or_make_label(target.part_2)); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_kill(Op_kill instruction, @@ -4304,14 +4647,20 @@ void Spirv_to_llvm::handle_instruction_op_kill(Op_kill instruction, + std::string(get_enumerant_name(instruction.get_operation()))); } -void Spirv_to_llvm::handle_instruction_op_return(Op_return instruction, - std::size_t instruction_start_index) +void Spirv_to_llvm::handle_instruction_op_return( + [[gnu::unused]] Op_return instruction, [[gnu::unused]] std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + current_basic_block_id = 0; + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + ::LLVMBuildRetVoid(builder.get()); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_return_value(Op_return_value instruction, diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index 6b377f4..c568323 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -29,21 +29,83 @@ #include #include #include +#include +#include #include "llvm_wrapper/llvm_wrapper.h" namespace vulkan_cpu { namespace spirv_to_llvm { +class Simple_type_descriptor; +class Vector_type_descriptor; +class Matrix_type_descriptor; +class Pointer_type_descriptor; +class Function_type_descriptor; +class Struct_type_descriptor; class Type_descriptor { Type_descriptor(const Type_descriptor &) = delete; Type_descriptor &operator=(const Type_descriptor &) = delete; +public: + struct Type_visitor + { + virtual ~Type_visitor() = default; + virtual void visit(Simple_type_descriptor &type) = 0; + virtual void visit(Vector_type_descriptor &type) = 0; + virtual void visit(Matrix_type_descriptor &type) = 0; + virtual void visit(Pointer_type_descriptor &type) = 0; + virtual void visit(Function_type_descriptor &type) = 0; + virtual void visit(Struct_type_descriptor &type) = 0; + }; + public: Type_descriptor() noexcept = default; virtual ~Type_descriptor() = default; virtual ::LLVMTypeRef get_or_make_type(bool need_complete_structs) = 0; + virtual void visit(Type_visitor &type_visitor) = 0; + void visit(Type_visitor &&type_visitor) + { + visit(type_visitor); + } + template + typename std::enable_if::value, void>::type + visit(Fn &&fn) + { + struct Visitor final : public Type_visitor + { + Fn &fn; + virtual void visit(Simple_type_descriptor &type) override + { + std::forward(fn)(type); + } + virtual void visit(Vector_type_descriptor &type) override + { + std::forward(fn)(type); + } + virtual void visit(Matrix_type_descriptor &type) override + { + std::forward(fn)(type); + } + virtual void visit(Pointer_type_descriptor &type) override + { + std::forward(fn)(type); + } + virtual void visit(Function_type_descriptor &type) override + { + std::forward(fn)(type); + } + virtual void visit(Struct_type_descriptor &type) override + { + std::forward(fn)(type); + } + explicit Visitor(Fn &fn) noexcept : fn(fn) + { + } + }; + visit(Visitor(fn)); + } class Recursion_checker; class Recursion_checker_state { @@ -99,6 +161,77 @@ public: { return type; } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } +}; + +class Vector_type_descriptor final : public Type_descriptor +{ +private: + ::LLVMTypeRef type; + std::shared_ptr element_type; + std::size_t element_count; + +public: + explicit Vector_type_descriptor(std::shared_ptr element_type, + std::size_t element_count) noexcept + : type(::LLVMVectorType(element_type->get_or_make_type(true), element_count)), + element_type(std::move(element_type)), + element_count(element_count) + { + } + virtual ::LLVMTypeRef get_or_make_type([[gnu::unused]] bool need_complete_structs) override + { + return type; + } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } + const std::shared_ptr &get_element_type() const noexcept + { + return element_type; + } + std::size_t get_element_count() const noexcept + { + return element_count; + } +}; + +class Matrix_type_descriptor final : public Type_descriptor +{ +private: + ::LLVMTypeRef type; + std::shared_ptr column_type; + std::size_t column_count; + +public: + explicit Matrix_type_descriptor(std::shared_ptr column_type, + std::size_t column_count) noexcept + : type(::LLVMVectorType(column_type->get_element_type()->get_or_make_type(true), + column_type->get_element_count() * column_count)), + column_type(std::move(column_type)), + column_count(column_count) + { + } + virtual ::LLVMTypeRef get_or_make_type([[gnu::unused]] bool need_complete_structs) override + { + return type; + } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } + const std::shared_ptr &get_column_type() const noexcept + { + return column_type; + } + std::size_t get_column_count() const noexcept + { + return column_count; + } }; class Pointer_type_descriptor final : public Type_descriptor @@ -149,6 +282,10 @@ public: } return type; } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } }; class Function_type_descriptor final : public Type_descriptor @@ -188,6 +325,10 @@ public: } return type; } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } }; class Struct_type_descriptor final : public Type_descriptor @@ -263,6 +404,10 @@ public: } return type; } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } }; class Constant_descriptor -- 2.30.2