implementing uniforms; implemented matrix multiplication kazan-old
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 25 Sep 2017 00:40:18 +0000 (17:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 25 Sep 2017 00:40:18 +0000 (17:40 -0700)
12 files changed:
src/demo/demo.cpp
src/llvm_wrapper/llvm_wrapper.cpp
src/llvm_wrapper/llvm_wrapper.h
src/pipeline/pipeline.cpp
src/pipeline/pipeline.h
src/spirv_to_llvm/core_instructions.cpp
src/spirv_to_llvm/fragment_entry_point.cpp
src/spirv_to_llvm/matrix_operations.h
src/spirv_to_llvm/spirv_to_llvm.cpp
src/spirv_to_llvm/spirv_to_llvm.h
src/spirv_to_llvm/spirv_to_llvm_implementation.h
src/spirv_to_llvm/vertex_entry_point.cpp

index 400da19089540226ddc28fcf77a8396cd244a97f..673a87dc64b3bd9cde77afe85f682185d32b2424 100644 (file)
@@ -901,12 +901,20 @@ int test_main(int argc, char **argv)
         void *bindings[binding_count] = {
             vertexes.data(),
         };
-        graphics_pipeline->run(
-            vertex_start_index, vertex_end_index, instance_id, *color_attachment, bindings);
+        struct Uniforms
+        {
+        };
+        Uniforms uniforms{};
+        graphics_pipeline->run(vertex_start_index,
+                               vertex_end_index,
+                               instance_id,
+                               *color_attachment,
+                               bindings,
+                               &uniforms);
         typedef std::uint32_t Pixel_type;
         // check Pixel_type
         static_assert(std::is_void<util::void_t<decltype(graphics_pipeline->run_fragment_shader(
-                          static_cast<Pixel_type *>(nullptr)))>>::value,
+                          static_cast<Pixel_type *>(nullptr), nullptr))>>::value,
                       "");
         auto rgba = [](std::uint8_t r,
                        std::uint8_t g,
index b4faf6648d1b4f84a920c4c08aacb1ed843ffa7f..777f24417acb7fa6b7304fe1c89e003330c0770b 100644 (file)
 #include <llvm/ExecutionEngine/SectionMemoryManager.h>
 #include <llvm-c/ExecutionEngine.h>
 #include <llvm/IR/DataLayout.h>
+#include <llvm/IR/Intrinsics.h>
 #include <llvm/Target/TargetMachine.h>
 #include <llvm/Analysis/TargetTransformInfo.h>
 #include <iostream>
 #include <cstdlib>
 #include <algorithm>
+#include <vector>
+#include <memory>
 
 namespace kazan
 {
@@ -148,6 +151,22 @@ unsigned Target_machine::get_biggest_vector_register_bit_width(::LLVMTargetMachi
         .getRegisterBitWidth(true);
 }
 
+LLVM_intrinsic_id get_llvm_intrinsic_id(Intrinsic intrinsic) noexcept
+{
+    using llvm::Intrinsic::ID;
+    auto cvt = [](ID v) noexcept
+    {
+        return static_cast<LLVM_intrinsic_id>(static_cast<unsigned>(v));
+    };
+    switch(intrinsic)
+    {
+    case Intrinsic::fmuladd:
+        return cvt(ID::fmuladd);
+    }
+    assert(false);
+    return LLVM_intrinsic_id::Not_intrinsic;
+}
+
 void Module::set_target_machine(::LLVMModuleRef module, ::LLVMTargetMachineRef target_machine)
 {
     ::LLVMSetTarget(module, Target_machine::get_target_triple(target_machine).get());
@@ -163,5 +182,32 @@ void Module::set_function_target_machine(::LLVMValueRef function,
     ::LLVMAddTargetDependentFunctionAttr(
         function, "target-features", Target_machine::get_feature_string(target_machine).get());
 }
+
+::LLVMValueRef Module::get_intrinsic_declaration(::LLVMModuleRef module,
+                                                 LLVM_intrinsic_id llvm_intrinsic_id,
+                                                 const ::LLVMTypeRef *types,
+                                                 std::size_t type_count)
+{
+    auto *module_pointer = llvm::unwrap(module);
+    constexpr std::size_t array_size = 4;
+    llvm::Type *on_stack_array[array_size];
+    std::unique_ptr<llvm::Type *[]> on_heap_array;
+    llvm::Type **unwrapped_types;
+    if(type_count > array_size)
+    {
+        on_heap_array.reset(new llvm::Type *[type_count]);
+        unwrapped_types = on_heap_array.get();
+    }
+    else
+    {
+        unwrapped_types = on_stack_array;
+    }
+    for(std::size_t i = 0; i < type_count; i++)
+        unwrapped_types[i] = llvm::unwrap(types[i]);
+    return llvm::wrap(llvm::Intrinsic::getDeclaration(
+        module_pointer,
+        static_cast<llvm::Intrinsic::ID>(static_cast<unsigned>(llvm_intrinsic_id)),
+        llvm::ArrayRef<llvm::Type *>(unwrapped_types, type_count)));
+}
 }
 }
index 9ac2d077e50c048a8f03ce918cf30bb87777a86a..c79e286c1eafe8a994f6422cab71977ed163d8c3 100644 (file)
@@ -37,6 +37,7 @@
 #include <string>
 #include <cassert>
 #include <stdexcept>
+#include <initializer_list>
 #include "util/string_view.h"
 #include "util/variant.h"
 
@@ -317,6 +318,19 @@ struct Target_machine : public Wrapper<::LLVMTargetMachineRef, Target_machine_de
     }
 };
 
+enum class Intrinsic // doesn't match llvm::Intrinsic::ID
+{
+    fmuladd,
+};
+
+enum class LLVM_intrinsic_id : unsigned
+{
+    Not_intrinsic = 0,
+    Maximum_intrinsic_id = static_cast<unsigned>(-1)
+};
+
+LLVM_intrinsic_id get_llvm_intrinsic_id(Intrinsic intrinsic) noexcept;
+
 struct Module_deleter
 {
     void operator()(::LLVMModuleRef module) const noexcept
@@ -347,6 +361,27 @@ struct Module : public Wrapper<::LLVMModuleRef, Module_deleter>
     {
         set_target_machine(get(), target_machine);
     }
+    static ::LLVMValueRef get_intrinsic_declaration(::LLVMModuleRef module,
+                                                    LLVM_intrinsic_id llvm_intrinsic_id,
+                                                    const ::LLVMTypeRef *types,
+                                                    std::size_t type_count);
+    ::LLVMValueRef get_intrinsic_declaration(LLVM_intrinsic_id llvm_intrinsic_id,
+                                             const ::LLVMTypeRef *types,
+                                             std::size_t type_count)
+    {
+        return get_intrinsic_declaration(get(), llvm_intrinsic_id, types, type_count);
+    }
+    static ::LLVMValueRef get_intrinsic_declaration(::LLVMModuleRef module,
+                                                    LLVM_intrinsic_id llvm_intrinsic_id,
+                                                    std::initializer_list<::LLVMTypeRef> types)
+    {
+        return get_intrinsic_declaration(module, llvm_intrinsic_id, types.begin(), types.size());
+    }
+    ::LLVMValueRef get_intrinsic_declaration(LLVM_intrinsic_id llvm_intrinsic_id,
+                                             std::initializer_list<::LLVMTypeRef> types)
+    {
+        return get_intrinsic_declaration(get(), llvm_intrinsic_id, types);
+    }
 };
 
 inline LLVM_string print_type_to_string(::LLVMTypeRef type)
@@ -401,6 +436,32 @@ struct Builder : public Wrapper<::LLVMBuilderRef, Builder_deleter>
     {
         return build_smod(get(), lhs, rhs, result_name);
     }
+    static ::LLVMValueRef build_fmuladd(::LLVMBuilderRef builder,
+                                        ::LLVMModuleRef module,
+                                        ::LLVMValueRef factor1,
+                                        ::LLVMValueRef factor2,
+                                        ::LLVMValueRef term,
+                                        const char *result_name)
+    {
+        auto type = ::LLVMTypeOf(factor1);
+        assert(type == ::LLVMTypeOf(factor2));
+        assert(type == ::LLVMTypeOf(term));
+        auto intrinsic = Module::get_intrinsic_declaration(
+            module, get_llvm_intrinsic_id(Intrinsic::fmuladd), {type});
+        constexpr std::size_t arg_count = 3;
+        ::LLVMValueRef args[arg_count] = {
+            factor1, factor2, term,
+        };
+        return ::LLVMBuildCall(builder, intrinsic, args, arg_count, result_name);
+    }
+    ::LLVMValueRef build_fmuladd(::LLVMModuleRef module,
+                                 ::LLVMValueRef factor1,
+                                 ::LLVMValueRef factor2,
+                                 ::LLVMValueRef term,
+                                 const char *result_name) const
+    {
+        return build_fmuladd(get(), module, factor1, factor2, term, result_name);
+    }
 };
 
 struct Pass_manager_deleter
index ed9db49d940c4f0f2c68f586679951ee91bfa76a..064a070aa34f9b2ea81fd7b02221988cea6091be 100644 (file)
@@ -471,7 +471,8 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index,
                             std::uint32_t vertex_end_index,
                             std::uint32_t instance_id,
                             const vulkan::Vulkan_image &color_attachment,
-                            void *const *bindings)
+                            void *const *bindings,
+                            void *uniforms)
 {
     typedef std::uint32_t Pixel_type;
     assert(color_attachment.descriptor.tiling == VK_IMAGE_TILING_LINEAR);
@@ -684,7 +685,8 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index,
                           current_vertex_start_index + chunk_size,
                           instance_id,
                           chunk_vertex_buffer.get(),
-                          bindings);
+                          bindings,
+                          uniforms);
         const unsigned char *current_vertex =
             chunk_vertex_buffer.get() + vertex_shader_position_output_offset;
         triangles.clear();
@@ -932,7 +934,7 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index,
                             static_cast<unsigned char *>(color_attachment_memory)
                             + (static_cast<std::size_t>(x) * color_attachment_pixel_size
                                + static_cast<std::size_t>(y) * color_attachment_stride));
-                        fs(pixel);
+                        fs(pixel, uniforms);
                     }
                 }
             }
index d5af52748d4a390345cd3662c41e224910c33bae..2f5e42472ab5437c52face8d3ce43ff54a12bcbc 100644 (file)
@@ -161,33 +161,40 @@ public:
                                            std::uint32_t vertex_end_index,
                                            std::uint32_t instance_id,
                                            void *output_buffer,
-                                           void *const *bindings);
-    typedef void (*Fragment_shader_function)(std::uint32_t *color_attachment_pixel);
+                                           void *const *input_bindings,
+                                           void *uniforms);
+    typedef void (*Fragment_shader_function)(std::uint32_t *color_attachment_pixel, void *uniforms);
 
 public:
     void run_vertex_shader(std::uint32_t vertex_start_index,
                            std::uint32_t vertex_end_index,
                            std::uint32_t instance_id,
                            void *output_buffer,
-                           void *const *input_bindings) const noexcept
+                           void *const *input_bindings,
+                           void *uniforms) const noexcept
     {
-        vertex_shader_function(
-            vertex_start_index, vertex_end_index, instance_id, output_buffer, input_bindings);
+        vertex_shader_function(vertex_start_index,
+                               vertex_end_index,
+                               instance_id,
+                               output_buffer,
+                               input_bindings,
+                               uniforms);
     }
     std::size_t get_vertex_shader_output_struct_size() const noexcept
     {
         return vertex_shader_output_struct_size;
     }
     void dump_vertex_shader_output_struct(const void *output_struct) const;
-    void run_fragment_shader(std::uint32_t *color_attachment_pixel) const noexcept
+    void run_fragment_shader(std::uint32_t *color_attachment_pixel, void *uniforms) const noexcept
     {
-        fragment_shader_function(color_attachment_pixel);
+        fragment_shader_function(color_attachment_pixel, uniforms);
     }
     void run(std::uint32_t vertex_start_index,
              std::uint32_t vertex_end_index,
              std::uint32_t instance_id,
              const vulkan::Vulkan_image &color_attachment,
-             void *const *bindings);
+             void *const *input_bindings,
+             void *uniforms);
     static std::unique_ptr<Graphics_pipeline> create(
         vulkan::Vulkan_device &,
         Pipeline_cache *pipeline_cache,
index 1d4b7661d196fad03e8b5a5a35f57b6b40f7533b..2e244ce4c6434a9a5e31b151c7a702a438bb53d0 100644 (file)
@@ -1056,7 +1056,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
     case Stage::calculate_types:
     {
         auto &state = get_id_state(instruction.result);
-        bool check_decorations = true;
+        bool parse_decorations = true;
         [&]()
         {
             switch(instruction.storage_class)
@@ -1077,7 +1077,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                     Input_variable_state{type,
                                          inputs_struct->add_member(Struct_type_descriptor::Member(
                                              state.decorations, type))};
-                check_decorations = false;
+                parse_decorations = false;
                 return;
             }
             case Storage_class::uniform:
@@ -1086,7 +1086,10 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                     throw Parser_error(instruction_start_index,
                                        instruction_start_index,
                                        "shader uniform variable initializers are not implemented");
-                state.variable = Uniform_variable_state{};
+                auto type = get_type<Pointer_type_descriptor>(instruction.result_type,
+                                                              instruction_start_index)
+                                ->get_base_type();
+                state.variable = Uniform_variable_state(type);
                 return;
             }
             case Storage_class::output:
@@ -1102,7 +1105,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                     Output_variable_state{type,
                                           outputs_struct->add_member(Struct_type_descriptor::Member(
                                               state.decorations, type))};
-                check_decorations = false;
+                parse_decorations = false;
                 return;
             }
             case Storage_class::workgroup:
@@ -1143,7 +1146,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                                "unimplemented OpVariable storage class: "
                                    + std::string(get_enumerant_name(instruction.storage_class)));
         }();
-        if(check_decorations)
+        if(parse_decorations)
         {
             for(auto &decoration : state.decorations)
             {
@@ -1244,9 +1247,13 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                     break;
                 case Decoration::binding:
                 {
+                    auto &parameters =
+                        util::get<spirv::Decoration_binding_parameters>(decoration.parameters);
                     switch(instruction.storage_class)
                     {
                     case spirv::Storage_class::uniform:
+                        util::get<Uniform_variable_state>(state.variable).binding =
+                            parameters.binding_point;
                         continue;
 #warning finish implementing Decoration::binding
                     default:
@@ -1260,9 +1267,13 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                 }
                 case Decoration::descriptor_set:
                 {
+                    auto &parameters = util::get<spirv::Decoration_descriptor_set_parameters>(
+                        decoration.parameters);
                     switch(instruction.storage_class)
                     {
                     case spirv::Storage_class::uniform:
+                        util::get<Uniform_variable_state>(state.variable).descriptor_set =
+                            parameters.descriptor_set;
                         continue;
 #warning finish implementing Decoration::descriptor_set
                     default:
@@ -1383,7 +1394,100 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
         }
         case Storage_class::uniform:
 #warning finish implementing Storage_class::uniform
-            break;
+        {
+            if(instruction.initializer)
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "shader uniform variable initializers are not implemented");
+            auto set_value_fn = [this, instruction, &state, instruction_start_index]()
+            {
+                auto &variable = util::get<Uniform_variable_state>(state.variable);
+                if(!variable.binding)
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "shader uniform variable is missing a Binding decoration");
+                if(!variable.descriptor_set)
+                    throw Parser_error(
+                        instruction_start_index,
+                        instruction_start_index,
+                        "shader uniform variable is missing a DescriptorSet decoration");
+                auto binding_number = *variable.binding;
+                auto descriptor_set_number = *variable.descriptor_set;
+                if(descriptor_set_number >= pipeline_layout.descriptor_sets.size())
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "DescriptorSet decoration's value is out of range");
+                auto &descriptor_set = pipeline_layout.descriptor_sets[descriptor_set_number];
+                if(binding_number >= descriptor_set.bindings.size())
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "Binding decoration's value is out of range");
+                auto &binding = descriptor_set.bindings[binding_number];
+                auto &uniforms_struct_member =
+                    pipeline_layout.type->get_members(true)[binding.member_index];
+                auto uniform_slot_address = ::LLVMBuildStructGEP(
+                    builder.get(),
+                    get_id_state(current_function_id).function->entry_block->uniforms_struct,
+                    uniforms_struct_member.llvm_member_index,
+                    "");
+                auto result_type = get_type(instruction.result_type, instruction_start_index);
+                ::LLVMValueRef result = nullptr;
+                switch(binding.base->descriptor_type)
+                {
+                case VK_DESCRIPTOR_TYPE_SAMPLER:
+#warning implement VK_DESCRIPTOR_TYPE_SAMPLER uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+#warning implement VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
+#warning implement VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
+#warning implement VK_DESCRIPTOR_TYPE_STORAGE_IMAGE uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
+#warning implement VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
+#warning implement VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
+                    result =
+                        ::LLVMBuildBitCast(builder.get(),
+                                           ::LLVMBuildLoad(builder.get(), uniform_slot_address, ""),
+                                           result_type->get_or_make_type().type,
+                                           "");
+                    break;
+                case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
+#warning implement VK_DESCRIPTOR_TYPE_STORAGE_BUFFER uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
+#warning implement VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
+#warning implement VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
+#warning implement VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT uniform variables
+                    break;
+                case VK_DESCRIPTOR_TYPE_RANGE_SIZE:
+                case VK_DESCRIPTOR_TYPE_MAX_ENUM:
+                    break;
+                }
+                if(result == nullptr)
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented uniform descriptor type");
+                ::LLVMSetValueName(result, get_name(instruction.result).c_str());
+                state.value = Value(result, std::move(result_type));
+            };
+            if(current_function_id)
+                set_value_fn();
+            else
+                function_entry_block_handlers.push_back(set_value_fn);
+            return;
+        }
         case Storage_class::output:
         {
             if(instruction.initializer)
@@ -1460,7 +1564,10 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
 #warning finish implementing Storage_class::storage_buffer
             break;
         }
-        break;
+        throw Parser_error(instruction_start_index,
+                           instruction_start_index,
+                           "unimplemented OpVariable storage class: "
+                               + std::string(get_enumerant_name(instruction.storage_class)));
     }
     }
 }
@@ -1525,6 +1632,7 @@ void Spirv_to_llvm::handle_instruction_op_load(Op_load instruction,
                 builder.get(), get_id_state(instruction.pointer).value.value().value, "");
             ::LLVMSetAlignment(untransposed_value, memory_type->get_or_make_type().alignment);
             state.value = Value(matrix_operations::transpose(context,
+                                                             module.get(),
                                                              builder.get(),
                                                              untransposed_value,
                                                              get_name(instruction.result).c_str()),
@@ -1574,10 +1682,8 @@ void Spirv_to_llvm::handle_instruction_op_store(Op_store instruction,
             break;
         case Type_descriptor::Load_store_implementation_kind::Transpose_matrix:
         {
-            auto transposed_value = matrix_operations::transpose(context,
-                                                             builder.get(),
-                                                             object_value.value,
-                                                             "");
+            auto transposed_value = matrix_operations::transpose(
+                context, module.get(), builder.get(), object_value.value, "");
             ::LLVMSetAlignment(
                 ::LLVMBuildStore(builder.get(), transposed_value, pointer_value.value),
                 memory_type->get_or_make_type().alignment);
@@ -2845,21 +2951,113 @@ void Spirv_to_llvm::handle_instruction_op_vector_times_matrix(Op_vector_times_ma
 void Spirv_to_llvm::handle_instruction_op_matrix_times_vector(Op_matrix_times_vector instruction,
                                                               std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        if(!state.decorations.empty())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "decorations on instruction not implemented: "
+                                   + std::string(get_enumerant_name(instruction.get_operation())));
+        auto result_type =
+            get_type<Vector_type_descriptor>(instruction.result_type, instruction_start_index);
+        auto &matrix = get_id_state(instruction.matrix).value.value();
+        auto &vector = get_id_state(instruction.vector).value.value();
+        auto matrix_type = std::dynamic_pointer_cast<Matrix_type_descriptor>(matrix.type);
+        if(!matrix_type)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesVector matrix operand type mismatch: not a matrix");
+        auto vector_type = std::dynamic_pointer_cast<Vector_type_descriptor>(vector.type);
+        if(!vector_type)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesVector vector operand type mismatch: not a vector");
+        if(matrix_type->get_row_count() != result_type->get_element_count())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesVector matrix operand type mismatch: row count "
+                               "doesn't match result_type's element count");
+        if(matrix_type->get_column_count() != vector_type->get_element_count())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesVector matrix operand type mismatch: column "
+                               "count doesn't match vector's element count");
+        state.value =
+            Value(matrix_operations::matrix_times_vector(context,
+                                                         module.get(),
+                                                         builder.get(),
+                                                         matrix.value,
+                                                         vector.value,
+                                                         get_name(instruction.result).c_str()),
+                  result_type);
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_matrix_times_matrix(Op_matrix_times_matrix instruction,
                                                               std::size_t instruction_start_index)
 {
-#warning finish
-    throw Parser_error(instruction_start_index,
-                       instruction_start_index,
-                       "instruction not implemented: "
-                           + std::string(get_enumerant_name(instruction.get_operation())));
+    switch(stage)
+    {
+    case Stage::calculate_types:
+        break;
+    case Stage::generate_code:
+    {
+        auto &state = get_id_state(instruction.result);
+        if(!state.decorations.empty())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "decorations on instruction not implemented: "
+                                   + std::string(get_enumerant_name(instruction.get_operation())));
+        auto result_type =
+            get_type<Matrix_type_descriptor>(instruction.result_type, instruction_start_index);
+        auto &left_matrix = get_id_state(instruction.left_matrix).value.value();
+        auto &right_matrix = get_id_state(instruction.right_matrix).value.value();
+        auto left_matrix_type = std::dynamic_pointer_cast<Matrix_type_descriptor>(left_matrix.type);
+        if(!left_matrix_type)
+            throw Parser_error(
+                instruction_start_index,
+                instruction_start_index,
+                "OpMatrixTimesMatrix left_matrix operand type mismatch: not a matrix");
+        auto right_matrix_type =
+            std::dynamic_pointer_cast<Matrix_type_descriptor>(right_matrix.type);
+        if(!right_matrix_type)
+            throw Parser_error(
+                instruction_start_index,
+                instruction_start_index,
+                "OpMatrixTimesMatrix right_matrix operand type mismatch: not a matrix");
+        if(left_matrix_type->get_row_count() != result_type->get_row_count())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesMatrix left_matrix operand type mismatch: row count "
+                               "doesn't match result_type's row count");
+        if(right_matrix_type->get_column_count() != result_type->get_column_count())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesMatrix right_matrix operand type mismatch: column "
+                               "count doesn't match result_type's column count");
+        if(left_matrix_type->get_column_count() != right_matrix_type->get_row_count())
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "OpMatrixTimesMatrix left_matrix operand type mismatch: column "
+                               "count doesn't match right_matrix's row count");
+        state.value =
+            Value(matrix_operations::matrix_multiply(context,
+                                                     module.get(),
+                                                     builder.get(),
+                                                     left_matrix.value,
+                                                     right_matrix.value,
+                                                     get_name(instruction.result).c_str()),
+                  result_type);
+        break;
+    }
+    }
 }
 
 void Spirv_to_llvm::handle_instruction_op_outer_product(Op_outer_product instruction,
@@ -3800,8 +3998,19 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction,
                     io_struct->get_members(true)[this->outputs_member].llvm_member_index,
                     "outputs_pointer"),
                 "outputs");
-            function.entry_block = Function_state::Entry_block(
-                block, io_struct_value, inputs_struct_value, outputs_struct_value);
+            auto uniforms_struct_value = ::LLVMBuildLoad(
+                builder.get(),
+                ::LLVMBuildStructGEP(
+                    builder.get(),
+                    io_struct_value,
+                    io_struct->get_members(true)[this->uniforms_member].llvm_member_index,
+                    "uniforms_pointer"),
+                "uniforms");
+            function.entry_block = Function_state::Entry_block(block,
+                                                               io_struct_value,
+                                                               inputs_struct_value,
+                                                               outputs_struct_value,
+                                                               uniforms_struct_value);
             for(auto iter = function_entry_block_handlers.begin();
                 iter != function_entry_block_handlers.end();)
             {
index 0ae3979cb2388d04e658b8157594a5eff179f41d..ce7fe7d36f4cc87f18679e965eda6b55790eb5c8 100644 (file)
@@ -38,8 +38,9 @@ using namespace spirv;
     auto llvm_vec4_type = ::LLVMVectorType(llvm_float_type, 4);
     auto llvm_u8vec4_type = ::LLVMVectorType(llvm_u8_type, 4);
     static_cast<void>(llvm_pixel_type);
-    typedef void (*Fragment_shader_function)(Pixel_type *color_attachment_pixel);
+    typedef void (*Fragment_shader_function)(Pixel_type *color_attachment_pixel, void *uniforms);
     constexpr std::size_t arg_color_attachment_pixel = 0;
+    constexpr std::size_t arg_uniforms = 1;
     static_assert(std::is_same<Fragment_shader_function,
                                pipeline::Graphics_pipeline::Fragment_shader_function>::value,
                   "vertex shader function signature mismatch");
@@ -49,6 +50,8 @@ using namespace spirv;
     llvm_wrapper::Module::set_function_target_machine(entry_function, target_machine);
     auto color_attachment_pixel = ::LLVMGetParam(entry_function, arg_color_attachment_pixel);
     ::LLVMSetValueName(color_attachment_pixel, "color_attachment_pixel");
+    auto uniforms = ::LLVMGetParam(entry_function, arg_uniforms);
+    ::LLVMSetValueName(uniforms, "uniforms");
     auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry");
     ::LLVMPositionBuilderAtEnd(builder.get(), entry_block);
     auto io_struct_type = io_struct->get_or_make_type();
@@ -634,7 +637,8 @@ using namespace spirv;
         else if(member_index == uniforms_member)
         {
 #warning implement shader uniforms
-            assert(this->pipeline_layout.descriptor_sets.empty() && "shader uniforms not implemented");
+            assert(this->pipeline_layout.descriptor_sets.empty()
+                   && "shader uniforms not implemented");
         }
         else
         {
@@ -702,8 +706,9 @@ using namespace spirv;
     auto packed_output_color = ::LLVMBuildBitCast(
         builder.get(), converted_output_color, llvm_pixel_type, "packed_output_color");
     ::LLVMBuildStore(builder.get(), packed_output_color, color_attachment_pixel);
-    static_assert(
-        std::is_same<decltype(std::declval<Fragment_shader_function>()(nullptr)), void>::value, "");
+    static_assert(std::is_same<decltype(std::declval<Fragment_shader_function>()(nullptr, nullptr)),
+                               void>::value,
+                  "");
     ::LLVMBuildRetVoid(builder.get());
     return entry_function;
 }
index 4b50332f8791ed29545c2a519f6ff90e8a95966f..5c716c8449f74a39f14b519f975389ab8dade6e8 100644 (file)
@@ -60,7 +60,27 @@ struct Matrix_descriptor
     }
 };
 
+struct Vector_descriptor
+{
+    std::uint32_t element_count;
+    ::LLVMTypeRef element_type;
+    ::LLVMTypeRef vector_type;
+    explicit Vector_descriptor(::LLVMTypeRef vector_type) noexcept : vector_type(vector_type)
+    {
+        assert(::LLVMGetTypeKind(vector_type) == ::LLVMVectorTypeKind);
+        element_count = ::LLVMGetVectorSize(vector_type);
+        element_type = ::LLVMGetElementType(vector_type);
+    }
+    Vector_descriptor(::LLVMTypeRef element_type, std::uint32_t element_count)
+        : element_count(element_count),
+          element_type(element_type),
+          vector_type(::LLVMVectorType(element_type, element_count))
+    {
+    }
+};
+
 inline ::LLVMValueRef transpose(::LLVMContextRef context,
+                                ::LLVMModuleRef module,
                                 ::LLVMBuilderRef builder,
                                 ::LLVMValueRef input_matrix,
                                 const char *output_name)
@@ -100,6 +120,92 @@ inline ::LLVMValueRef transpose(::LLVMContextRef context,
     ::LLVMSetValueName(output_value, output_name);
     return output_value;
 }
+
+inline ::LLVMValueRef vector_broadcast_from_vector(::LLVMContextRef context,
+                                                   ::LLVMBuilderRef builder,
+                                                   ::LLVMValueRef input_vector,
+                                                   std::uint32_t input_vector_index,
+                                                   std::uint32_t output_vector_length,
+                                                   const char *output_name)
+{
+    auto i32_type = llvm_wrapper::Create_llvm_type<std::uint32_t>()(context);
+    auto index = ::LLVMConstInt(i32_type, input_vector_index, false);
+    std::vector<::LLVMValueRef> shuffle_arguments(output_vector_length, index);
+    auto shuffle_index_vector =
+        ::LLVMConstVector(shuffle_arguments.data(), shuffle_arguments.size());
+    return ::LLVMBuildShuffleVector(builder,
+                                    input_vector,
+                                    ::LLVMGetUndef(::LLVMTypeOf(input_vector)),
+                                    shuffle_index_vector,
+                                    output_name);
+}
+
+inline ::LLVMValueRef matrix_multiply(::LLVMContextRef context,
+                                      ::LLVMModuleRef module,
+                                      ::LLVMBuilderRef builder,
+                                      ::LLVMValueRef left_matrix,
+                                      ::LLVMValueRef right_matrix,
+                                      const char *output_name)
+{
+    Matrix_descriptor left_matrix_descriptor(::LLVMTypeOf(left_matrix));
+    Matrix_descriptor right_matrix_descriptor(::LLVMTypeOf(right_matrix));
+    assert(left_matrix_descriptor.element_type == right_matrix_descriptor.element_type);
+    assert(left_matrix_descriptor.columns == right_matrix_descriptor.rows);
+    assert(left_matrix_descriptor.columns != 0);
+    assert(left_matrix_descriptor.rows != 0);
+    assert(right_matrix_descriptor.columns != 0);
+    Matrix_descriptor result_matrix_descriptor(left_matrix_descriptor.element_type,
+                                               left_matrix_descriptor.rows,
+                                               right_matrix_descriptor.columns);
+    ::LLVMValueRef retval = ::LLVMGetUndef(result_matrix_descriptor.matrix_type);
+    for(std::size_t i = 0; i < right_matrix_descriptor.columns; i++)
+    {
+        ::LLVMValueRef right_matrix_column = ::LLVMBuildExtractValue(builder, right_matrix, i, "");
+        ::LLVMValueRef sum{};
+        for(std::size_t j = 0; j < left_matrix_descriptor.columns; j++)
+        {
+            auto factor0 = ::LLVMBuildExtractValue(builder, left_matrix, j, "");
+            auto factor1 = vector_broadcast_from_vector(
+                context, builder, right_matrix_column, j, left_matrix_descriptor.rows, "");
+            if(j == 0)
+                sum = ::LLVMBuildFMul(builder, factor0, factor1, "");
+            else
+                sum = llvm_wrapper::Builder::build_fmuladd(
+                    builder, module, factor0, factor1, sum, "");
+        }
+        retval = ::LLVMBuildInsertValue(builder, retval, sum, i, "");
+    }
+    ::LLVMSetValueName(retval, output_name);
+    return retval;
+}
+
+inline ::LLVMValueRef matrix_times_vector(::LLVMContextRef context,
+                                          ::LLVMModuleRef module,
+                                          ::LLVMBuilderRef builder,
+                                          ::LLVMValueRef matrix,
+                                          ::LLVMValueRef input_vector,
+                                          const char *output_name)
+{
+    Matrix_descriptor matrix_descriptor(::LLVMTypeOf(matrix));
+    Vector_descriptor input_vector_descriptor(::LLVMTypeOf(input_vector));
+    assert(matrix_descriptor.element_type == input_vector_descriptor.element_type);
+    assert(matrix_descriptor.columns == input_vector_descriptor.element_count);
+    assert(matrix_descriptor.columns != 0);
+    ::LLVMValueRef retval{};
+    for(std::size_t i = 0; i < matrix_descriptor.columns; i++)
+    {
+        auto factor0 = ::LLVMBuildExtractValue(builder, matrix, i, "");
+        auto factor1 = vector_broadcast_from_vector(
+            context, builder, input_vector, i, matrix_descriptor.rows, "");
+        if(i == 0)
+            retval = ::LLVMBuildFMul(builder, factor0, factor1, "");
+        else
+            retval =
+                llvm_wrapper::Builder::build_fmuladd(builder, module, factor0, factor1, retval, "");
+    }
+    ::LLVMSetValueName(retval, output_name);
+    return retval;
+}
 }
 }
 }
index 44e2134183b4a1be01e8e2106a397a5ff22a0910..cf6e0669e1cd7f3df6c60272c22de3979ec9f6a6 100644 (file)
@@ -209,11 +209,14 @@ void Struct_type_descriptor::complete_type()
     {
         std::size_t alignment;
         std::size_t size;
+        util::optional<std::size_t> offset;
         ::LLVMTypeRef type;
         explicit Member_descriptor(std::size_t alignment,
                                    std::size_t size,
+                                   util::optional<std::size_t> offset,
                                    ::LLVMTypeRef type) noexcept : alignment(alignment),
                                                                   size(size),
+                                                                  offset(offset),
                                                                   type(type)
         {
         }
@@ -223,6 +226,9 @@ void Struct_type_descriptor::complete_type()
     std::size_t total_alignment = 1;
     for(auto &member : members)
     {
+        util::optional<bool> is_row_major;
+        util::optional<std::size_t> offset;
+        util::optional<std::size_t> matrix_stride;
         for(auto &decoration : member.decorations)
         {
             switch(decoration.value)
@@ -240,17 +246,20 @@ void Struct_type_descriptor::complete_type()
 #warning finish implementing Decoration::buffer_block
                 break;
             case Decoration::row_major:
-#warning finish implementing Decoration::row_major
-                break;
+                is_row_major = true;
+                continue;
             case Decoration::col_major:
-#warning finish implementing Decoration::col_major
-                break;
+                is_row_major = false;
+                continue;
             case Decoration::array_stride:
 #warning finish implementing Decoration::array_stride
                 break;
             case Decoration::matrix_stride:
-#warning finish implementing Decoration::matrix_stride
-                break;
+            {
+                auto &parameters = util::get<spirv::Decoration_matrix_stride_parameters>(decoration.parameters);
+                matrix_stride = parameters.matrix_stride;
+                continue;
+            }
             case Decoration::glsl_shared:
 #warning finish implementing Decoration::glsl_shared
                 break;
@@ -326,8 +335,11 @@ void Struct_type_descriptor::complete_type()
 #warning finish implementing Decoration::descriptor_set
                 break;
             case Decoration::offset:
-#warning finish implementing Decoration::offset
-                break;
+            {
+                auto &parameters = util::get<spirv::Decoration_offset_parameters>(decoration.parameters);
+                offset = parameters.byte_offset;
+                continue;
+            }
             case Decoration::xfb_buffer:
 #warning finish implementing Decoration::xfb_buffer
                 break;
@@ -382,6 +394,14 @@ void Struct_type_descriptor::complete_type()
                                "unimplemented member decoration on OpTypeStruct: "
                                    + std::string(get_enumerant_name(decoration.value)));
         }
+        if(is_row_major)
+        {
+            if(*is_row_major)
+                member.type = member.type->get_row_major_type(target_data);
+            else
+                member.type = member.type->get_column_major_type(target_data);
+        }
+        assert(matrix_stride == member.type->get_matrix_stride(target_data) && "MatrixStride decoration unimplemented for non-default strides");
         auto member_type = member.type->get_or_make_type();
         std::size_t size = ::LLVMABISizeOfType(target_data, member_type.type);
         struct Member_type_visitor : public Type_descriptor::Type_visitor
@@ -400,16 +420,10 @@ void Struct_type_descriptor::complete_type()
             virtual void visit(Matrix_type_descriptor &type) override
             {
 #warning finish implementing member type
-                throw Parser_error(this_->instruction_start_index,
-                                   this_->instruction_start_index,
-                                   "unimplemented member type");
             }
             virtual void visit(Row_major_matrix_type_descriptor &type) override
             {
 #warning finish implementing member type
-                throw Parser_error(this_->instruction_start_index,
-                                   this_->instruction_start_index,
-                                   "unimplemented member type");
             }
             virtual void visit(Array_type_descriptor &type) override
             {
@@ -452,7 +466,7 @@ void Struct_type_descriptor::complete_type()
         if(member_type.alignment > total_alignment)
             total_alignment = member_type.alignment;
         member_descriptors.push_back(
-            Member_descriptor(member_type.alignment, size, member_type.type));
+            Member_descriptor(member_type.alignment, size, offset, member_type.type));
     }
     assert(member_descriptors.size() == members.size());
     assert(is_power_of_2(total_alignment));
@@ -463,6 +477,17 @@ void Struct_type_descriptor::complete_type()
     {
         for(std::size_t member_index = 0; member_index < members.size(); member_index++)
         {
+            if(member_descriptors[member_index].offset)
+            {
+                assert(*member_descriptors[member_index].offset >= current_offset);
+                auto padding_size = *member_descriptors[member_index].offset - current_offset;
+                if(padding_size != 0)
+                {
+                    member_types.push_back(
+                        ::LLVMArrayType(::LLVMInt8TypeInContext(context), padding_size));
+                    current_offset += padding_size;
+                }
+            }
             members[member_index].llvm_member_index = member_types.size();
 #warning finish Struct_type_descriptor::complete_type
             member_types.push_back(member_descriptors[member_index].type);
index 7cc0d7c403a5761d8b3973abd0e089bb79bcdc89..e7dc142e481fd81d78133a98113b15fa0957ee39 100644 (file)
@@ -153,6 +153,10 @@ public:
     {
         return shared_from_this();
     }
+    virtual util::optional<std::size_t> get_matrix_stride(::LLVMTargetDataRef target_data) const
+    {
+        return {};
+    }
     void visit(Type_visitor &&type_visitor)
     {
         visit(type_visitor);
@@ -396,6 +400,10 @@ public:
         column_major_type = retval;
         return retval;
     }
+    virtual util::optional<std::size_t> get_matrix_stride(::LLVMTargetDataRef target_data) const override
+    {
+        return element_type->get_matrix_stride(target_data);
+    }
     const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
     {
         return element_type;
@@ -445,6 +453,14 @@ public:
     {
         return column_count;
     }
+    std::size_t get_row_count() const noexcept
+    {
+        return column_type->get_element_count();
+    }
+    const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
+    {
+        return column_type->get_element_type();
+    }
     virtual std::shared_ptr<Type_descriptor> get_row_major_type(
         ::LLVMTargetDataRef target_data) override
     {
@@ -455,6 +471,10 @@ public:
         row_major_type = retval;
         return retval;
     }
+    virtual util::optional<std::size_t> get_matrix_stride(::LLVMTargetDataRef target_data) const override
+    {
+        return ::LLVMABISizeOfType(target_data, column_type->get_or_make_type().type);
+    }
 };
 
 class Row_major_matrix_type_descriptor final : public Type_descriptor
@@ -496,6 +516,14 @@ public:
     {
         return row_count;
     }
+    std::size_t get_column_count() const noexcept
+    {
+        return row_type->get_element_count();
+    }
+    const std::shared_ptr<Simple_type_descriptor> &get_element_type() const noexcept
+    {
+        return row_type->get_element_type();
+    }
     virtual std::shared_ptr<Type_descriptor> get_column_major_type(
         ::LLVMTargetDataRef target_data) override
     {
index 83d8dfe60a32215ed76bec4ab662446b34db8986..95f00e217bca1d4aa61ad7ef881fbc5bf613ebc3 100644 (file)
@@ -82,6 +82,15 @@ private:
     };
     struct Uniform_variable_state
     {
+        std::shared_ptr<Type_descriptor> type;
+        util::optional<std::uint32_t> binding;
+        util::optional<std::uint32_t> descriptor_set;
+        explicit Uniform_variable_state(std::shared_ptr<Type_descriptor> type) noexcept
+            : type(std::move(type)),
+              binding(),
+              descriptor_set()
+        {
+        }
     };
     typedef util::variant<util::monostate,
                           Input_variable_state,
@@ -95,14 +104,17 @@ private:
             ::LLVMValueRef io_struct;
             ::LLVMValueRef inputs_struct;
             ::LLVMValueRef outputs_struct;
+            ::LLVMValueRef uniforms_struct;
             explicit Entry_block(::LLVMBasicBlockRef entry_block,
                                  ::LLVMValueRef io_struct,
                                  ::LLVMValueRef inputs_struct,
-                                 ::LLVMValueRef outputs_struct) noexcept
+                                 ::LLVMValueRef outputs_struct,
+                                 ::LLVMValueRef uniforms_struct) noexcept
                 : entry_block(entry_block),
                   io_struct(io_struct),
                   inputs_struct(inputs_struct),
-                  outputs_struct(outputs_struct)
+                  outputs_struct(outputs_struct),
+                  uniforms_struct(uniforms_struct)
             {
             }
         };
index 9843972c699b1309250e0af5bc02b2e3f2aa997f..ee742ee5ca3484ee80949e7b86cace01f09b210d 100644 (file)
@@ -44,12 +44,14 @@ using namespace spirv;
                                            Vertex_index_type vertex_end_index,
                                            std::uint32_t instance_id,
                                            void *output_buffer,
-                                           void *const *bindings);
+                                           void *const *bindings,
+                                           void *uniforms);
     constexpr std::size_t arg_vertex_start_index = 0;
     constexpr std::size_t arg_vertex_end_index = 1;
     constexpr std::size_t arg_instance_id = 2;
     constexpr std::size_t arg_output_buffer = 3;
     constexpr std::size_t arg_bindings = 4;
+    constexpr std::size_t arg_uniforms = 5;
     static_assert(std::is_same<Vertex_shader_function,
                                pipeline::Graphics_pipeline::Vertex_shader_function>::value,
                   "vertex shader function signature mismatch");
@@ -63,6 +65,7 @@ using namespace spirv;
     ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_instance_id), "instance_id");
     ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer_");
     ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_bindings), "bindings");
+    ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_uniforms), "uniforms");
     auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry");
     auto loop_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "loop");
     auto exit_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "exit");
@@ -452,6 +455,26 @@ using namespace spirv;
                                         "unimplemented vertex input variable type conversion");
                                 break;
                             }
+                            case VK_FORMAT_R32G32B32_SFLOAT:
+                            {
+                                constexpr std::size_t vector_element_count = 3;
+                                format_type =
+                                    Vector_type_descriptor(
+                                        std::vector<spirv::Decoration_with_parameters>{},
+                                        std::make_shared<Simple_type_descriptor>(
+                                            std::vector<spirv::Decoration_with_parameters>{},
+                                            LLVM_type_and_alignment(llvm_float_type,
+                                                                    llvm_float_type_alignment)),
+                                        vector_element_count,
+                                        target_data)
+                                        .get_or_make_type();
+                                if(input_type.type != format_type.type)
+                                    throw Parser_error(
+                                        0,
+                                        0,
+                                        "unimplemented vertex input variable type conversion");
+                                break;
+                            }
 #warning implement all required formats
                             default:
                                 throw Parser_error(0, 0, "unimplemented vertex input format");
@@ -853,7 +876,8 @@ using namespace spirv;
         else if(member_index == uniforms_member)
         {
 #warning implement shader uniforms
-            assert(this->pipeline_layout.descriptor_sets.empty() && "shader uniforms not implemented");
+            assert(this->pipeline_layout.descriptor_sets.empty()
+                   && "shader uniforms not implemented");
         }
         else
         {
@@ -882,7 +906,7 @@ using namespace spirv;
     ::LLVMBuildCondBr(builder.get(), next_iteration_condition, loop_block, exit_block);
     ::LLVMPositionBuilderAtEnd(builder.get(), exit_block);
     static_assert(
-        std::is_same<decltype(std::declval<Vertex_shader_function>()(0, 0, 0, nullptr, nullptr)),
+        std::is_same<decltype(std::declval<Vertex_shader_function>()(0, 0, 0, nullptr, nullptr, nullptr)),
                      void>::value,
         "");
     ::LLVMBuildRetVoid(builder.get());