calculate_types stage works in Spirv_to_llvm
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Jul 2017 08:39:25 +0000 (01:39 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Jul 2017 08:39:25 +0000 (01:39 -0700)
src/spirv_to_llvm/spirv_to_llvm.cpp
src/spirv_to_llvm/spirv_to_llvm.h

index eefbe0b46bc126969fd8fead27a2d4c121126214..9ef798dc86d80cdafb9f6f46bb06cab8689bb3f6 100644 (file)
@@ -88,12 +88,30 @@ private:
         Variable_state;
     struct Function_state
     {
+        struct Entry_block
+        {
+            ::LLVMBasicBlockRef entry_block;
+            ::LLVMValueRef io_struct;
+            ::LLVMValueRef inputs_struct;
+            ::LLVMValueRef outputs_struct;
+            explicit Entry_block(::LLVMBasicBlockRef entry_block,
+                                 ::LLVMValueRef io_struct,
+                                 ::LLVMValueRef inputs_struct,
+                                 ::LLVMValueRef outputs_struct) noexcept
+                : entry_block(entry_block),
+                  io_struct(io_struct),
+                  inputs_struct(inputs_struct),
+                  outputs_struct(outputs_struct)
+            {
+            }
+        };
         std::shared_ptr<Function_type_descriptor> type;
         ::LLVMValueRef function;
-        ::LLVMBasicBlockRef entry_block = nullptr;
+        util::optional<Entry_block> entry_block;
         explicit Function_state(std::shared_ptr<Function_type_descriptor> type,
                                 ::LLVMValueRef function) noexcept : type(std::move(type)),
-                                                                    function(function)
+                                                                    function(function),
+                                                                    entry_block()
         {
         }
     };
@@ -104,6 +122,16 @@ private:
         {
         }
     };
+    struct Value
+    {
+        ::LLVMValueRef value;
+        std::shared_ptr<Type_descriptor> type;
+        explicit Value(::LLVMValueRef value, std::shared_ptr<Type_descriptor> type) noexcept
+            : value(value),
+              type(std::move(type))
+        {
+        }
+    };
     struct Id_state
     {
         util::optional<Op_string_state> op_string;
@@ -118,6 +146,7 @@ private:
         std::shared_ptr<Constant_descriptor> constant;
         util::optional<Function_state> function;
         util::optional<Label_state> label;
+        util::optional<Value> value;
 
     private:
         template <typename Fn>
@@ -162,6 +191,17 @@ private:
         {
         }
     };
+    struct Last_merge_instruction
+    {
+        typedef util::variant<Op_selection_merge, Op_loop_merge> Instruction_variant;
+        Instruction_variant instruction;
+        std::size_t instruction_start_index;
+        explicit Last_merge_instruction(Instruction_variant instruction,
+                                        std::size_t instruction_start_index)
+            : instruction(std::move(instruction)), instruction_start_index(instruction_start_index)
+        {
+        }
+    };
 
 private:
     std::vector<Id_state> id_states;
@@ -174,6 +214,7 @@ private:
     std::string name_prefix;
     llvm_wrapper::Module module;
     std::shared_ptr<Struct_type_descriptor> io_struct;
+    static constexpr std::size_t io_struct_argument_index = 0;
     std::array<std::shared_ptr<Type_descriptor>, 1> implicit_function_arguments;
     std::size_t inputs_member;
     std::shared_ptr<Struct_type_descriptor> inputs_struct;
@@ -183,6 +224,7 @@ private:
     Id current_function_id = 0;
     Id current_basic_block_id = 0;
     llvm_wrapper::Builder builder;
+    util::optional<Last_merge_instruction> last_merge_instruction;
 
 private:
     Id_state &get_id_state(Id id)
@@ -236,7 +278,8 @@ public:
         io_struct = std::make_shared<Struct_type_descriptor>(
             context, (name_prefix + "Io_struct").c_str(), no_instruction_index);
         assert(implicit_function_arguments.size() == 1);
-        implicit_function_arguments[0] = io_struct;
+        static_assert(io_struct_argument_index == 0, "");
+        implicit_function_arguments[io_struct_argument_index] = io_struct;
         inputs_struct = std::make_shared<Struct_type_descriptor>(
             context, (name_prefix + "Inputs").c_str(), no_instruction_index);
         inputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, inputs_struct));
@@ -1719,10 +1762,9 @@ void Spirv_to_llvm::handle_instruction_op_type_vector(Op_type_vector instruction
     switch(stage)
     {
     case Stage::calculate_types:
-        get_id_state(instruction.result)
-            .type = std::make_shared<Simple_type_descriptor>(::LLVMVectorType(
-            get_type(instruction.component_type, instruction_start_index)->get_or_make_type(false),
-            instruction.component_count));
+        get_id_state(instruction.result).type = std::make_shared<Vector_type_descriptor>(
+            get_type<Simple_type_descriptor>(instruction.component_type, instruction_start_index),
+            instruction.component_count);
         break;
     case Stage::generate_code:
         break;
@@ -1735,18 +1777,10 @@ void Spirv_to_llvm::handle_instruction_op_type_matrix(Op_type_matrix instruction
     switch(stage)
     {
     case Stage::calculate_types:
-    {
-        auto column_type =
-            get_type(instruction.column_type, instruction_start_index)->get_or_make_type(false);
-        if(::LLVMGetTypeKind(column_type) != LLVMVectorTypeKind)
-            throw Parser_error(instruction_start_index,
-                               instruction_start_index,
-                               "column type must be a vector type");
-        get_id_state(instruction.result).type = std::make_shared<Simple_type_descriptor>(
-            ::LLVMVectorType(::LLVMGetElementType(column_type),
-                             instruction.column_count * ::LLVMGetVectorSize(column_type)));
+        get_id_state(instruction.result).type = std::make_shared<Matrix_type_descriptor>(
+            get_type<Vector_type_descriptor>(instruction.column_type, instruction_start_index),
+            instruction.column_count);
         break;
-    }
     case Stage::generate_code:
         break;
     }
@@ -2113,8 +2147,13 @@ void Spirv_to_llvm::handle_instruction_op_constant(Op_constant instruction,
         break;
     }
     case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        state.value = Value(state.constant->get_or_make_value(),
+                            get_type(instruction.result_type, instruction_start_index));
         break;
     }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_constant_composite(Op_constant_composite instruction,
@@ -2234,7 +2273,7 @@ void Spirv_to_llvm::handle_instruction_op_function_parameter(Op_function_paramet
                            + std::string(get_enumerant_name(instruction.get_operation())));
 }
 
-void Spirv_to_llvm::handle_instruction_op_function_end(Op_function_end instruction,
+void Spirv_to_llvm::handle_instruction_op_function_end([[gnu::unused]] Op_function_end instruction,
                                                        std::size_t instruction_start_index)
 {
     if(!current_function_id)
@@ -2242,11 +2281,13 @@ void Spirv_to_llvm::handle_instruction_op_function_end(Op_function_end instructi
                            instruction_start_index,
                            "OpFunctionEnd without matching OpFunction");
     current_function_id = 0;
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+        break;
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_function_call(Op_function_call instruction,
@@ -2510,8 +2551,87 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
         break;
     }
     case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        switch(instruction.storage_class)
+        {
+        case Storage_class::uniform_constant:
+#warning finish implementing Storage_class::uniform_constant
+            break;
+        case Storage_class::input:
+        {
+            if(instruction.initializer)
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "shader input variable initializers are not implemented");
+            auto &variable = util::get<Input_variable_state>(state.variable);
+            state.value =
+                Value(::LLVMBuildStructGEP(
+                          builder.get(),
+                          get_id_state(current_function_id).function->entry_block->inputs_struct,
+                          inputs_struct->get_members(true)[variable.member_index].llvm_member_index,
+                          get_name(instruction.result).c_str()),
+                      get_type(instruction.result_type, instruction_start_index));
+            return;
+        }
+        case Storage_class::uniform:
+#warning finish implementing Storage_class::uniform
+            break;
+        case Storage_class::output:
+        {
+            if(instruction.initializer)
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "shader output variable initializers are not implemented");
+            auto &variable = util::get<Output_variable_state>(state.variable);
+            state.value = Value(
+                ::LLVMBuildStructGEP(
+                    builder.get(),
+                    get_id_state(current_function_id).function->entry_block->outputs_struct,
+                    outputs_struct->get_members(true)[variable.member_index].llvm_member_index,
+                    get_name(instruction.result).c_str()),
+                get_type(instruction.result_type, instruction_start_index));
+            return;
+        }
+        case Storage_class::workgroup:
+#warning finish implementing Storage_class::workgroup
+            break;
+        case Storage_class::cross_workgroup:
+#warning finish implementing Storage_class::cross_workgroup
+            break;
+        case Storage_class::private_:
+#warning finish implementing Storage_class::private_
+            break;
+        case Storage_class::function:
+        {
+            if(!current_function_id)
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "function-local variable must be inside function");
+#warning finish implementing Storage_class::function
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "function-local variables are not implemented");
+        }
+        case Storage_class::generic:
+#warning finish implementing Storage_class::generic
+            break;
+        case Storage_class::push_constant:
+#warning finish implementing Storage_class::push_constant
+            break;
+        case Storage_class::atomic_counter:
+#warning finish implementing Storage_class::atomic_counter
+            break;
+        case Storage_class::image:
+#warning finish implementing Storage_class::image
+            break;
+        case Storage_class::storage_buffer:
+#warning finish implementing Storage_class::storage_buffer
+            break;
+        }
         break;
     }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_image_texel_pointer(Op_image_texel_pointer instruction,
@@ -2527,21 +2647,53 @@ void Spirv_to_llvm::handle_instruction_op_image_texel_pointer(Op_image_texel_poi
 void Spirv_to_llvm::handle_instruction_op_load(Op_load instruction,
                                                std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        auto memory_access = instruction.memory_access.value_or(
+            Memory_access_with_parameters(Memory_access::none, {}));
+        if((memory_access.value & Memory_access::volatile_) == Memory_access::volatile_)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpLoad volatile not implemented");
+        if((memory_access.value & Memory_access::aligned) == Memory_access::aligned)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpLoad alignment not implemented");
+        if((memory_access.value & Memory_access::nontemporal) == Memory_access::nontemporal)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpLoad nontemporal not implemented");
+        state.value = Value(::LLVMBuildLoad(builder.get(),
+                                            get_id_state(instruction.pointer).value.value().value,
+                                            get_name(instruction.result).c_str()),
+                            get_type(instruction.result_type, instruction_start_index));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_store(Op_store instruction,
                                                 std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "instruction not implemented: "
+                               + std::string(get_enumerant_name(instruction.get_operation())));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_copy_memory(Op_copy_memory instruction,
@@ -2567,11 +2719,20 @@ void Spirv_to_llvm::handle_instruction_op_copy_memory_sized(Op_copy_memory_sized
 void Spirv_to_llvm::handle_instruction_op_access_chain(Op_access_chain instruction,
                                                        std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "instruction not implemented: "
+                               + std::string(get_enumerant_name(instruction.get_operation())));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_in_bounds_access_chain(
@@ -2700,21 +2861,139 @@ void Spirv_to_llvm::handle_instruction_op_vector_shuffle(Op_vector_shuffle instr
 void Spirv_to_llvm::handle_instruction_op_composite_construct(Op_composite_construct instruction,
                                                               std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        auto result_type = get_type(instruction.result_type, instruction_start_index);
+        ::LLVMValueRef result_value = nullptr;
+        struct Visitor
+        {
+            Op_composite_construct &instruction;
+            std::size_t instruction_start_index;
+            Id_state &state;
+            ::LLVMValueRef &result_value;
+            void operator()(Simple_type_descriptor &)
+            {
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "invalid result type for OpCompositeConstruct");
+            }
+            void operator()(Vector_type_descriptor &)
+            {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "unimplemented result type for OpCompositeConstruct");
+            }
+            void operator()(Matrix_type_descriptor &)
+            {
+#warning finish
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "unimplemented result type for OpCompositeConstruct");
+            }
+            void operator()(Pointer_type_descriptor &)
+            {
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "invalid result type for OpCompositeConstruct");
+            }
+            void operator()(Function_type_descriptor &)
+            {
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "invalid result type for OpCompositeConstruct");
+            }
+            void operator()(Struct_type_descriptor &)
+            {
+#warning finish
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "unimplemented result type for OpCompositeConstruct");
+            }
+        };
+        result_type->visit(Visitor{instruction, instruction_start_index, state, result_value});
+        state.value = Value(result_value, std::move(result_type));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_composite_extract(Op_composite_extract instruction,
                                                             std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        auto result = get_id_state(instruction.composite).value.value();
+        std::string name = "";
+        for(std::size_t i = 0; i < instruction.indexes.size(); i++)
+        {
+            std::size_t index = instruction.indexes[i];
+            if(i == instruction.indexes.size() - 1)
+                name = get_name(instruction.result);
+            struct Visitor
+            {
+                std::size_t instruction_start_index;
+                Id_state &state;
+                Value &result;
+                std::string &name;
+                std::size_t index;
+                void operator()(Simple_type_descriptor &)
+                {
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "invalid composite type for OpCompositeExtract");
+                }
+                void operator()(Vector_type_descriptor &)
+                {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented composite type for OpCompositeExtract");
+                }
+                void operator()(Matrix_type_descriptor &)
+                {
+#warning finish
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented composite type for OpCompositeExtract");
+                }
+                void operator()(Pointer_type_descriptor &)
+                {
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "invalid composite type for OpCompositeExtract");
+                }
+                void operator()(Function_type_descriptor &)
+                {
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "invalid composite type for OpCompositeExtract");
+                }
+                void operator()(Struct_type_descriptor &)
+                {
+#warning finish
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented composite type for OpCompositeExtract");
+                }
+            };
+            auto *type = result.type.get();
+            type->visit(Visitor{instruction_start_index, state, result, name, index});
+        }
+        state.value = result;
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_composite_insert(Op_composite_insert instruction,
@@ -2970,11 +3249,20 @@ void Spirv_to_llvm::handle_instruction_op_image_query_samples(Op_image_query_sam
 void Spirv_to_llvm::handle_instruction_op_convert_f_to_u(Op_convert_f_to_u instruction,
                                                          std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "instruction not implemented: "
+                               + std::string(get_enumerant_name(instruction.get_operation())));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_convert_f_to_s(Op_convert_f_to_s instruction,
@@ -3010,11 +3298,20 @@ void Spirv_to_llvm::handle_instruction_op_convert_u_to_f(Op_convert_u_to_f instr
 void Spirv_to_llvm::handle_instruction_op_u_convert(Op_u_convert instruction,
                                                     std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "instruction not implemented: "
+                               + std::string(get_enumerant_name(instruction.get_operation())));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_s_convert(Op_s_convert instruction,
@@ -3120,11 +3417,20 @@ void Spirv_to_llvm::handle_instruction_op_generic_cast_to_ptr_explicit(
 void Spirv_to_llvm::handle_instruction_op_bitcast(Op_bitcast instruction,
                                                   std::size_t instruction_start_index)
 {
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
 #warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "instruction not implemented: "
+                               + std::string(get_enumerant_name(instruction.get_operation())));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_s_negate(Op_s_negate instruction,
@@ -4220,30 +4526,28 @@ void Spirv_to_llvm::handle_instruction_op_phi(Op_phi instruction,
 void Spirv_to_llvm::handle_instruction_op_loop_merge(Op_loop_merge instruction,
                                                      std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    last_merge_instruction =
+        Last_merge_instruction(std::move(instruction), instruction_start_index);
 }
 
 void Spirv_to_llvm::handle_instruction_op_selection_merge(Op_selection_merge instruction,
                                                           std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    last_merge_instruction =
+        Last_merge_instruction(std::move(instruction), instruction_start_index);
 }
 
 void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction,
                                                 std::size_t instruction_start_index)
 {
     if(current_function_id == 0)
-        throw Parser_error(instruction_start_index, instruction_start_index, "OpLabel not allowed outside a function");
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "OpLabel not allowed outside a function");
     if(current_basic_block_id != 0)
-        throw Parser_error(instruction_start_index, instruction_start_index, "missing block terminator before OpLabel");
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "missing block terminator before OpLabel");
     current_basic_block_id = instruction.result;
     switch(stage)
     {
@@ -4256,22 +4560,42 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction,
         ::LLVMPositionBuilderAtEnd(builder.get(), block);
         if(!function.entry_block)
         {
-            function.entry_block = block;
+            auto io_struct_value = ::LLVMGetParam(function.function, io_struct_argument_index);
+            auto inputs_struct_value = ::LLVMBuildStructGEP(
+                builder.get(),
+                io_struct_value,
+                io_struct->get_members(true)[this->inputs_member].llvm_member_index,
+                "inputs");
+            auto outputs_struct_value = ::LLVMBuildStructGEP(
+                builder.get(),
+                io_struct_value,
+                io_struct->get_members(true)[this->outputs_member].llvm_member_index,
+                "outputs");
 #warning finish adding function entry instructions
+            function.entry_block = Function_state::Entry_block(
+                block, io_struct_value, inputs_struct_value, outputs_struct_value);
         }
         break;
     }
     }
 }
 
-void Spirv_to_llvm::handle_instruction_op_branch(Op_branch instruction,
-                                                 std::size_t instruction_start_index)
+void Spirv_to_llvm::handle_instruction_op_branch(
+    Op_branch instruction, [[gnu::unused]] std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    auto merge = std::move(last_merge_instruction);
+    last_merge_instruction.reset();
+    current_basic_block_id = 0;
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        ::LLVMBuildBr(builder.get(), get_or_make_label(instruction.target_label));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_branch_conditional(Op_branch_conditional instruction,
@@ -4284,14 +4608,33 @@ void Spirv_to_llvm::handle_instruction_op_branch_conditional(Op_branch_condition
                            + std::string(get_enumerant_name(instruction.get_operation())));
 }
 
-void Spirv_to_llvm::handle_instruction_op_switch(Op_switch instruction,
-                                                 std::size_t instruction_start_index)
+void Spirv_to_llvm::handle_instruction_op_switch(
+    Op_switch instruction, [[gnu::unused]] std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    auto merge = std::move(last_merge_instruction.value());
+    last_merge_instruction.reset();
+    current_basic_block_id = 0;
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        for(auto &target : instruction.target)
+            get_or_make_label(target.part_2); // create basic blocks first
+        auto selector = get_id_state(instruction.selector).value.value();
+        auto switch_instruction = ::LLVMBuildSwitch(builder.get(),
+                                                    selector.value,
+                                                    get_or_make_label(instruction.default_),
+                                                    instruction.target.size());
+        for(auto &target : instruction.target)
+            ::LLVMAddCase(
+                switch_instruction,
+                ::LLVMConstInt(selector.type->get_or_make_type(true), target.part_1, false),
+                get_or_make_label(target.part_2));
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_kill(Op_kill instruction,
@@ -4304,14 +4647,20 @@ void Spirv_to_llvm::handle_instruction_op_kill(Op_kill instruction,
                            + std::string(get_enumerant_name(instruction.get_operation())));
 }
 
-void Spirv_to_llvm::handle_instruction_op_return(Op_return instruction,
-                                                 std::size_t instruction_start_index)
+void Spirv_to_llvm::handle_instruction_op_return(
+    [[gnu::unused]] Op_return instruction, [[gnu::unused]] std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    current_basic_block_id = 0;
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        ::LLVMBuildRetVoid(builder.get());
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_return_value(Op_return_value instruction,
index 6b377f4b93915c65eaedda842e4a1a571db9ccdd..c568323222811c9b557f127542355099034dbda4 100644 (file)
 #include <vector>
 #include <string>
 #include <cassert>
+#include <type_traits>
+#include <utility>
 #include "llvm_wrapper/llvm_wrapper.h"
 
 namespace vulkan_cpu
 {
 namespace spirv_to_llvm
 {
+class Simple_type_descriptor;
+class Vector_type_descriptor;
+class Matrix_type_descriptor;
+class Pointer_type_descriptor;
+class Function_type_descriptor;
+class Struct_type_descriptor;
 class Type_descriptor
 {
     Type_descriptor(const Type_descriptor &) = delete;
     Type_descriptor &operator=(const Type_descriptor &) = delete;
 
+public:
+    struct Type_visitor
+    {
+        virtual ~Type_visitor() = default;
+        virtual void visit(Simple_type_descriptor &type) = 0;
+        virtual void visit(Vector_type_descriptor &type) = 0;
+        virtual void visit(Matrix_type_descriptor &type) = 0;
+        virtual void visit(Pointer_type_descriptor &type) = 0;
+        virtual void visit(Function_type_descriptor &type) = 0;
+        virtual void visit(Struct_type_descriptor &type) = 0;
+    };
+
 public:
     Type_descriptor() noexcept = default;
     virtual ~Type_descriptor() = default;
     virtual ::LLVMTypeRef get_or_make_type(bool need_complete_structs) = 0;
+    virtual void visit(Type_visitor &type_visitor) = 0;
+    void visit(Type_visitor &&type_visitor)
+    {
+        visit(type_visitor);
+    }
+    template <typename Fn>
+    typename std::enable_if<!std::is_convertible<Fn &&, const Type_visitor &>::value, void>::type
+        visit(Fn &&fn)
+    {
+        struct Visitor final : public Type_visitor
+        {
+            Fn &fn;
+            virtual void visit(Simple_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            virtual void visit(Vector_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            virtual void visit(Matrix_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            virtual void visit(Pointer_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            virtual void visit(Function_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            virtual void visit(Struct_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
+            explicit Visitor(Fn &fn) noexcept : fn(fn)
+            {
+            }
+        };
+        visit(Visitor(fn));
+    }
     class Recursion_checker;
     class Recursion_checker_state
     {
@@ -99,6 +161,77 @@ public:
     {
         return type;
     }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
+};
+
+class Vector_type_descriptor final : public Type_descriptor
+{
+private:
+    ::LLVMTypeRef type;
+    std::shared_ptr<Simple_type_descriptor> element_type;
+    std::size_t element_count;
+
+public:
+    explicit Vector_type_descriptor(std::shared_ptr<Simple_type_descriptor> element_type,
+                                    std::size_t element_count) noexcept
+        : type(::LLVMVectorType(element_type->get_or_make_type(true), element_count)),
+          element_type(std::move(element_type)),
+          element_count(element_count)
+    {
+    }
+    virtual ::LLVMTypeRef get_or_make_type([[gnu::unused]] bool need_complete_structs) override
+    {
+        return type;
+    }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
+    const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
+    {
+        return element_type;
+    }
+    std::size_t get_element_count() const noexcept
+    {
+        return element_count;
+    }
+};
+
+class Matrix_type_descriptor final : public Type_descriptor
+{
+private:
+    ::LLVMTypeRef type;
+    std::shared_ptr<Vector_type_descriptor> column_type;
+    std::size_t column_count;
+
+public:
+    explicit Matrix_type_descriptor(std::shared_ptr<Vector_type_descriptor> column_type,
+                                    std::size_t column_count) noexcept
+        : type(::LLVMVectorType(column_type->get_element_type()->get_or_make_type(true),
+                                column_type->get_element_count() * column_count)),
+          column_type(std::move(column_type)),
+          column_count(column_count)
+    {
+    }
+    virtual ::LLVMTypeRef get_or_make_type([[gnu::unused]] bool need_complete_structs) override
+    {
+        return type;
+    }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
+    const std::shared_ptr<Vector_type_descriptor> &get_column_type() const noexcept
+    {
+        return column_type;
+    }
+    std::size_t get_column_count() const noexcept
+    {
+        return column_count;
+    }
 };
 
 class Pointer_type_descriptor final : public Type_descriptor
@@ -149,6 +282,10 @@ public:
         }
         return type;
     }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
 };
 
 class Function_type_descriptor final : public Type_descriptor
@@ -188,6 +325,10 @@ public:
         }
         return type;
     }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
 };
 
 class Struct_type_descriptor final : public Type_descriptor
@@ -263,6 +404,10 @@ public:
         }
         return type;
     }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
 };
 
 class Constant_descriptor