From: Jacob Lifshay Date: Mon, 25 Sep 2017 00:40:18 +0000 (-0700) Subject: implementing uniforms; implemented matrix multiplication X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=46d4bf1ef1060130950c5e9dcf006a57a7c6fb81;p=kazan.git implementing uniforms; implemented matrix multiplication --- diff --git a/src/demo/demo.cpp b/src/demo/demo.cpp index 400da19..673a87d 100644 --- a/src/demo/demo.cpp +++ b/src/demo/demo.cpp @@ -901,12 +901,20 @@ int test_main(int argc, char **argv) void *bindings[binding_count] = { vertexes.data(), }; - graphics_pipeline->run( - vertex_start_index, vertex_end_index, instance_id, *color_attachment, bindings); + struct Uniforms + { + }; + Uniforms uniforms{}; + graphics_pipeline->run(vertex_start_index, + vertex_end_index, + instance_id, + *color_attachment, + bindings, + &uniforms); typedef std::uint32_t Pixel_type; // check Pixel_type static_assert(std::is_voidrun_fragment_shader( - static_cast(nullptr)))>>::value, + static_cast(nullptr), nullptr))>>::value, ""); auto rgba = [](std::uint8_t r, std::uint8_t g, diff --git a/src/llvm_wrapper/llvm_wrapper.cpp b/src/llvm_wrapper/llvm_wrapper.cpp index b4faf66..777f244 100644 --- a/src/llvm_wrapper/llvm_wrapper.cpp +++ b/src/llvm_wrapper/llvm_wrapper.cpp @@ -25,11 +25,14 @@ #include #include #include +#include #include #include #include #include #include +#include +#include namespace kazan { @@ -148,6 +151,22 @@ unsigned Target_machine::get_biggest_vector_register_bit_width(::LLVMTargetMachi .getRegisterBitWidth(true); } +LLVM_intrinsic_id get_llvm_intrinsic_id(Intrinsic intrinsic) noexcept +{ + using llvm::Intrinsic::ID; + auto cvt = [](ID v) noexcept + { + return static_cast(static_cast(v)); + }; + switch(intrinsic) + { + case Intrinsic::fmuladd: + return cvt(ID::fmuladd); + } + assert(false); + return LLVM_intrinsic_id::Not_intrinsic; +} + void Module::set_target_machine(::LLVMModuleRef module, ::LLVMTargetMachineRef target_machine) { ::LLVMSetTarget(module, Target_machine::get_target_triple(target_machine).get()); @@ -163,5 +182,32 @@ void Module::set_function_target_machine(::LLVMValueRef function, ::LLVMAddTargetDependentFunctionAttr( function, "target-features", Target_machine::get_feature_string(target_machine).get()); } + +::LLVMValueRef Module::get_intrinsic_declaration(::LLVMModuleRef module, + LLVM_intrinsic_id llvm_intrinsic_id, + const ::LLVMTypeRef *types, + std::size_t type_count) +{ + auto *module_pointer = llvm::unwrap(module); + constexpr std::size_t array_size = 4; + llvm::Type *on_stack_array[array_size]; + std::unique_ptr on_heap_array; + llvm::Type **unwrapped_types; + if(type_count > array_size) + { + on_heap_array.reset(new llvm::Type *[type_count]); + unwrapped_types = on_heap_array.get(); + } + else + { + unwrapped_types = on_stack_array; + } + for(std::size_t i = 0; i < type_count; i++) + unwrapped_types[i] = llvm::unwrap(types[i]); + return llvm::wrap(llvm::Intrinsic::getDeclaration( + module_pointer, + static_cast(static_cast(llvm_intrinsic_id)), + llvm::ArrayRef(unwrapped_types, type_count))); +} } } diff --git a/src/llvm_wrapper/llvm_wrapper.h b/src/llvm_wrapper/llvm_wrapper.h index 9ac2d07..c79e286 100644 --- a/src/llvm_wrapper/llvm_wrapper.h +++ b/src/llvm_wrapper/llvm_wrapper.h @@ -37,6 +37,7 @@ #include #include #include +#include #include "util/string_view.h" #include "util/variant.h" @@ -317,6 +318,19 @@ struct Target_machine : public Wrapper<::LLVMTargetMachineRef, Target_machine_de } }; +enum class Intrinsic // doesn't match llvm::Intrinsic::ID +{ + fmuladd, +}; + +enum class LLVM_intrinsic_id : unsigned +{ + Not_intrinsic = 0, + Maximum_intrinsic_id = static_cast(-1) +}; + +LLVM_intrinsic_id get_llvm_intrinsic_id(Intrinsic intrinsic) noexcept; + struct Module_deleter { void operator()(::LLVMModuleRef module) const noexcept @@ -347,6 +361,27 @@ struct Module : public Wrapper<::LLVMModuleRef, Module_deleter> { set_target_machine(get(), target_machine); } + static ::LLVMValueRef get_intrinsic_declaration(::LLVMModuleRef module, + LLVM_intrinsic_id llvm_intrinsic_id, + const ::LLVMTypeRef *types, + std::size_t type_count); + ::LLVMValueRef get_intrinsic_declaration(LLVM_intrinsic_id llvm_intrinsic_id, + const ::LLVMTypeRef *types, + std::size_t type_count) + { + return get_intrinsic_declaration(get(), llvm_intrinsic_id, types, type_count); + } + static ::LLVMValueRef get_intrinsic_declaration(::LLVMModuleRef module, + LLVM_intrinsic_id llvm_intrinsic_id, + std::initializer_list<::LLVMTypeRef> types) + { + return get_intrinsic_declaration(module, llvm_intrinsic_id, types.begin(), types.size()); + } + ::LLVMValueRef get_intrinsic_declaration(LLVM_intrinsic_id llvm_intrinsic_id, + std::initializer_list<::LLVMTypeRef> types) + { + return get_intrinsic_declaration(get(), llvm_intrinsic_id, types); + } }; inline LLVM_string print_type_to_string(::LLVMTypeRef type) @@ -401,6 +436,32 @@ struct Builder : public Wrapper<::LLVMBuilderRef, Builder_deleter> { return build_smod(get(), lhs, rhs, result_name); } + static ::LLVMValueRef build_fmuladd(::LLVMBuilderRef builder, + ::LLVMModuleRef module, + ::LLVMValueRef factor1, + ::LLVMValueRef factor2, + ::LLVMValueRef term, + const char *result_name) + { + auto type = ::LLVMTypeOf(factor1); + assert(type == ::LLVMTypeOf(factor2)); + assert(type == ::LLVMTypeOf(term)); + auto intrinsic = Module::get_intrinsic_declaration( + module, get_llvm_intrinsic_id(Intrinsic::fmuladd), {type}); + constexpr std::size_t arg_count = 3; + ::LLVMValueRef args[arg_count] = { + factor1, factor2, term, + }; + return ::LLVMBuildCall(builder, intrinsic, args, arg_count, result_name); + } + ::LLVMValueRef build_fmuladd(::LLVMModuleRef module, + ::LLVMValueRef factor1, + ::LLVMValueRef factor2, + ::LLVMValueRef term, + const char *result_name) const + { + return build_fmuladd(get(), module, factor1, factor2, term, result_name); + } }; struct Pass_manager_deleter diff --git a/src/pipeline/pipeline.cpp b/src/pipeline/pipeline.cpp index ed9db49..064a070 100644 --- a/src/pipeline/pipeline.cpp +++ b/src/pipeline/pipeline.cpp @@ -471,7 +471,8 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index, std::uint32_t vertex_end_index, std::uint32_t instance_id, const vulkan::Vulkan_image &color_attachment, - void *const *bindings) + void *const *bindings, + void *uniforms) { typedef std::uint32_t Pixel_type; assert(color_attachment.descriptor.tiling == VK_IMAGE_TILING_LINEAR); @@ -684,7 +685,8 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index, current_vertex_start_index + chunk_size, instance_id, chunk_vertex_buffer.get(), - bindings); + bindings, + uniforms); const unsigned char *current_vertex = chunk_vertex_buffer.get() + vertex_shader_position_output_offset; triangles.clear(); @@ -932,7 +934,7 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index, static_cast(color_attachment_memory) + (static_cast(x) * color_attachment_pixel_size + static_cast(y) * color_attachment_stride)); - fs(pixel); + fs(pixel, uniforms); } } } diff --git a/src/pipeline/pipeline.h b/src/pipeline/pipeline.h index d5af527..2f5e424 100644 --- a/src/pipeline/pipeline.h +++ b/src/pipeline/pipeline.h @@ -161,33 +161,40 @@ public: std::uint32_t vertex_end_index, std::uint32_t instance_id, void *output_buffer, - void *const *bindings); - typedef void (*Fragment_shader_function)(std::uint32_t *color_attachment_pixel); + void *const *input_bindings, + void *uniforms); + typedef void (*Fragment_shader_function)(std::uint32_t *color_attachment_pixel, void *uniforms); public: void run_vertex_shader(std::uint32_t vertex_start_index, std::uint32_t vertex_end_index, std::uint32_t instance_id, void *output_buffer, - void *const *input_bindings) const noexcept + void *const *input_bindings, + void *uniforms) const noexcept { - vertex_shader_function( - vertex_start_index, vertex_end_index, instance_id, output_buffer, input_bindings); + vertex_shader_function(vertex_start_index, + vertex_end_index, + instance_id, + output_buffer, + input_bindings, + uniforms); } std::size_t get_vertex_shader_output_struct_size() const noexcept { return vertex_shader_output_struct_size; } void dump_vertex_shader_output_struct(const void *output_struct) const; - void run_fragment_shader(std::uint32_t *color_attachment_pixel) const noexcept + void run_fragment_shader(std::uint32_t *color_attachment_pixel, void *uniforms) const noexcept { - fragment_shader_function(color_attachment_pixel); + fragment_shader_function(color_attachment_pixel, uniforms); } void run(std::uint32_t vertex_start_index, std::uint32_t vertex_end_index, std::uint32_t instance_id, const vulkan::Vulkan_image &color_attachment, - void *const *bindings); + void *const *input_bindings, + void *uniforms); static std::unique_ptr create( vulkan::Vulkan_device &, Pipeline_cache *pipeline_cache, diff --git a/src/spirv_to_llvm/core_instructions.cpp b/src/spirv_to_llvm/core_instructions.cpp index 1d4b766..2e244ce 100644 --- a/src/spirv_to_llvm/core_instructions.cpp +++ b/src/spirv_to_llvm/core_instructions.cpp @@ -1056,7 +1056,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, case Stage::calculate_types: { auto &state = get_id_state(instruction.result); - bool check_decorations = true; + bool parse_decorations = true; [&]() { switch(instruction.storage_class) @@ -1077,7 +1077,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, Input_variable_state{type, inputs_struct->add_member(Struct_type_descriptor::Member( state.decorations, type))}; - check_decorations = false; + parse_decorations = false; return; } case Storage_class::uniform: @@ -1086,7 +1086,10 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, throw Parser_error(instruction_start_index, instruction_start_index, "shader uniform variable initializers are not implemented"); - state.variable = Uniform_variable_state{}; + auto type = get_type(instruction.result_type, + instruction_start_index) + ->get_base_type(); + state.variable = Uniform_variable_state(type); return; } case Storage_class::output: @@ -1102,7 +1105,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, Output_variable_state{type, outputs_struct->add_member(Struct_type_descriptor::Member( state.decorations, type))}; - check_decorations = false; + parse_decorations = false; return; } case Storage_class::workgroup: @@ -1143,7 +1146,7 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, "unimplemented OpVariable storage class: " + std::string(get_enumerant_name(instruction.storage_class))); }(); - if(check_decorations) + if(parse_decorations) { for(auto &decoration : state.decorations) { @@ -1244,9 +1247,13 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, break; case Decoration::binding: { + auto ¶meters = + util::get(decoration.parameters); switch(instruction.storage_class) { case spirv::Storage_class::uniform: + util::get(state.variable).binding = + parameters.binding_point; continue; #warning finish implementing Decoration::binding default: @@ -1260,9 +1267,13 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, } case Decoration::descriptor_set: { + auto ¶meters = util::get( + decoration.parameters); switch(instruction.storage_class) { case spirv::Storage_class::uniform: + util::get(state.variable).descriptor_set = + parameters.descriptor_set; continue; #warning finish implementing Decoration::descriptor_set default: @@ -1383,7 +1394,100 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, } case Storage_class::uniform: #warning finish implementing Storage_class::uniform - break; + { + if(instruction.initializer) + throw Parser_error(instruction_start_index, + instruction_start_index, + "shader uniform variable initializers are not implemented"); + auto set_value_fn = [this, instruction, &state, instruction_start_index]() + { + auto &variable = util::get(state.variable); + if(!variable.binding) + throw Parser_error(instruction_start_index, + instruction_start_index, + "shader uniform variable is missing a Binding decoration"); + if(!variable.descriptor_set) + throw Parser_error( + instruction_start_index, + instruction_start_index, + "shader uniform variable is missing a DescriptorSet decoration"); + auto binding_number = *variable.binding; + auto descriptor_set_number = *variable.descriptor_set; + if(descriptor_set_number >= pipeline_layout.descriptor_sets.size()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "DescriptorSet decoration's value is out of range"); + auto &descriptor_set = pipeline_layout.descriptor_sets[descriptor_set_number]; + if(binding_number >= descriptor_set.bindings.size()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "Binding decoration's value is out of range"); + auto &binding = descriptor_set.bindings[binding_number]; + auto &uniforms_struct_member = + pipeline_layout.type->get_members(true)[binding.member_index]; + auto uniform_slot_address = ::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id).function->entry_block->uniforms_struct, + uniforms_struct_member.llvm_member_index, + ""); + auto result_type = get_type(instruction.result_type, instruction_start_index); + ::LLVMValueRef result = nullptr; + switch(binding.base->descriptor_type) + { + case VK_DESCRIPTOR_TYPE_SAMPLER: +#warning implement VK_DESCRIPTOR_TYPE_SAMPLER uniform variables + break; + case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: +#warning implement VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER uniform variables + break; + case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE: +#warning implement VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE uniform variables + break; + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: +#warning implement VK_DESCRIPTOR_TYPE_STORAGE_IMAGE uniform variables + break; + case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER: +#warning implement VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER uniform variables + break; + case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER: +#warning implement VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER uniform variables + break; + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: + result = + ::LLVMBuildBitCast(builder.get(), + ::LLVMBuildLoad(builder.get(), uniform_slot_address, ""), + result_type->get_or_make_type().type, + ""); + break; + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: +#warning implement VK_DESCRIPTOR_TYPE_STORAGE_BUFFER uniform variables + break; + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC: +#warning implement VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC uniform variables + break; + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: +#warning implement VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC uniform variables + break; + case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT: +#warning implement VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT uniform variables + break; + case VK_DESCRIPTOR_TYPE_RANGE_SIZE: + case VK_DESCRIPTOR_TYPE_MAX_ENUM: + break; + } + if(result == nullptr) + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented uniform descriptor type"); + ::LLVMSetValueName(result, get_name(instruction.result).c_str()); + state.value = Value(result, std::move(result_type)); + }; + if(current_function_id) + set_value_fn(); + else + function_entry_block_handlers.push_back(set_value_fn); + return; + } case Storage_class::output: { if(instruction.initializer) @@ -1460,7 +1564,10 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, #warning finish implementing Storage_class::storage_buffer break; } - break; + throw Parser_error(instruction_start_index, + instruction_start_index, + "unimplemented OpVariable storage class: " + + std::string(get_enumerant_name(instruction.storage_class))); } } } @@ -1525,6 +1632,7 @@ void Spirv_to_llvm::handle_instruction_op_load(Op_load instruction, builder.get(), get_id_state(instruction.pointer).value.value().value, ""); ::LLVMSetAlignment(untransposed_value, memory_type->get_or_make_type().alignment); state.value = Value(matrix_operations::transpose(context, + module.get(), builder.get(), untransposed_value, get_name(instruction.result).c_str()), @@ -1574,10 +1682,8 @@ void Spirv_to_llvm::handle_instruction_op_store(Op_store instruction, break; case Type_descriptor::Load_store_implementation_kind::Transpose_matrix: { - auto transposed_value = matrix_operations::transpose(context, - builder.get(), - object_value.value, - ""); + auto transposed_value = matrix_operations::transpose( + context, module.get(), builder.get(), object_value.value, ""); ::LLVMSetAlignment( ::LLVMBuildStore(builder.get(), transposed_value, pointer_value.value), memory_type->get_or_make_type().alignment); @@ -2845,21 +2951,113 @@ void Spirv_to_llvm::handle_instruction_op_vector_times_matrix(Op_vector_times_ma void Spirv_to_llvm::handle_instruction_op_matrix_times_vector(Op_matrix_times_vector instruction, std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + if(!state.decorations.empty()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "decorations on instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + auto result_type = + get_type(instruction.result_type, instruction_start_index); + auto &matrix = get_id_state(instruction.matrix).value.value(); + auto &vector = get_id_state(instruction.vector).value.value(); + auto matrix_type = std::dynamic_pointer_cast(matrix.type); + if(!matrix_type) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesVector matrix operand type mismatch: not a matrix"); + auto vector_type = std::dynamic_pointer_cast(vector.type); + if(!vector_type) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesVector vector operand type mismatch: not a vector"); + if(matrix_type->get_row_count() != result_type->get_element_count()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesVector matrix operand type mismatch: row count " + "doesn't match result_type's element count"); + if(matrix_type->get_column_count() != vector_type->get_element_count()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesVector matrix operand type mismatch: column " + "count doesn't match vector's element count"); + state.value = + Value(matrix_operations::matrix_times_vector(context, + module.get(), + builder.get(), + matrix.value, + vector.value, + get_name(instruction.result).c_str()), + result_type); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_matrix_times_matrix(Op_matrix_times_matrix instruction, std::size_t instruction_start_index) { -#warning finish - throw Parser_error(instruction_start_index, - instruction_start_index, - "instruction not implemented: " - + std::string(get_enumerant_name(instruction.get_operation()))); + switch(stage) + { + case Stage::calculate_types: + break; + case Stage::generate_code: + { + auto &state = get_id_state(instruction.result); + if(!state.decorations.empty()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "decorations on instruction not implemented: " + + std::string(get_enumerant_name(instruction.get_operation()))); + auto result_type = + get_type(instruction.result_type, instruction_start_index); + auto &left_matrix = get_id_state(instruction.left_matrix).value.value(); + auto &right_matrix = get_id_state(instruction.right_matrix).value.value(); + auto left_matrix_type = std::dynamic_pointer_cast(left_matrix.type); + if(!left_matrix_type) + throw Parser_error( + instruction_start_index, + instruction_start_index, + "OpMatrixTimesMatrix left_matrix operand type mismatch: not a matrix"); + auto right_matrix_type = + std::dynamic_pointer_cast(right_matrix.type); + if(!right_matrix_type) + throw Parser_error( + instruction_start_index, + instruction_start_index, + "OpMatrixTimesMatrix right_matrix operand type mismatch: not a matrix"); + if(left_matrix_type->get_row_count() != result_type->get_row_count()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesMatrix left_matrix operand type mismatch: row count " + "doesn't match result_type's row count"); + if(right_matrix_type->get_column_count() != result_type->get_column_count()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesMatrix right_matrix operand type mismatch: column " + "count doesn't match result_type's column count"); + if(left_matrix_type->get_column_count() != right_matrix_type->get_row_count()) + throw Parser_error(instruction_start_index, + instruction_start_index, + "OpMatrixTimesMatrix left_matrix operand type mismatch: column " + "count doesn't match right_matrix's row count"); + state.value = + Value(matrix_operations::matrix_multiply(context, + module.get(), + builder.get(), + left_matrix.value, + right_matrix.value, + get_name(instruction.result).c_str()), + result_type); + break; + } + } } void Spirv_to_llvm::handle_instruction_op_outer_product(Op_outer_product instruction, @@ -3800,8 +3998,19 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, io_struct->get_members(true)[this->outputs_member].llvm_member_index, "outputs_pointer"), "outputs"); - function.entry_block = Function_state::Entry_block( - block, io_struct_value, inputs_struct_value, outputs_struct_value); + auto uniforms_struct_value = ::LLVMBuildLoad( + builder.get(), + ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->uniforms_member].llvm_member_index, + "uniforms_pointer"), + "uniforms"); + function.entry_block = Function_state::Entry_block(block, + io_struct_value, + inputs_struct_value, + outputs_struct_value, + uniforms_struct_value); for(auto iter = function_entry_block_handlers.begin(); iter != function_entry_block_handlers.end();) { diff --git a/src/spirv_to_llvm/fragment_entry_point.cpp b/src/spirv_to_llvm/fragment_entry_point.cpp index 0ae3979..ce7fe7d 100644 --- a/src/spirv_to_llvm/fragment_entry_point.cpp +++ b/src/spirv_to_llvm/fragment_entry_point.cpp @@ -38,8 +38,9 @@ using namespace spirv; auto llvm_vec4_type = ::LLVMVectorType(llvm_float_type, 4); auto llvm_u8vec4_type = ::LLVMVectorType(llvm_u8_type, 4); static_cast(llvm_pixel_type); - typedef void (*Fragment_shader_function)(Pixel_type *color_attachment_pixel); + typedef void (*Fragment_shader_function)(Pixel_type *color_attachment_pixel, void *uniforms); constexpr std::size_t arg_color_attachment_pixel = 0; + constexpr std::size_t arg_uniforms = 1; static_assert(std::is_same::value, "vertex shader function signature mismatch"); @@ -49,6 +50,8 @@ using namespace spirv; llvm_wrapper::Module::set_function_target_machine(entry_function, target_machine); auto color_attachment_pixel = ::LLVMGetParam(entry_function, arg_color_attachment_pixel); ::LLVMSetValueName(color_attachment_pixel, "color_attachment_pixel"); + auto uniforms = ::LLVMGetParam(entry_function, arg_uniforms); + ::LLVMSetValueName(uniforms, "uniforms"); auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry"); ::LLVMPositionBuilderAtEnd(builder.get(), entry_block); auto io_struct_type = io_struct->get_or_make_type(); @@ -634,7 +637,8 @@ using namespace spirv; else if(member_index == uniforms_member) { #warning implement shader uniforms - assert(this->pipeline_layout.descriptor_sets.empty() && "shader uniforms not implemented"); + assert(this->pipeline_layout.descriptor_sets.empty() + && "shader uniforms not implemented"); } else { @@ -702,8 +706,9 @@ using namespace spirv; auto packed_output_color = ::LLVMBuildBitCast( builder.get(), converted_output_color, llvm_pixel_type, "packed_output_color"); ::LLVMBuildStore(builder.get(), packed_output_color, color_attachment_pixel); - static_assert( - std::is_same()(nullptr)), void>::value, ""); + static_assert(std::is_same()(nullptr, nullptr)), + void>::value, + ""); ::LLVMBuildRetVoid(builder.get()); return entry_function; } diff --git a/src/spirv_to_llvm/matrix_operations.h b/src/spirv_to_llvm/matrix_operations.h index 4b50332..5c716c8 100644 --- a/src/spirv_to_llvm/matrix_operations.h +++ b/src/spirv_to_llvm/matrix_operations.h @@ -60,7 +60,27 @@ struct Matrix_descriptor } }; +struct Vector_descriptor +{ + std::uint32_t element_count; + ::LLVMTypeRef element_type; + ::LLVMTypeRef vector_type; + explicit Vector_descriptor(::LLVMTypeRef vector_type) noexcept : vector_type(vector_type) + { + assert(::LLVMGetTypeKind(vector_type) == ::LLVMVectorTypeKind); + element_count = ::LLVMGetVectorSize(vector_type); + element_type = ::LLVMGetElementType(vector_type); + } + Vector_descriptor(::LLVMTypeRef element_type, std::uint32_t element_count) + : element_count(element_count), + element_type(element_type), + vector_type(::LLVMVectorType(element_type, element_count)) + { + } +}; + inline ::LLVMValueRef transpose(::LLVMContextRef context, + ::LLVMModuleRef module, ::LLVMBuilderRef builder, ::LLVMValueRef input_matrix, const char *output_name) @@ -100,6 +120,92 @@ inline ::LLVMValueRef transpose(::LLVMContextRef context, ::LLVMSetValueName(output_value, output_name); return output_value; } + +inline ::LLVMValueRef vector_broadcast_from_vector(::LLVMContextRef context, + ::LLVMBuilderRef builder, + ::LLVMValueRef input_vector, + std::uint32_t input_vector_index, + std::uint32_t output_vector_length, + const char *output_name) +{ + auto i32_type = llvm_wrapper::Create_llvm_type()(context); + auto index = ::LLVMConstInt(i32_type, input_vector_index, false); + std::vector<::LLVMValueRef> shuffle_arguments(output_vector_length, index); + auto shuffle_index_vector = + ::LLVMConstVector(shuffle_arguments.data(), shuffle_arguments.size()); + return ::LLVMBuildShuffleVector(builder, + input_vector, + ::LLVMGetUndef(::LLVMTypeOf(input_vector)), + shuffle_index_vector, + output_name); +} + +inline ::LLVMValueRef matrix_multiply(::LLVMContextRef context, + ::LLVMModuleRef module, + ::LLVMBuilderRef builder, + ::LLVMValueRef left_matrix, + ::LLVMValueRef right_matrix, + const char *output_name) +{ + Matrix_descriptor left_matrix_descriptor(::LLVMTypeOf(left_matrix)); + Matrix_descriptor right_matrix_descriptor(::LLVMTypeOf(right_matrix)); + assert(left_matrix_descriptor.element_type == right_matrix_descriptor.element_type); + assert(left_matrix_descriptor.columns == right_matrix_descriptor.rows); + assert(left_matrix_descriptor.columns != 0); + assert(left_matrix_descriptor.rows != 0); + assert(right_matrix_descriptor.columns != 0); + Matrix_descriptor result_matrix_descriptor(left_matrix_descriptor.element_type, + left_matrix_descriptor.rows, + right_matrix_descriptor.columns); + ::LLVMValueRef retval = ::LLVMGetUndef(result_matrix_descriptor.matrix_type); + for(std::size_t i = 0; i < right_matrix_descriptor.columns; i++) + { + ::LLVMValueRef right_matrix_column = ::LLVMBuildExtractValue(builder, right_matrix, i, ""); + ::LLVMValueRef sum{}; + for(std::size_t j = 0; j < left_matrix_descriptor.columns; j++) + { + auto factor0 = ::LLVMBuildExtractValue(builder, left_matrix, j, ""); + auto factor1 = vector_broadcast_from_vector( + context, builder, right_matrix_column, j, left_matrix_descriptor.rows, ""); + if(j == 0) + sum = ::LLVMBuildFMul(builder, factor0, factor1, ""); + else + sum = llvm_wrapper::Builder::build_fmuladd( + builder, module, factor0, factor1, sum, ""); + } + retval = ::LLVMBuildInsertValue(builder, retval, sum, i, ""); + } + ::LLVMSetValueName(retval, output_name); + return retval; +} + +inline ::LLVMValueRef matrix_times_vector(::LLVMContextRef context, + ::LLVMModuleRef module, + ::LLVMBuilderRef builder, + ::LLVMValueRef matrix, + ::LLVMValueRef input_vector, + const char *output_name) +{ + Matrix_descriptor matrix_descriptor(::LLVMTypeOf(matrix)); + Vector_descriptor input_vector_descriptor(::LLVMTypeOf(input_vector)); + assert(matrix_descriptor.element_type == input_vector_descriptor.element_type); + assert(matrix_descriptor.columns == input_vector_descriptor.element_count); + assert(matrix_descriptor.columns != 0); + ::LLVMValueRef retval{}; + for(std::size_t i = 0; i < matrix_descriptor.columns; i++) + { + auto factor0 = ::LLVMBuildExtractValue(builder, matrix, i, ""); + auto factor1 = vector_broadcast_from_vector( + context, builder, input_vector, i, matrix_descriptor.rows, ""); + if(i == 0) + retval = ::LLVMBuildFMul(builder, factor0, factor1, ""); + else + retval = + llvm_wrapper::Builder::build_fmuladd(builder, module, factor0, factor1, retval, ""); + } + ::LLVMSetValueName(retval, output_name); + return retval; +} } } } diff --git a/src/spirv_to_llvm/spirv_to_llvm.cpp b/src/spirv_to_llvm/spirv_to_llvm.cpp index 44e2134..cf6e066 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.cpp +++ b/src/spirv_to_llvm/spirv_to_llvm.cpp @@ -209,11 +209,14 @@ void Struct_type_descriptor::complete_type() { std::size_t alignment; std::size_t size; + util::optional offset; ::LLVMTypeRef type; explicit Member_descriptor(std::size_t alignment, std::size_t size, + util::optional offset, ::LLVMTypeRef type) noexcept : alignment(alignment), size(size), + offset(offset), type(type) { } @@ -223,6 +226,9 @@ void Struct_type_descriptor::complete_type() std::size_t total_alignment = 1; for(auto &member : members) { + util::optional is_row_major; + util::optional offset; + util::optional matrix_stride; for(auto &decoration : member.decorations) { switch(decoration.value) @@ -240,17 +246,20 @@ void Struct_type_descriptor::complete_type() #warning finish implementing Decoration::buffer_block break; case Decoration::row_major: -#warning finish implementing Decoration::row_major - break; + is_row_major = true; + continue; case Decoration::col_major: -#warning finish implementing Decoration::col_major - break; + is_row_major = false; + continue; case Decoration::array_stride: #warning finish implementing Decoration::array_stride break; case Decoration::matrix_stride: -#warning finish implementing Decoration::matrix_stride - break; + { + auto ¶meters = util::get(decoration.parameters); + matrix_stride = parameters.matrix_stride; + continue; + } case Decoration::glsl_shared: #warning finish implementing Decoration::glsl_shared break; @@ -326,8 +335,11 @@ void Struct_type_descriptor::complete_type() #warning finish implementing Decoration::descriptor_set break; case Decoration::offset: -#warning finish implementing Decoration::offset - break; + { + auto ¶meters = util::get(decoration.parameters); + offset = parameters.byte_offset; + continue; + } case Decoration::xfb_buffer: #warning finish implementing Decoration::xfb_buffer break; @@ -382,6 +394,14 @@ void Struct_type_descriptor::complete_type() "unimplemented member decoration on OpTypeStruct: " + std::string(get_enumerant_name(decoration.value))); } + if(is_row_major) + { + if(*is_row_major) + member.type = member.type->get_row_major_type(target_data); + else + member.type = member.type->get_column_major_type(target_data); + } + assert(matrix_stride == member.type->get_matrix_stride(target_data) && "MatrixStride decoration unimplemented for non-default strides"); auto member_type = member.type->get_or_make_type(); std::size_t size = ::LLVMABISizeOfType(target_data, member_type.type); struct Member_type_visitor : public Type_descriptor::Type_visitor @@ -400,16 +420,10 @@ void Struct_type_descriptor::complete_type() virtual void visit(Matrix_type_descriptor &type) override { #warning finish implementing member type - throw Parser_error(this_->instruction_start_index, - this_->instruction_start_index, - "unimplemented member type"); } virtual void visit(Row_major_matrix_type_descriptor &type) override { #warning finish implementing member type - throw Parser_error(this_->instruction_start_index, - this_->instruction_start_index, - "unimplemented member type"); } virtual void visit(Array_type_descriptor &type) override { @@ -452,7 +466,7 @@ void Struct_type_descriptor::complete_type() if(member_type.alignment > total_alignment) total_alignment = member_type.alignment; member_descriptors.push_back( - Member_descriptor(member_type.alignment, size, member_type.type)); + Member_descriptor(member_type.alignment, size, offset, member_type.type)); } assert(member_descriptors.size() == members.size()); assert(is_power_of_2(total_alignment)); @@ -463,6 +477,17 @@ void Struct_type_descriptor::complete_type() { for(std::size_t member_index = 0; member_index < members.size(); member_index++) { + if(member_descriptors[member_index].offset) + { + assert(*member_descriptors[member_index].offset >= current_offset); + auto padding_size = *member_descriptors[member_index].offset - current_offset; + if(padding_size != 0) + { + member_types.push_back( + ::LLVMArrayType(::LLVMInt8TypeInContext(context), padding_size)); + current_offset += padding_size; + } + } members[member_index].llvm_member_index = member_types.size(); #warning finish Struct_type_descriptor::complete_type member_types.push_back(member_descriptors[member_index].type); diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index 7cc0d7c..e7dc142 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -153,6 +153,10 @@ public: { return shared_from_this(); } + virtual util::optional get_matrix_stride(::LLVMTargetDataRef target_data) const + { + return {}; + } void visit(Type_visitor &&type_visitor) { visit(type_visitor); @@ -396,6 +400,10 @@ public: column_major_type = retval; return retval; } + virtual util::optional get_matrix_stride(::LLVMTargetDataRef target_data) const override + { + return element_type->get_matrix_stride(target_data); + } const std::shared_ptr &get_element_type() const noexcept { return element_type; @@ -445,6 +453,14 @@ public: { return column_count; } + std::size_t get_row_count() const noexcept + { + return column_type->get_element_count(); + } + const std::shared_ptr &get_element_type() const noexcept + { + return column_type->get_element_type(); + } virtual std::shared_ptr get_row_major_type( ::LLVMTargetDataRef target_data) override { @@ -455,6 +471,10 @@ public: row_major_type = retval; return retval; } + virtual util::optional get_matrix_stride(::LLVMTargetDataRef target_data) const override + { + return ::LLVMABISizeOfType(target_data, column_type->get_or_make_type().type); + } }; class Row_major_matrix_type_descriptor final : public Type_descriptor @@ -496,6 +516,14 @@ public: { return row_count; } + std::size_t get_column_count() const noexcept + { + return row_type->get_element_count(); + } + const std::shared_ptr &get_element_type() const noexcept + { + return row_type->get_element_type(); + } virtual std::shared_ptr get_column_major_type( ::LLVMTargetDataRef target_data) override { diff --git a/src/spirv_to_llvm/spirv_to_llvm_implementation.h b/src/spirv_to_llvm/spirv_to_llvm_implementation.h index 83d8dfe..95f00e2 100644 --- a/src/spirv_to_llvm/spirv_to_llvm_implementation.h +++ b/src/spirv_to_llvm/spirv_to_llvm_implementation.h @@ -82,6 +82,15 @@ private: }; struct Uniform_variable_state { + std::shared_ptr type; + util::optional binding; + util::optional descriptor_set; + explicit Uniform_variable_state(std::shared_ptr type) noexcept + : type(std::move(type)), + binding(), + descriptor_set() + { + } }; typedef util::variant::value, "vertex shader function signature mismatch"); @@ -63,6 +65,7 @@ using namespace spirv; ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_instance_id), "instance_id"); ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer_"); ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_bindings), "bindings"); + ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_uniforms), "uniforms"); auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry"); auto loop_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "loop"); auto exit_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "exit"); @@ -452,6 +455,26 @@ using namespace spirv; "unimplemented vertex input variable type conversion"); break; } + case VK_FORMAT_R32G32B32_SFLOAT: + { + constexpr std::size_t vector_element_count = 3; + format_type = + Vector_type_descriptor( + std::vector{}, + std::make_shared( + std::vector{}, + LLVM_type_and_alignment(llvm_float_type, + llvm_float_type_alignment)), + vector_element_count, + target_data) + .get_or_make_type(); + if(input_type.type != format_type.type) + throw Parser_error( + 0, + 0, + "unimplemented vertex input variable type conversion"); + break; + } #warning implement all required formats default: throw Parser_error(0, 0, "unimplemented vertex input format"); @@ -853,7 +876,8 @@ using namespace spirv; else if(member_index == uniforms_member) { #warning implement shader uniforms - assert(this->pipeline_layout.descriptor_sets.empty() && "shader uniforms not implemented"); + assert(this->pipeline_layout.descriptor_sets.empty() + && "shader uniforms not implemented"); } else { @@ -882,7 +906,7 @@ using namespace spirv; ::LLVMBuildCondBr(builder.get(), next_iteration_condition, loop_block, exit_block); ::LLVMPositionBuilderAtEnd(builder.get(), exit_block); static_assert( - std::is_same()(0, 0, 0, nullptr, nullptr)), + std::is_same()(0, 0, 0, nullptr, nullptr, nullptr)), void>::value, ""); ::LLVMBuildRetVoid(builder.get());