class Simple_type_descriptor;
class Vector_type_descriptor;
class Matrix_type_descriptor;
+class Row_major_matrix_type_descriptor;
class Array_type_descriptor;
class Pointer_type_descriptor;
class Function_type_descriptor;
class Struct_type_descriptor;
-class Type_descriptor
+class Type_descriptor : public std::enable_shared_from_this<Type_descriptor>
{
Type_descriptor(const Type_descriptor &) = delete;
Type_descriptor &operator=(const Type_descriptor &) = delete;
virtual void visit(Simple_type_descriptor &type) = 0;
virtual void visit(Vector_type_descriptor &type) = 0;
virtual void visit(Matrix_type_descriptor &type) = 0;
+ virtual void visit(Row_major_matrix_type_descriptor &type) = 0;
virtual void visit(Array_type_descriptor &type) = 0;
virtual void visit(Pointer_type_descriptor &type) = 0;
virtual void visit(Function_type_descriptor &type) = 0;
virtual ~Type_descriptor() = default;
virtual LLVM_type_and_alignment get_or_make_type() = 0;
virtual void visit(Type_visitor &type_visitor) = 0;
+ virtual std::shared_ptr<Type_descriptor> get_row_major_type(::LLVMTargetDataRef target_data)
+ {
+ return shared_from_this();
+ }
+ virtual std::shared_ptr<Type_descriptor> get_column_major_type(::LLVMTargetDataRef target_data)
+ {
+ return shared_from_this();
+ }
void visit(Type_visitor &&type_visitor)
{
visit(type_visitor);
{
std::forward<Fn>(fn)(type);
}
+ virtual void visit(Row_major_matrix_type_descriptor &type) override
+ {
+ std::forward<Fn>(fn)(type);
+ }
virtual void visit(Array_type_descriptor &type) override
{
std::forward<Fn>(fn)(type);
}
};
+class Array_type_descriptor final : public Type_descriptor
+{
+private:
+ LLVM_type_and_alignment type;
+ std::shared_ptr<Type_descriptor> element_type;
+ std::size_t element_count;
+ std::size_t instruction_start_index;
+ Recursion_checker_state recursion_checker_state;
+ std::weak_ptr<Type_descriptor> column_major_type;
+ std::weak_ptr<Type_descriptor> row_major_type;
+
+public:
+ explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
+ std::shared_ptr<Type_descriptor> element_type,
+ std::size_t element_count,
+ std::size_t instruction_start_index) noexcept
+ : Type_descriptor(std::move(decorations)),
+ type(),
+ element_type(std::move(element_type)),
+ element_count(element_count),
+ instruction_start_index(instruction_start_index)
+ {
+ }
+ virtual LLVM_type_and_alignment get_or_make_type() override
+ {
+ if(!type.type)
+ {
+ Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
+ auto llvm_element_type = element_type->get_or_make_type();
+ type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
+ llvm_element_type.alignment);
+ }
+ return type;
+ }
+ virtual void visit(Type_visitor &type_visitor) override
+ {
+ type_visitor.visit(*this);
+ }
+ virtual std::shared_ptr<Type_descriptor> get_row_major_type(
+ ::LLVMTargetDataRef target_data) override
+ {
+ auto retval = row_major_type.lock();
+ if(retval)
+ return retval;
+ auto row_major_element_type = element_type->get_row_major_type(target_data);
+ if(row_major_element_type == element_type)
+ retval = shared_from_this();
+ else
+ retval = std::make_shared<Array_type_descriptor>(decorations,
+ std::move(row_major_element_type),
+ element_count,
+ instruction_start_index);
+ row_major_type = retval;
+ return retval;
+ }
+ virtual std::shared_ptr<Type_descriptor> get_column_major_type(
+ ::LLVMTargetDataRef target_data) override
+ {
+ auto retval = column_major_type.lock();
+ if(retval)
+ return retval;
+ auto column_major_element_type = element_type->get_column_major_type(target_data);
+ if(column_major_element_type == element_type)
+ retval = shared_from_this();
+ else
+ retval = std::make_shared<Array_type_descriptor>(decorations,
+ std::move(column_major_element_type),
+ element_count,
+ instruction_start_index);
+ column_major_type = retval;
+ return retval;
+ }
+ const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
+ {
+ return element_type;
+ }
+ std::size_t get_element_count() const noexcept
+ {
+ return element_count;
+ }
+};
+
class Matrix_type_descriptor final : public Type_descriptor
{
+ friend class Row_major_matrix_type_descriptor;
+
private:
LLVM_type_and_alignment type;
std::shared_ptr<Vector_type_descriptor> column_type;
std::size_t column_count;
+ std::weak_ptr<Type_descriptor> row_major_type;
+ std::shared_ptr<Type_descriptor> make_row_major_type(::LLVMTargetDataRef target_data);
public:
explicit Matrix_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
std::shared_ptr<Vector_type_descriptor> column_type,
- std::size_t column_count,
- ::LLVMTargetDataRef target_data) noexcept
+ std::size_t column_count) noexcept
: Type_descriptor(std::move(decorations)),
- type(Vector_type_descriptor::make_vector_type(column_type->get_element_type(),
- column_type->get_element_count()
- * column_count,
- target_data)),
+ type(::LLVMArrayType(column_type->get_or_make_type().type, column_count),
+ column_type->get_or_make_type().alignment),
column_type(std::move(column_type)),
- column_count(column_count)
+ column_count(column_count),
+ row_major_type()
{
}
virtual LLVM_type_and_alignment get_or_make_type() override
{
return column_count;
}
+ virtual std::shared_ptr<Type_descriptor> get_row_major_type(
+ ::LLVMTargetDataRef target_data) override
+ {
+ auto retval = row_major_type.lock();
+ if(retval)
+ return retval;
+ retval = make_row_major_type(target_data);
+ row_major_type = retval;
+ return retval;
+ }
};
-class Array_type_descriptor final : public Type_descriptor
+class Row_major_matrix_type_descriptor final : public Type_descriptor
{
+ friend class Matrix_type_descriptor;
+
private:
LLVM_type_and_alignment type;
- std::shared_ptr<Type_descriptor> element_type;
- std::size_t element_count;
- std::size_t instruction_start_index;
- Recursion_checker_state recursion_checker_state;
+ std::shared_ptr<Vector_type_descriptor> row_type;
+ std::size_t row_count;
+ std::shared_ptr<Matrix_type_descriptor> column_major_type;
public:
- explicit Array_type_descriptor(std::vector<spirv::Decoration_with_parameters> decorations,
- std::shared_ptr<Type_descriptor> element_type,
- std::size_t element_count,
- std::size_t instruction_start_index) noexcept
+ explicit Row_major_matrix_type_descriptor(
+ std::vector<spirv::Decoration_with_parameters> decorations,
+ std::shared_ptr<Vector_type_descriptor> row_type,
+ std::size_t row_count) noexcept
: Type_descriptor(std::move(decorations)),
- type(),
- element_type(std::move(element_type)),
- element_count(element_count),
- instruction_start_index(instruction_start_index)
+ type(::LLVMArrayType(row_type->get_or_make_type().type, row_count),
+ row_type->get_or_make_type().alignment),
+ row_type(std::move(row_type)),
+ row_count(row_count),
+ column_major_type()
{
}
virtual LLVM_type_and_alignment get_or_make_type() override
{
- if(!type.type)
- {
- Recursion_checker recursion_checker(recursion_checker_state, instruction_start_index);
- auto llvm_element_type = element_type->get_or_make_type();
- type = LLVM_type_and_alignment(::LLVMArrayType(llvm_element_type.type, element_count),
- llvm_element_type.alignment);
- }
return type;
}
virtual void visit(Type_visitor &type_visitor) override
{
type_visitor.visit(*this);
}
- const std::shared_ptr<Type_descriptor> &get_element_type() const noexcept
+ const std::shared_ptr<Vector_type_descriptor> &get_row_type() const noexcept
{
- return element_type;
+ return row_type;
}
- std::size_t get_element_count() const noexcept
+ std::size_t get_row_count() const noexcept
{
- return element_count;
+ return row_count;
+ }
+ virtual std::shared_ptr<Type_descriptor> get_column_major_type(
+ ::LLVMTargetDataRef target_data) override
+ {
+ if(column_major_type)
+ return column_major_type;
+ auto column_type = std::make_shared<Vector_type_descriptor>(
+ std::vector<spirv::Decoration_with_parameters>{},
+ row_type->get_element_type(),
+ row_count,
+ target_data);
+ column_major_type = std::make_shared<Matrix_type_descriptor>(
+ decorations, std::move(column_type), row_type->get_element_count());
+ column_major_type->row_major_type = std::static_pointer_cast<Row_major_matrix_type_descriptor>(shared_from_this());
+ return column_major_type;
}
};
+inline std::shared_ptr<Type_descriptor> Matrix_type_descriptor::make_row_major_type(
+ ::LLVMTargetDataRef target_data)
+{
+ auto row_type =
+ std::make_shared<Vector_type_descriptor>(std::vector<spirv::Decoration_with_parameters>{},
+ column_type->get_element_type(),
+ column_count,
+ target_data);
+ auto retval = std::make_shared<Row_major_matrix_type_descriptor>(
+ decorations, std::move(row_type), column_type->get_element_count());
+ retval->column_major_type = std::static_pointer_cast<Matrix_type_descriptor>(shared_from_this());
+ return retval;
+}
+
class Pointer_type_descriptor final : public Type_descriptor
{
private: