From: Jacob Lifshay Date: Fri, 22 Sep 2017 23:40:03 +0000 (-0700) Subject: added row-major matrix type X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=8ccb7df0fef7935f6b392454fec5ace380241d54;p=kazan.git added row-major matrix type --- diff --git a/src/pipeline/pipeline.cpp b/src/pipeline/pipeline.cpp index 19fcbb4..2cb457f 100644 --- a/src/pipeline/pipeline.cpp +++ b/src/pipeline/pipeline.cpp @@ -296,6 +296,12 @@ struct Graphics_pipeline::Implementation { assert(!"dumping matrix not implemented"); throw std::runtime_error("dumping matrix not implemented"); +#warning dumping matrix not implemented + } + virtual void visit(spirv_to_llvm::Row_major_matrix_type_descriptor &type) override + { + assert(!"dumping matrix not implemented"); + throw std::runtime_error("dumping matrix not implemented"); #warning dumping matrix not implemented } virtual void visit(spirv_to_llvm::Array_type_descriptor &type) override diff --git a/src/spirv_to_llvm/core_instructions.cpp b/src/spirv_to_llvm/core_instructions.cpp index 0c23e0c..eae8eb1 100644 --- a/src/spirv_to_llvm/core_instructions.cpp +++ b/src/spirv_to_llvm/core_instructions.cpp @@ -377,8 +377,7 @@ void Spirv_to_llvm::handle_instruction_op_type_matrix(Op_type_matrix instruction state.type = std::make_shared( state.decorations, get_type(instruction.column_type, instruction_start_index), - instruction.column_count, - target_data); + instruction.column_count); break; } case Stage::generate_code: @@ -1585,6 +1584,13 @@ void Spirv_to_llvm::handle_instruction_op_access_chain(Op_access_chain instructi } void operator()(Matrix_type_descriptor &) { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented composite type for OpAccessChain"); + } + void operator()(Row_major_matrix_type_descriptor &) + { #warning finish throw Parser_error(instruction_start_index, instruction_start_index, @@ -1841,6 +1847,13 @@ void Spirv_to_llvm::handle_instruction_op_composite_construct(Op_composite_const } void operator()(Matrix_type_descriptor &) { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented result type for OpCompositeConstruct"); + } + void operator()(Row_major_matrix_type_descriptor &) + { #warning finish throw Parser_error(instruction_start_index, instruction_start_index, @@ -1934,6 +1947,13 @@ void Spirv_to_llvm::handle_instruction_op_composite_extract(Op_composite_extract } void operator()(Matrix_type_descriptor &) { +#warning finish + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented composite type for OpCompositeExtract"); + } + void operator()(Row_major_matrix_type_descriptor &) + { #warning finish throw Parser_error(instruction_start_index, instruction_start_index, diff --git a/src/spirv_to_llvm/spirv_to_llvm.cpp b/src/spirv_to_llvm/spirv_to_llvm.cpp index 4864bdd..d509fcc 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.cpp +++ b/src/spirv_to_llvm/spirv_to_llvm.cpp @@ -399,6 +399,13 @@ void Struct_type_descriptor::complete_type() } virtual void visit(Matrix_type_descriptor &type) override { +#warning finish implementing member type + throw Parser_error(this_->instruction_start_index, + this_->instruction_start_index, + "unimplemented member type"); + } + virtual void visit(Row_major_matrix_type_descriptor &type) override + { #warning finish implementing member type throw Parser_error(this_->instruction_start_index, this_->instruction_start_index, diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index f2a7915..f5ad1d4 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -104,11 +104,12 @@ struct Shader_interface class Simple_type_descriptor; class Vector_type_descriptor; class Matrix_type_descriptor; +class Row_major_matrix_type_descriptor; class Array_type_descriptor; class Pointer_type_descriptor; class Function_type_descriptor; class Struct_type_descriptor; -class Type_descriptor +class Type_descriptor : public std::enable_shared_from_this { Type_descriptor(const Type_descriptor &) = delete; Type_descriptor &operator=(const Type_descriptor &) = delete; @@ -120,6 +121,7 @@ public: virtual void visit(Simple_type_descriptor &type) = 0; virtual void visit(Vector_type_descriptor &type) = 0; virtual void visit(Matrix_type_descriptor &type) = 0; + virtual void visit(Row_major_matrix_type_descriptor &type) = 0; virtual void visit(Array_type_descriptor &type) = 0; virtual void visit(Pointer_type_descriptor &type) = 0; virtual void visit(Function_type_descriptor &type) = 0; @@ -137,6 +139,14 @@ public: virtual ~Type_descriptor() = default; virtual LLVM_type_and_alignment get_or_make_type() = 0; virtual void visit(Type_visitor &type_visitor) = 0; + virtual std::shared_ptr get_row_major_type(::LLVMTargetDataRef target_data) + { + return shared_from_this(); + } + virtual std::shared_ptr get_column_major_type(::LLVMTargetDataRef target_data) + { + return shared_from_this(); + } void visit(Type_visitor &&type_visitor) { visit(type_visitor); @@ -160,6 +170,10 @@ public: { std::forward(fn)(type); } + virtual void visit(Row_major_matrix_type_descriptor &type) override + { + std::forward(fn)(type); + } virtual void visit(Array_type_descriptor &type) override { std::forward(fn)(type); @@ -295,25 +309,109 @@ public: } }; +class Array_type_descriptor final : public Type_descriptor +{ +private: + LLVM_type_and_alignment type; + std::shared_ptr element_type; + std::size_t element_count; + std::size_t instruction_start_index; + Recursion_checker_state recursion_checker_state; + std::weak_ptr column_major_type; + std::weak_ptr row_major_type; + +public: + explicit Array_type_descriptor(std::vector decorations, + std::shared_ptr element_type, + std::size_t element_count, + std::size_t instruction_start_index) noexcept + : Type_descriptor(std::move(decorations)), + type(), + element_type(std::move(element_type)), + element_count(element_count), + instruction_start_index(instruction_start_index) + { + } + virtual LLVM_type_and_alignment get_or_make_type() override + { + if(!type.type) + { + Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index); + auto llvm_element_type = element_type->get_or_make_type(); + type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count), + llvm_element_type.alignment); + } + return type; + } + virtual void visit(Type_visitor &type_visitor) override + { + type_visitor.visit(*this); + } + virtual std::shared_ptr get_row_major_type( + ::LLVMTargetDataRef target_data) override + { + auto retval = row_major_type.lock(); + if(retval) + return retval; + auto row_major_element_type = element_type->get_row_major_type(target_data); + if(row_major_element_type == element_type) + retval = shared_from_this(); + else + retval = std::make_shared(decorations, + std::move(row_major_element_type), + element_count, + instruction_start_index); + row_major_type = retval; + return retval; + } + virtual std::shared_ptr get_column_major_type( + ::LLVMTargetDataRef target_data) override + { + auto retval = column_major_type.lock(); + if(retval) + return retval; + auto column_major_element_type = element_type->get_column_major_type(target_data); + if(column_major_element_type == element_type) + retval = shared_from_this(); + else + retval = std::make_shared(decorations, + std::move(column_major_element_type), + element_count, + instruction_start_index); + column_major_type = retval; + return retval; + } + const std::shared_ptr &get_element_type() const noexcept + { + return element_type; + } + std::size_t get_element_count() const noexcept + { + return element_count; + } +}; + class Matrix_type_descriptor final : public Type_descriptor { + friend class Row_major_matrix_type_descriptor; + private: LLVM_type_and_alignment type; std::shared_ptr column_type; std::size_t column_count; + std::weak_ptr row_major_type; + std::shared_ptr make_row_major_type(::LLVMTargetDataRef target_data); public: explicit Matrix_type_descriptor(std::vector decorations, std::shared_ptr column_type, - std::size_t column_count, - ::LLVMTargetDataRef target_data) noexcept + std::size_t column_count) noexcept : Type_descriptor(std::move(decorations)), - type(Vector_type_descriptor::make_vector_type(column_type->get_element_type(), - column_type->get_element_count() - * column_count, - target_data)), + type(::LLVMArrayType(column_type->get_or_make_type().type, column_count), + column_type->get_or_make_type().alignment), column_type(std::move(column_type)), - column_count(column_count) + column_count(column_count), + row_major_type() { } virtual LLVM_type_and_alignment get_or_make_type() override @@ -332,54 +430,88 @@ public: { return column_count; } + virtual std::shared_ptr get_row_major_type( + ::LLVMTargetDataRef target_data) override + { + auto retval = row_major_type.lock(); + if(retval) + return retval; + retval = make_row_major_type(target_data); + row_major_type = retval; + return retval; + } }; -class Array_type_descriptor final : public Type_descriptor +class Row_major_matrix_type_descriptor final : public Type_descriptor { + friend class Matrix_type_descriptor; + private: LLVM_type_and_alignment type; - std::shared_ptr element_type; - std::size_t element_count; - std::size_t instruction_start_index; - Recursion_checker_state recursion_checker_state; + std::shared_ptr row_type; + std::size_t row_count; + std::shared_ptr column_major_type; public: - explicit Array_type_descriptor(std::vector decorations, - std::shared_ptr element_type, - std::size_t element_count, - std::size_t instruction_start_index) noexcept + explicit Row_major_matrix_type_descriptor( + std::vector decorations, + std::shared_ptr row_type, + std::size_t row_count) noexcept : Type_descriptor(std::move(decorations)), - type(), - element_type(std::move(element_type)), - element_count(element_count), - instruction_start_index(instruction_start_index) + type(::LLVMArrayType(row_type->get_or_make_type().type, row_count), + row_type->get_or_make_type().alignment), + row_type(std::move(row_type)), + row_count(row_count), + column_major_type() { } virtual LLVM_type_and_alignment get_or_make_type() override { - if(!type.type) - { - Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index); - auto llvm_element_type = element_type->get_or_make_type(); - type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count), - llvm_element_type.alignment); - } return type; } virtual void visit(Type_visitor &type_visitor) override { type_visitor.visit(*this); } - const std::shared_ptr &get_element_type() const noexcept + const std::shared_ptr &get_row_type() const noexcept { - return element_type; + return row_type; } - std::size_t get_element_count() const noexcept + std::size_t get_row_count() const noexcept { - return element_count; + return row_count; + } + virtual std::shared_ptr get_column_major_type( + ::LLVMTargetDataRef target_data) override + { + if(column_major_type) + return column_major_type; + auto column_type = std::make_shared( + std::vector{}, + row_type->get_element_type(), + row_count, + target_data); + column_major_type = std::make_shared( + decorations, std::move(column_type), row_type->get_element_count()); + column_major_type->row_major_type = std::static_pointer_cast(shared_from_this()); + return column_major_type; } }; +inline std::shared_ptr Matrix_type_descriptor::make_row_major_type( + ::LLVMTargetDataRef target_data) +{ + auto row_type = + std::make_shared(std::vector{}, + column_type->get_element_type(), + column_count, + target_data); + auto retval = std::make_shared( + decorations, std::move(row_type), column_type->get_element_count()); + retval->column_major_type = std::static_pointer_cast(shared_from_this()); + return retval; +} + class Pointer_type_descriptor final : public Type_descriptor { private: