added row-major matrix type
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 22 Sep 2017 23:40:03 +0000 (16:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 22 Sep 2017 23:40:03 +0000 (16:40 -0700)
src/pipeline/pipeline.cpp
src/spirv_to_llvm/core_instructions.cpp
src/spirv_to_llvm/spirv_to_llvm.cpp
src/spirv_to_llvm/spirv_to_llvm.h

index 19fcbb4e109b2248c60a4ca878eb50db7f1bb37c..2cb457f7d64d4296a5d1b8af83bb543827c146ac 100644 (file)
@@ -296,6 +296,12 @@ struct Graphics_pipeline::Implementation
             {
                 assert(!"dumping matrix not implemented");
                 throw std::runtime_error("dumping matrix not implemented");
+#warning dumping matrix not implemented
+            }
+            virtual void visit(spirv_to_llvm::Row_major_matrix_type_descriptor &type) override
+            {
+                assert(!"dumping matrix not implemented");
+                throw std::runtime_error("dumping matrix not implemented");
 #warning dumping matrix not implemented
             }
             virtual void visit(spirv_to_llvm::Array_type_descriptor &type) override
index 0c23e0c0f715eb87850c3a3ab36340c858ea84ce..eae8eb1befe7f3bd3bda921e47fb77d98dabc948 100644 (file)
@@ -377,8 +377,7 @@ void Spirv_to_llvm::handle_instruction_op_type_matrix(Op_type_matrix instruction
         state.type = std::make_shared<Matrix_type_descriptor>(
             state.decorations,
             get_type<Vector_type_descriptor>(instruction.column_type, instruction_start_index),
-            instruction.column_count,
-            target_data);
+            instruction.column_count);
         break;
     }
     case Stage::generate_code:
@@ -1585,6 +1584,13 @@ void Spirv_to_llvm::handle_instruction_op_access_chain(Op_access_chain instructi
                 }
                 void operator()(Matrix_type_descriptor &)
                 {
+#warning finish
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented composite type for OpAccessChain");
+                }
+                void operator()(Row_major_matrix_type_descriptor &)
+                {
 #warning finish
                     throw Parser_error(instruction_start_index,
                                        instruction_start_index,
@@ -1841,6 +1847,13 @@ void Spirv_to_llvm::handle_instruction_op_composite_construct(Op_composite_const
             }
             void operator()(Matrix_type_descriptor &)
             {
+#warning finish
+                throw Parser_error(instruction_start_index,
+                                   instruction_start_index,
+                                   "unimplemented result type for OpCompositeConstruct");
+            }
+            void operator()(Row_major_matrix_type_descriptor &)
+            {
 #warning finish
                 throw Parser_error(instruction_start_index,
                                    instruction_start_index,
@@ -1934,6 +1947,13 @@ void Spirv_to_llvm::handle_instruction_op_composite_extract(Op_composite_extract
                 }
                 void operator()(Matrix_type_descriptor &)
                 {
+#warning finish
+                    throw Parser_error(instruction_start_index,
+                                       instruction_start_index,
+                                       "unimplemented composite type for OpCompositeExtract");
+                }
+                void operator()(Row_major_matrix_type_descriptor &)
+                {
 #warning finish
                     throw Parser_error(instruction_start_index,
                                        instruction_start_index,
index 4864bdd931fb76f4d2e22e915a050a60d70b371e..d509fccdddb4e28dd75facd03d1b342158980f8f 100644 (file)
@@ -399,6 +399,13 @@ void Struct_type_descriptor::complete_type()
             }
             virtual void visit(Matrix_type_descriptor &type) override
             {
+#warning finish implementing member type
+                throw Parser_error(this_->instruction_start_index,
+                                   this_->instruction_start_index,
+                                   "unimplemented member type");
+            }
+            virtual void visit(Row_major_matrix_type_descriptor &type) override
+            {
 #warning finish implementing member type
                 throw Parser_error(this_->instruction_start_index,
                                    this_->instruction_start_index,
index f2a791512f592ce994aa125b334b3f2e9df376d1..f5ad1d43ebff7d170e746b6a218e7bde56bbf368 100644 (file)
@@ -104,11 +104,12 @@ struct Shader_interface
 class Simple_type_descriptor;
 class Vector_type_descriptor;
 class Matrix_type_descriptor;
+class Row_major_matrix_type_descriptor;
 class Array_type_descriptor;
 class Pointer_type_descriptor;
 class Function_type_descriptor;
 class Struct_type_descriptor;
-class Type_descriptor
+class Type_descriptor : public std::enable_shared_from_this<Type_descriptor>
 {
     Type_descriptor(const Type_descriptor &) = delete;
     Type_descriptor &operator=(const Type_descriptor &) = delete;
@@ -120,6 +121,7 @@ public:
         virtual void visit(Simple_type_descriptor &type) = 0;
         virtual void visit(Vector_type_descriptor &type) = 0;
         virtual void visit(Matrix_type_descriptor &type) = 0;
+        virtual void visit(Row_major_matrix_type_descriptor &type) = 0;
         virtual void visit(Array_type_descriptor &type) = 0;
         virtual void visit(Pointer_type_descriptor &type) = 0;
         virtual void visit(Function_type_descriptor &type) = 0;
@@ -137,6 +139,14 @@ public:
     virtual ~Type_descriptor() = default;
     virtual LLVM_type_and_alignment get_or_make_type() = 0;
     virtual void visit(Type_visitor &type_visitor) = 0;
+    virtual std::shared_ptr<Type_descriptor> get_row_major_type(::LLVMTargetDataRef target_data)
+    {
+        return shared_from_this();
+    }
+    virtual std::shared_ptr<Type_descriptor> get_column_major_type(::LLVMTargetDataRef target_data)
+    {
+        return shared_from_this();
+    }
     void visit(Type_visitor &&type_visitor)
     {
         visit(type_visitor);
@@ -160,6 +170,10 @@ public:
             {
                 std::forward<Fn>(fn)(type);
             }
+            virtual void visit(Row_major_matrix_type_descriptor &type) override
+            {
+                std::forward<Fn>(fn)(type);
+            }
             virtual void visit(Array_type_descriptor &type) override
             {
                 std::forward<Fn>(fn)(type);
@@ -295,25 +309,109 @@ public:
     }
 };
 
+class Array_type_descriptor final : public Type_descriptor
+{
+private:
+    LLVM_type_and_alignment type;
+    std::shared_ptr<Type_descriptor> element_type;
+    std::size_t element_count;
+    std::size_t instruction_start_index;
+    Recursion_checker_state recursion_checker_state;
+    std::weak_ptr<Type_descriptor> column_major_type;
+    std::weak_ptr<Type_descriptor> row_major_type;
+
+public:
+    explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
+                                   std::shared_ptr<Type_descriptor> element_type,
+                                   std::size_t element_count,
+                                   std::size_t instruction_start_index) noexcept
+        : Type_descriptor(std::move(decorations)),
+          type(),
+          element_type(std::move(element_type)),
+          element_count(element_count),
+          instruction_start_index(instruction_start_index)
+    {
+    }
+    virtual LLVM_type_and_alignment get_or_make_type() override
+    {
+        if(!type.type)
+        {
+            Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
+            auto llvm_element_type = element_type->get_or_make_type();
+            type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
+                                           llvm_element_type.alignment);
+        }
+        return type;
+    }
+    virtual void visit(Type_visitor &type_visitor) override
+    {
+        type_visitor.visit(*this);
+    }
+    virtual std::shared_ptr<Type_descriptor> get_row_major_type(
+        ::LLVMTargetDataRef target_data) override
+    {
+        auto retval = row_major_type.lock();
+        if(retval)
+            return retval;
+        auto row_major_element_type = element_type->get_row_major_type(target_data);
+        if(row_major_element_type == element_type)
+            retval = shared_from_this();
+        else
+            retval = std::make_shared<Array_type_descriptor>(decorations,
+                                                             std::move(row_major_element_type),
+                                                             element_count,
+                                                             instruction_start_index);
+        row_major_type = retval;
+        return retval;
+    }
+    virtual std::shared_ptr<Type_descriptor> get_column_major_type(
+        ::LLVMTargetDataRef target_data) override
+    {
+        auto retval = column_major_type.lock();
+        if(retval)
+            return retval;
+        auto column_major_element_type = element_type->get_column_major_type(target_data);
+        if(column_major_element_type == element_type)
+            retval = shared_from_this();
+        else
+            retval = std::make_shared<Array_type_descriptor>(decorations,
+                                                             std::move(column_major_element_type),
+                                                             element_count,
+                                                             instruction_start_index);
+        column_major_type = retval;
+        return retval;
+    }
+    const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
+    {
+        return element_type;
+    }
+    std::size_t get_element_count() const noexcept
+    {
+        return element_count;
+    }
+};
+
 class Matrix_type_descriptor final : public Type_descriptor
 {
+    friend class Row_major_matrix_type_descriptor;
+
 private:
     LLVM_type_and_alignment type;
     std::shared_ptr<Vector_type_descriptor> column_type;
     std::size_t column_count;
+    std::weak_ptr<Type_descriptor> row_major_type;
+    std::shared_ptr<Type_descriptor> make_row_major_type(::LLVMTargetDataRef target_data);
 
 public:
     explicit Matrix_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
                                     std::shared_ptr<Vector_type_descriptor> column_type,
-                                    std::size_t column_count,
-                                    ::LLVMTargetDataRef target_data) noexcept
+                                    std::size_t column_count) noexcept
         : Type_descriptor(std::move(decorations)),
-          type(Vector_type_descriptor::make_vector_type(column_type->get_element_type(),
-                                                        column_type->get_element_count()
-                                                            * column_count,
-                                                        target_data)),
+          type(::LLVMArrayType(column_type->get_or_make_type().type, column_count),
+               column_type->get_or_make_type().alignment),
           column_type(std::move(column_type)),
-          column_count(column_count)
+          column_count(column_count),
+          row_major_type()
     {
     }
     virtual LLVM_type_and_alignment get_or_make_type() override
@@ -332,54 +430,88 @@ public:
     {
         return column_count;
     }
+    virtual std::shared_ptr<Type_descriptor> get_row_major_type(
+        ::LLVMTargetDataRef target_data) override
+    {
+        auto retval = row_major_type.lock();
+        if(retval)
+            return retval;
+        retval = make_row_major_type(target_data);
+        row_major_type = retval;
+        return retval;
+    }
 };
 
-class Array_type_descriptor final : public Type_descriptor
+class Row_major_matrix_type_descriptor final : public Type_descriptor
 {
+    friend class Matrix_type_descriptor;
+
 private:
     LLVM_type_and_alignment type;
-    std::shared_ptr<Type_descriptor> element_type;
-    std::size_t element_count;
-    std::size_t instruction_start_index;
-    Recursion_checker_state recursion_checker_state;
+    std::shared_ptr<Vector_type_descriptor> row_type;
+    std::size_t row_count;
+    std::shared_ptr<Matrix_type_descriptor> column_major_type;
 
 public:
-    explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
-                                   std::shared_ptr<Type_descriptor> element_type,
-                                   std::size_t element_count,
-                                   std::size_t instruction_start_index) noexcept
+    explicit Row_major_matrix_type_descriptor(
+        std::vector<spirv::Decoration_with_parameters> decorations,
+        std::shared_ptr<Vector_type_descriptor> row_type,
+        std::size_t row_count) noexcept
         : Type_descriptor(std::move(decorations)),
-          type(),
-          element_type(std::move(element_type)),
-          element_count(element_count),
-          instruction_start_index(instruction_start_index)
+          type(::LLVMArrayType(row_type->get_or_make_type().type, row_count),
+               row_type->get_or_make_type().alignment),
+          row_type(std::move(row_type)),
+          row_count(row_count),
+          column_major_type()
     {
     }
     virtual LLVM_type_and_alignment get_or_make_type() override
     {
-        if(!type.type)
-        {
-            Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
-            auto llvm_element_type = element_type->get_or_make_type();
-            type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
-                                           llvm_element_type.alignment);
-        }
         return type;
     }
     virtual void visit(Type_visitor &type_visitor) override
     {
         type_visitor.visit(*this);
     }
-    const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
+    const std::shared_ptr<Vector_type_descriptor> &get_row_type() const noexcept
     {
-        return element_type;
+        return row_type;
     }
-    std::size_t get_element_count() const noexcept
+    std::size_t get_row_count() const noexcept
     {
-        return element_count;
+        return row_count;
+    }
+    virtual std::shared_ptr<Type_descriptor> get_column_major_type(
+        ::LLVMTargetDataRef target_data) override
+    {
+        if(column_major_type)
+            return column_major_type;
+        auto column_type = std::make_shared<Vector_type_descriptor>(
+            std::vector<spirv::Decoration_with_parameters>{},
+            row_type->get_element_type(),
+            row_count,
+            target_data);
+        column_major_type = std::make_shared<Matrix_type_descriptor>(
+            decorations, std::move(column_type), row_type->get_element_count());
+        column_major_type->row_major_type = std::static_pointer_cast<Row_major_matrix_type_descriptor>(shared_from_this());
+        return column_major_type;
     }
 };
 
+inline std::shared_ptr<Type_descriptor> Matrix_type_descriptor::make_row_major_type(
+    ::LLVMTargetDataRef target_data)
+{
+    auto row_type =
+        std::make_shared<Vector_type_descriptor>(std::vector<spirv::Decoration_with_parameters>{},
+                                                 column_type->get_element_type(),
+                                                 column_count,
+                                                 target_data);
+    auto retval = std::make_shared<Row_major_matrix_type_descriptor>(
+        decorations, std::move(row_type), column_type->get_element_count());
+    retval->column_major_type = std::static_pointer_cast<Matrix_type_descriptor>(shared_from_this());
+    return retval;
+}
+
 class Pointer_type_descriptor final : public Type_descriptor
 {
 private: