From f2a297372fe4396f3de30203efd7292e1ff3854a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 17 Oct 2017 18:01:22 -0700 Subject: [PATCH] work in progress for adding inter-shader interfaces --- src/pipeline/pipeline.cpp | 92 ++- src/spirv_to_llvm/core_instructions.cpp | 187 +++++- src/spirv_to_llvm/fragment_entry_point.cpp | 8 +- src/spirv_to_llvm/spirv_to_llvm.cpp | 17 +- src/spirv_to_llvm/spirv_to_llvm.h | 621 +++++++++++++++++- .../spirv_to_llvm_implementation.h | 165 ++++- src/spirv_to_llvm/vertex_entry_point.cpp | 16 +- 7 files changed, 1012 insertions(+), 94 deletions(-) diff --git a/src/pipeline/pipeline.cpp b/src/pipeline/pipeline.cpp index 064a070..a80dd86 100644 --- a/src/pipeline/pipeline.cpp +++ b/src/pipeline/pipeline.cpp @@ -47,7 +47,8 @@ Instantiated_pipeline_layout::Instantiated_pipeline_layout(vulkan::Vulkan_pipeli llvm_context, target_data, "pipeline_layout", - 0)) + 0, + spirv_to_llvm::Struct_type_descriptor::Layout_kind::Default)) { auto void_pointer_type = std::make_shared( std::vector{}, @@ -942,6 +943,47 @@ void Graphics_pipeline::run(std::uint32_t vertex_start_index, } } +namespace +{ +constexpr util::optional get_shader_stage_order( + spirv::Execution_model execution_model) noexcept +{ + switch(execution_model) + { + case spirv::Execution_model::vertex: + return 0; + case spirv::Execution_model::tessellation_control: + return 1; + case spirv::Execution_model::tessellation_evaluation: + return 2; + case spirv::Execution_model::geometry: + return 3; + case spirv::Execution_model::fragment: + return 4; + case spirv::Execution_model::gl_compute: + case spirv::Execution_model::kernel: + return {}; + } + assert(!"unknown execution model"); + return {}; +} + +constexpr bool are_shader_stage_enumerants_ordered_properly() noexcept +{ + util::optional last_stage_order; + for(auto execution_model : util::Enum_traits::values) + { + auto current_stage_order = get_shader_stage_order(execution_model); + if(!current_stage_order) + continue; + if(last_stage_order >= current_stage_order) + return false; + last_stage_order = current_stage_order; + } + return true; +} +} + std::unique_ptr Graphics_pipeline::create( vulkan::Vulkan_device &, Pipeline_cache *pipeline_cache, @@ -967,7 +1009,10 @@ std::unique_ptr Graphics_pipeline::create( implementation->instantiated_pipeline_layout = std::make_unique( *pipeline_layout, implementation->llvm_context.get(), implementation->data_layout.get()); implementation->compiled_shaders.reserve(create_info.stageCount); - util::Enum_set found_shader_stages; + util::Enum_map found_shader_stages; + // the iteration order of shader stages must match the order in which + // outputs are connected to the next stage's inputs + static_assert(are_shader_stage_enumerants_ordered_properly(), ""); for(std::size_t i = 0; i < create_info.stageCount; i++) { auto &stage_info = create_info.pStages[i]; @@ -976,9 +1021,17 @@ std::unique_ptr Graphics_pipeline::create( assert(execution_models.size() == 1); auto execution_model = *execution_models.begin(); bool added_to_found_shader_stages = - std::get<1>(found_shader_stages.insert(execution_model)); + std::get<1>(found_shader_stages.emplace(execution_model, i)); if(!added_to_found_shader_stages) throw std::runtime_error("duplicate shader stage"); + } + if(found_shader_stages.count(spirv::Execution_model::vertex) == 0) + throw std::runtime_error("graphics pipeline is missing vertex shader"); + util::optional last_shader_index_by_shader_stage_order; + for(auto &found_shader_stage : found_shader_stages) + { + auto &stage_info = create_info.pStages[std::get<1>(found_shader_stage)]; + auto execution_model = std::get<0>(found_shader_stage); auto *shader_module = Shader_module::from_handle(stage_info.module); assert(shader_module); { @@ -997,6 +1050,31 @@ std::unique_ptr Graphics_pipeline::create( assert(create_info.pVertexInputState); assert(create_info.pVertexInputState->sType == VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO); + const spirv_to_llvm::Shader_interface *previous_stage_output_shader_interface = nullptr; + const spirv_to_llvm::Shader_interface *previous_stage_built_in_output_shader_interface = + nullptr; + switch(execution_model) + { + case spirv::Execution_model::vertex: + case spirv::Execution_model::gl_compute: + case spirv::Execution_model::kernel: + break; + case spirv::Execution_model::tessellation_control: + case spirv::Execution_model::tessellation_evaluation: + case spirv::Execution_model::geometry: + case spirv::Execution_model::fragment: + { + assert(last_shader_index_by_shader_stage_order); + auto &previous_stage = + implementation->compiled_shaders[*last_shader_index_by_shader_stage_order]; + previous_stage_output_shader_interface = previous_stage.output_shader_interface.get(); + assert(previous_stage_output_shader_interface); + previous_stage_built_in_output_shader_interface = + previous_stage.built_in_output_shader_interface.get(); + assert(previous_stage_built_in_output_shader_interface); + break; + } + } auto compiled_shader = spirv_to_llvm::spirv_to_llvm(implementation->llvm_context.get(), llvm_target_machine.get(), @@ -1006,13 +1084,17 @@ std::unique_ptr Graphics_pipeline::create( execution_model, stage_info.pName, create_info.pVertexInputState, - *implementation->instantiated_pipeline_layout); + *implementation->instantiated_pipeline_layout, + previous_stage_output_shader_interface, + previous_stage_built_in_output_shader_interface); std::cerr << "Translation to LLVM succeeded." << std::endl; ::LLVMDumpModule(compiled_shader.module.get()); bool failed = ::LLVMVerifyModule(compiled_shader.module.get(), ::LLVMPrintMessageAction, nullptr); if(failed) throw std::runtime_error("LLVM module verification failed"); + if(get_shader_stage_order(execution_model)) + last_shader_index_by_shader_stage_order = implementation->compiled_shaders.size(); implementation->compiled_shaders.push_back(std::move(compiled_shader)); } implementation->jit_stack = @@ -1053,7 +1135,7 @@ std::unique_ptr Graphics_pipeline::create( { vertex_shader_function = reinterpret_cast(shader_entry_point_address); - implementation->vertex_shader_output_struct = compiled_shader.outputs_struct; + implementation->vertex_shader_output_struct = compiled_shader.combined_outputs_struct; auto llvm_vertex_shader_output_struct = implementation->vertex_shader_output_struct->get_or_make_type().type; vertex_shader_output_struct_size = ::LLVMABISizeOfType( diff --git a/src/spirv_to_llvm/core_instructions.cpp b/src/spirv_to_llvm/core_instructions.cpp index 2e244ce..997cb2f 100644 --- a/src/spirv_to_llvm/core_instructions.cpp +++ b/src/spirv_to_llvm/core_instructions.cpp @@ -483,6 +483,7 @@ void Spirv_to_llvm::handle_instruction_op_type_struct(Op_type_struct instruction ::LLVMGetModuleDataLayout(module.get()), get_prefixed_name(get_name(instruction.result), false).c_str(), instruction_start_index, + Struct_type_descriptor::Layout_kind::Default, std::move(members)); break; } @@ -1073,10 +1074,53 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, auto type = get_type(instruction.result_type, instruction_start_index) ->get_base_type(); - state.variable = - Input_variable_state{type, - inputs_struct->add_member(Struct_type_descriptor::Member( - state.decorations, type))}; + bool is_built_in = false; + if(auto struct_type = dynamic_cast(type.get())) + { + bool has_any_non_built_in_members = false; + for(auto &member : struct_type->get_members(false)) + { + bool member_is_built_in = false; + for(auto &decoration : member.decorations) + { + if(decoration.value == spirv::Decoration::built_in) + { + member_is_built_in = true; + break; + } + } + if(!member_is_built_in) + has_any_non_built_in_members = true; + else + is_built_in = true; + } + if(is_built_in && has_any_non_built_in_members) + throw Parser_error( + instruction_start_index, + instruction_start_index, + "shader interface variable has both built-in and non-built-in members"); + } + if(!is_built_in) + { + for(auto &decoration : type->decorations) + { + if(decoration.value == spirv::Decoration::built_in) + { + is_built_in = true; + break; + } + } + } + if(is_built_in) + state.variable = Built_in_input_variable_state{ + type, + built_in_inputs_struct->add_member( + Struct_type_descriptor::Member(state.decorations, type))}; + else + state.variable = Input_variable_state{ + type, + inputs_struct->add_member( + Struct_type_descriptor::Member(state.decorations, type))}; parse_decorations = false; return; } @@ -1101,10 +1145,53 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, auto type = get_type(instruction.result_type, instruction_start_index) ->get_base_type(); - state.variable = - Output_variable_state{type, - outputs_struct->add_member(Struct_type_descriptor::Member( - state.decorations, type))}; + bool is_built_in = false; + if(auto struct_type = dynamic_cast(type.get())) + { + bool has_any_non_built_in_members = false; + for(auto &member : struct_type->get_members(false)) + { + bool member_is_built_in = false; + for(auto &decoration : member.decorations) + { + if(decoration.value == spirv::Decoration::built_in) + { + member_is_built_in = true; + break; + } + } + if(!member_is_built_in) + has_any_non_built_in_members = true; + else + is_built_in = true; + } + if(is_built_in && has_any_non_built_in_members) + throw Parser_error( + instruction_start_index, + instruction_start_index, + "shader interface variable has both built-in and non-built-in members"); + } + if(!is_built_in) + { + for(auto &decoration : type->decorations) + { + if(decoration.value == spirv::Decoration::built_in) + { + is_built_in = true; + break; + } + } + } + if(is_built_in) + state.variable = Built_in_output_variable_state{ + type, + built_in_outputs_struct->add_member( + Struct_type_descriptor::Member(state.decorations, type))}; + else + state.variable = Output_variable_state{ + type, + outputs_struct->add_member( + Struct_type_descriptor::Member(state.decorations, type))}; parse_decorations = false; return; } @@ -1377,14 +1464,31 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, } auto set_value_fn = [this, instruction, &state, instruction_start_index]() { - auto &variable = util::get(state.variable); - state.value = Value( - ::LLVMBuildStructGEP( - builder.get(), - get_id_state(current_function_id).function->entry_block->inputs_struct, - inputs_struct->get_members(true)[variable.member_index].llvm_member_index, - get_name(instruction.result).c_str()), - get_type(instruction.result_type, instruction_start_index)); + if(util::holds_alternative(state.variable)) + { + auto &variable = util::get(state.variable); + state.value = + Value(::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id) + .function->entry_block->built_in_inputs_struct, + built_in_inputs_struct->get_members(true)[variable.member_index] + .llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + } + else + { + auto &variable = util::get(state.variable); + state.value = Value( + ::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id).function->entry_block->inputs_struct, + inputs_struct->get_members(true)[variable.member_index] + .llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + } }; if(current_function_id) set_value_fn(); @@ -1502,14 +1606,31 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, } auto set_value_fn = [this, instruction, &state, instruction_start_index]() { - auto &variable = util::get(state.variable); - state.value = Value( - ::LLVMBuildStructGEP( - builder.get(), - get_id_state(current_function_id).function->entry_block->outputs_struct, - outputs_struct->get_members(true)[variable.member_index].llvm_member_index, - get_name(instruction.result).c_str()), - get_type(instruction.result_type, instruction_start_index)); + if(util::holds_alternative(state.variable)) + { + auto &variable = util::get(state.variable); + state.value = + Value(::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id) + .function->entry_block->built_in_outputs_struct, + built_in_outputs_struct->get_members(true)[variable.member_index] + .llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + } + else + { + auto &variable = util::get(state.variable); + state.value = Value( + ::LLVMBuildStructGEP( + builder.get(), + get_id_state(current_function_id).function->entry_block->outputs_struct, + outputs_struct->get_members(true)[variable.member_index] + .llvm_member_index, + get_name(instruction.result).c_str()), + get_type(instruction.result_type, instruction_start_index)); + } }; if(current_function_id) set_value_fn(); @@ -3998,6 +4119,22 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, io_struct->get_members(true)[this->outputs_member].llvm_member_index, "outputs_pointer"), "outputs"); + auto built_in_inputs_struct_value = ::LLVMBuildLoad( + builder.get(), + ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->built_in_inputs_member].llvm_member_index, + "built_in_inputs_pointer"), + "built_in_inputs"); + auto built_in_outputs_struct_value = ::LLVMBuildLoad( + builder.get(), + ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->built_in_outputs_member].llvm_member_index, + "built_in_outputs_pointer"), + "built_in_outputs"); auto uniforms_struct_value = ::LLVMBuildLoad( builder.get(), ::LLVMBuildStructGEP( @@ -4010,6 +4147,8 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, io_struct_value, inputs_struct_value, outputs_struct_value, + built_in_inputs_struct_value, + built_in_outputs_struct_value, uniforms_struct_value); for(auto iter = function_entry_block_handlers.begin(); iter != function_entry_block_handlers.end();) diff --git a/src/spirv_to_llvm/fragment_entry_point.cpp b/src/spirv_to_llvm/fragment_entry_point.cpp index ce7fe7d..431e441 100644 --- a/src/spirv_to_llvm/fragment_entry_point.cpp +++ b/src/spirv_to_llvm/fragment_entry_point.cpp @@ -28,9 +28,13 @@ namespace spirv_to_llvm { using namespace spirv; -::LLVMValueRef Spirv_to_llvm::generate_fragment_entry_function(Op_entry_point_state &entry_point, - ::LLVMValueRef main_function) +::LLVMValueRef Spirv_to_llvm::generate_fragment_entry_function( + Op_entry_point_state &entry_point, + ::LLVMValueRef main_function, + Shader_interface &input_shader_interface, + Shader_interface &built_in_input_shader_interface) { +#error finish adding shader interface code typedef std::uint32_t Pixel_type; auto llvm_pixel_type = llvm_wrapper::Create_llvm_type()(context); auto llvm_float_type = llvm_wrapper::Create_llvm_type()(context); diff --git a/src/spirv_to_llvm/spirv_to_llvm.cpp b/src/spirv_to_llvm/spirv_to_llvm.cpp index cf6e066..1cde43b 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.cpp +++ b/src/spirv_to_llvm/spirv_to_llvm.cpp @@ -256,7 +256,8 @@ void Struct_type_descriptor::complete_type() break; case Decoration::matrix_stride: { - auto ¶meters = util::get(decoration.parameters); + auto ¶meters = + util::get(decoration.parameters); matrix_stride = parameters.matrix_stride; continue; } @@ -336,7 +337,8 @@ void Struct_type_descriptor::complete_type() break; case Decoration::offset: { - auto ¶meters = util::get(decoration.parameters); + auto ¶meters = + util::get(decoration.parameters); offset = parameters.byte_offset; continue; } @@ -401,7 +403,8 @@ void Struct_type_descriptor::complete_type() else member.type = member.type->get_column_major_type(target_data); } - assert(matrix_stride == member.type->get_matrix_stride(target_data) && "MatrixStride decoration unimplemented for non-default strides"); + assert(matrix_stride == member.type->get_matrix_stride(target_data) + && "MatrixStride decoration unimplemented for non-default strides"); auto member_type = member.type->get_or_make_type(); std::size_t size = ::LLVMABISizeOfType(target_data, member_type.type); struct Member_type_visitor : public Type_descriptor::Type_visitor @@ -566,7 +569,9 @@ spirv_to_llvm::Converted_module spirv_to_llvm::spirv_to_llvm( spirv::Execution_model execution_model, util::string_view entry_point_name, const VkPipelineVertexInputStateCreateInfo *vertex_input_state, - pipeline::Instantiated_pipeline_layout &pipeline_layout) + pipeline::Instantiated_pipeline_layout &pipeline_layout, + const Shader_interface *previous_stage_output_shader_interface, + const Shader_interface *previous_stage_built_in_output_shader_interface) { return Spirv_to_llvm(context, target_machine, @@ -574,7 +579,9 @@ spirv_to_llvm::Converted_module spirv_to_llvm::spirv_to_llvm( execution_model, entry_point_name, vertex_input_state, - pipeline_layout) + pipeline_layout, + previous_stage_output_shader_interface, + previous_stage_built_in_output_shader_interface) .run(shader_words, shader_size); } } diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index e7dc142..3239924 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -32,6 +32,7 @@ #include #include #include +#include #include "llvm_wrapper/llvm_wrapper.h" #include "util/string_view.h" #include "vulkan/vulkan.h" @@ -62,48 +63,307 @@ struct LLVM_type_and_alignment } }; -struct Shader_interface +struct Shader_interface_position { + std::size_t value; + static constexpr int component_index_bit_width = 2; + static constexpr std::size_t component_index_count = 1ULL << component_index_bit_width; + static constexpr std::size_t component_index_mask = component_index_count - 1; + static constexpr std::size_t location_mask = ~component_index_mask; + static constexpr std::size_t location_shift_amount = component_index_bit_width; + constexpr std::uint32_t get_location() const noexcept + { + return (value & location_mask) >> location_shift_amount; + } + constexpr std::uint32_t get_component_index() const noexcept + { + return value & component_index_mask; + } + constexpr std::uint32_t get_components_left_in_current_location() const noexcept + { + return component_index_count - get_component_index(); + } + constexpr bool is_aligned_to_location() const noexcept + { + return get_component_index() == 0; + } + constexpr Shader_interface_position get_aligned_location_rounding_up() const noexcept + { + if(is_aligned_to_location()) + return *this; + return Shader_interface_position(get_location() + 1); + } + constexpr Shader_interface_position get_position_after_components(std::uint32_t count) const + noexcept + { + std::uint32_t result_component_index = get_component_index() + count; + std::uint32_t result_location = + get_location() + result_component_index / component_index_count; + result_component_index %= component_index_count; + return Shader_interface_position(result_location, result_component_index); + } + constexpr Shader_interface_position(std::uint32_t location, + std::uint8_t component_index) noexcept + : value((location << location_shift_amount) | component_index) + { + assert(location == get_location() && component_index == get_component_index()); + } + constexpr explicit Shader_interface_position(std::uint32_t location) noexcept + : Shader_interface_position(location, 0) + { + } + constexpr Shader_interface_position() noexcept : value(0) + { + } + Shader_interface_position( + spirv::Decoration_location_parameters location, + util::optional component) noexcept + : Shader_interface_position(location.location, component ? component->component : 0) + { + } + explicit Shader_interface_position( + const std::vector &decorations) + : Shader_interface_position() + { + util::optional location; + util::optional component; + for(auto &decoration : decorations) + { + switch(decoration.value) + { + case spirv::Decoration::location: + location = util::get(decoration.parameters); + break; + case spirv::Decoration::component: + component = + util::get(decoration.parameters); + break; + default: + break; + } + } + if(!location) + throw spirv::Parser_error(0, 0, "missing Location decoration"); + *this = Shader_interface_position(*location, component); + } + friend constexpr bool operator==(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value == b.value; + } + friend constexpr bool operator!=(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value != b.value; + } + friend constexpr bool operator<(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value < b.value; + } + friend constexpr bool operator>(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value > b.value; + } + friend constexpr bool operator<=(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value <= b.value; + } + friend constexpr bool operator>=(Shader_interface_position a, + Shader_interface_position b) noexcept + { + return a.value >= b.value; + } +}; + +/** represents the range [begin_position, end_position) */ +struct Shader_interface_range +{ + Shader_interface_position begin_position; + Shader_interface_position end_position; + constexpr bool empty() const noexcept + { + return end_position == begin_position; + } + constexpr bool overlaps(Shader_interface_range other) const noexcept + { + if(begin_position >= other.end_position) + return false; + if(other.begin_position >= end_position) + return false; + return true; + } +}; + +class Type_descriptor; + +class Shader_interface +{ +public: /** uses a single type for both signed and unsigned integer variants */ - enum class Location_type + enum class Component_type { + Int8, + Int16, Int32, + Int64, + Float16, Float32, + Float64, }; + static constexpr std::uint32_t get_type_component_count( + Component_type type, std::size_t vector_element_count) noexcept + { + std::size_t size_in_bytes = 0; + switch(type) + { + case Component_type::Int8: + size_in_bytes = sizeof(std::uint8_t); + break; + case Component_type::Int16: + size_in_bytes = sizeof(std::uint16_t); + break; + case Component_type::Int32: + size_in_bytes = sizeof(std::uint32_t); + break; + case Component_type::Int64: + size_in_bytes = sizeof(std::uint64_t); + break; + case Component_type::Float16: + size_in_bytes = sizeof(std::uint16_t); + break; + case Component_type::Float32: + size_in_bytes = sizeof(float); + break; + case Component_type::Float64: + size_in_bytes = sizeof(double); + break; + } + assert(size_in_bytes != 0); + assert(vector_element_count >= 1 && vector_element_count <= 4); + size_in_bytes *= vector_element_count; + constexpr std::size_t component_size_in_bytes = sizeof(float); + static_assert(component_size_in_bytes == 4, ""); + return (size_in_bytes + component_size_in_bytes - 1) / component_size_in_bytes; + } + static util::optional get_component_type_for_llvm_scalar_type( + ::LLVMTypeRef type) + { + util::optional component_type; + switch(::LLVMGetTypeKind(type)) + { + case ::LLVMHalfTypeKind: + return Shader_interface::Component_type::Float16; + case ::LLVMFloatTypeKind: + return Shader_interface::Component_type::Float32; + case ::LLVMDoubleTypeKind: + return Shader_interface::Component_type::Float64; + case ::LLVMIntegerTypeKind: + { + auto bit_width = ::LLVMGetIntTypeWidth(type); + switch(bit_width) + { + case 8: + return Shader_interface::Component_type::Int8; + case 16: + return Shader_interface::Component_type::Int16; + case 32: + return Shader_interface::Component_type::Int32; + case 64: + return Shader_interface::Component_type::Int64; + default: + break; + } + break; + } + case ::LLVMVoidTypeKind: + case ::LLVMX86_FP80TypeKind: + case ::LLVMFP128TypeKind: + case ::LLVMPPC_FP128TypeKind: + case ::LLVMLabelTypeKind: + case ::LLVMFunctionTypeKind: + case ::LLVMStructTypeKind: + case ::LLVMArrayTypeKind: + case ::LLVMPointerTypeKind: + case ::LLVMVectorTypeKind: + case ::LLVMMetadataTypeKind: + case ::LLVMX86_MMXTypeKind: + case ::LLVMTokenTypeKind: + break; + } + return {}; + } enum class Interpolation_kind { Perspective, Linear, Flat, }; - struct Location_descriptor + struct Variable { - Location_type type; - util::bitset<4> used_components; + Component_type type; Interpolation_kind interpolation_kind; - constexpr Location_descriptor() noexcept : type(), used_components(), interpolation_kind() + Shader_interface_range range; + std::vector indexes; + std::shared_ptr base_type; + Variable() noexcept : type(), interpolation_kind(), range(), indexes(), base_type() { } - constexpr Location_descriptor(Location_type type, - util::bitset<4> used_components, - Interpolation_kind interpolation_kind) noexcept + Variable(Component_type type, + Interpolation_kind interpolation_kind, + Shader_interface_range range, + std::vector indexes, + std::shared_ptr base_type) noexcept : type(type), - used_components(used_components), - interpolation_kind(interpolation_kind) + interpolation_kind(interpolation_kind), + range(range), + indexes(std::move(indexes)), + base_type(std::move(base_type)) { } - constexpr explicit operator bool() const noexcept + explicit operator bool() const noexcept { - return used_components.any(); + return !range.empty(); } }; - std::vector locations; - Shader_interface() noexcept : locations() + +private: + std::vector variables; + bool is_sorted; + +private: + void sort_variables() noexcept + { + std::stable_sort(variables.begin(), + variables.end(), + [](const Variable &a, const Variable &b) noexcept + { + return a.range.begin_position < b.range.begin_position; + }); + is_sorted = true; + } + +public: + Shader_interface() noexcept : variables() + { + } + explicit Shader_interface(std::vector variables) noexcept + : variables(std::move(variables)), + is_sorted(false) { } - explicit Shader_interface(std::vector locations) noexcept - : locations(std::move(locations)) + const std::vector &get_sorted_variables() noexcept { + if(!is_sorted) + sort_variables(); + return variables; + } + void add(const Variable &variable) + { + variables.push_back(variable); + is_sorted = false; } }; @@ -255,6 +515,46 @@ public: { return Load_store_implementation_kind::Simple; } + util::optional find_decoration( + spirv::Decoration decoration_id) const + { + for(auto &decoration : decorations) + if(decoration.value == decoration_id) + return decoration; + return {}; + } + struct Shader_interface_index_list_item + { + const Shader_interface_index_list_item *prev; + std::size_t index; + }; + static std::vector shader_interface_index_list_to_vector( + const Shader_interface_index_list_item *index_list) + { + std::size_t size = 0; + for(auto *p = index_list; p; p = p->prev) + size++; + std::vector retval(size); + std::size_t i = size - 1; + for(auto *p = index_list; p; p = p->prev) + retval[i--] = p->index; + return retval; + } + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) = 0; + void add_to_shader_interface(Shader_interface &shader_interface) + { + util::optional current_position; + add_to_shader_interface(shader_interface, + current_position, + Shader_interface::Interpolation_kind::Perspective, + nullptr, + shared_from_this()); + } }; class Simple_type_descriptor final : public Type_descriptor @@ -277,6 +577,36 @@ public: { type_visitor.visit(*this); } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + auto component_type = Shader_interface::get_component_type_for_llvm_scalar_type(type.type); + if(!component_type) + throw spirv::Parser_error(0, 0, "invalid type in shader interface"); + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + auto component_count = Shader_interface::get_type_component_count(*component_type, 1); + if(component_count > current_position->get_components_left_in_current_location() + && current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration too big for type"); + Shader_interface_range range = { + .begin_position = *current_position, + .end_position = current_position->get_position_after_components(component_count), + }; + current_position = range.end_position.get_aligned_location_rounding_up(); + shader_interface.add( + Shader_interface::Variable(*component_type, + interpolation_kind, + range, + shader_interface_index_list_to_vector(parent_index_list), + base_type)); + } }; class Vector_type_descriptor final : public Type_descriptor @@ -326,6 +656,38 @@ public: { return element_count; } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + auto component_type = Shader_interface::get_component_type_for_llvm_scalar_type( + ::LLVMGetElementType(type.type)); + if(!component_type) + throw spirv::Parser_error(0, 0, "invalid type in shader interface"); + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + auto component_count = + Shader_interface::get_type_component_count(*component_type, element_count); + if(component_count > current_position->get_components_left_in_current_location() + && current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration too big for type"); + Shader_interface_range range = { + .begin_position = *current_position, + .end_position = current_position->get_position_after_components(component_count), + }; + current_position = range.end_position.get_aligned_location_rounding_up(); + shader_interface.add( + Shader_interface::Variable(*component_type, + interpolation_kind, + range, + shader_interface_index_list_to_vector(parent_index_list), + base_type)); + } }; class Array_type_descriptor final : public Type_descriptor @@ -400,7 +762,8 @@ public: column_major_type = retval; return retval; } - virtual util::optional get_matrix_stride(::LLVMTargetDataRef target_data) const override + virtual util::optional get_matrix_stride( + ::LLVMTargetDataRef target_data) const override { return element_type->get_matrix_stride(target_data); } @@ -412,6 +775,28 @@ public: { return element_count; } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + if(current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration not allowed on array"); + for(std::size_t i = 0; i < element_count; i++) + { + const Shader_interface_index_list_item index_list[1] = {{ + .prev = parent_index_list, .index = i, + }}; + element_type->add_to_shader_interface( + shader_interface, current_position, interpolation_kind, index_list, base_type); + } + } }; class Matrix_type_descriptor final : public Type_descriptor @@ -471,10 +856,33 @@ public: row_major_type = retval; return retval; } - virtual util::optional get_matrix_stride(::LLVMTargetDataRef target_data) const override + virtual util::optional get_matrix_stride( + ::LLVMTargetDataRef target_data) const override { return ::LLVMABISizeOfType(target_data, column_type->get_or_make_type().type); } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + if(current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration not allowed on matrix"); + for(std::size_t i = 0; i < column_count; i++) + { + const Shader_interface_index_list_item index_list[1] = {{ + .prev = parent_index_list, .index = i, + }}; + column_type->add_to_shader_interface( + shader_interface, current_position, interpolation_kind, index_list, base_type); + } + } }; class Row_major_matrix_type_descriptor final : public Type_descriptor @@ -544,6 +952,28 @@ public: { return Load_store_implementation_kind::Transpose_matrix; } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + if(current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration not allowed on matrix"); + for(std::size_t i = 0; i < row_count; i++) + { + const Shader_interface_index_list_item index_list[1] = {{ + .prev = parent_index_list, .index = i, + }}; + row_type->add_to_shader_interface( + shader_interface, current_position, interpolation_kind, index_list, base_type); + } + } }; inline std::shared_ptr Matrix_type_descriptor::make_row_major_type( @@ -619,6 +1049,16 @@ public: { type_visitor.visit(*this); } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + throw spirv::Parser_error(0, 0, "pointers not allowed shader interface"); + } }; class Function_type_descriptor final : public Type_descriptor @@ -672,6 +1112,16 @@ public: { return valid_for_entry_point; } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + throw spirv::Parser_error(0, 0, "function pointers not allowed shader interface"); + } }; class Struct_type_descriptor final : public Type_descriptor @@ -688,27 +1138,50 @@ public: type(std::move(type)) { } + util::optional find_decoration( + spirv::Decoration decoration_id) const + { + for(auto &decoration : decorations) + if(decoration.value == decoration_id) + return decoration; + return {}; + } + }; + enum class Layout_kind + { + Default, + Shader_interface, }; private: std::vector members; util::Enum_map builtin_members; + std::vector non_built_in_members; LLVM_type_and_alignment type; bool is_complete; Recursion_checker_state recursion_checker_state; std::size_t instruction_start_index; ::LLVMContextRef context; ::LLVMTargetDataRef target_data; + const Layout_kind layout_kind; void complete_type(); void on_add_member(std::size_t added_member_index) noexcept { assert(!is_complete); auto &member = members[added_member_index]; + bool is_built_in = false; for(auto &decoration : member.decorations) + { if(decoration.value == spirv::Decoration::built_in) + { builtin_members[util::get( decoration.parameters) .built_in] = added_member_index; + is_built_in = true; + } + } + if(!is_built_in) + non_built_in_members.push_back(added_member_index); } public: @@ -725,11 +1198,16 @@ public: get_or_make_type(); return members; } + Layout_kind get_layout_kind() const noexcept + { + return layout_kind; + } explicit Struct_type_descriptor(std::vector decorations, ::LLVMContextRef context, ::LLVMTargetDataRef target_data, const char *name, std::size_t instruction_start_index, + Layout_kind layout_kind, std::vector members = {}) : Type_descriptor(std::move(decorations)), members(std::move(members)), @@ -738,7 +1216,8 @@ public: is_complete(false), instruction_start_index(instruction_start_index), context(context), - target_data(target_data) + target_data(target_data), + layout_kind(layout_kind) { for(std::size_t member_index = 0; member_index < members.size(); member_index++) on_add_member(member_index); @@ -757,6 +1236,40 @@ public: { type_visitor.visit(*this); } + using Type_descriptor::add_to_shader_interface; + virtual void add_to_shader_interface( + Shader_interface &shader_interface, + util::optional ¤t_position, + Shader_interface::Interpolation_kind interpolation_kind, + const Shader_interface_index_list_item *parent_index_list, + const std::shared_ptr &base_type) override + { + if(find_decoration(spirv::Decoration::location)) + current_position = Shader_interface_position(decorations); + if(!current_position) + throw spirv::Parser_error( + 0, 0, "no Location decoration specified for shader interface"); + if(current_position->get_component_index() != 0) + throw spirv::Parser_error(0, 0, "Component decoration not allowed on struct"); + for(auto &member : get_members(true)) + { + if(member.find_decoration(spirv::Decoration::location)) + current_position = Shader_interface_position(member.decorations); + auto member_interpolation_kind = Shader_interface::Interpolation_kind::Perspective; + if(member.find_decoration(spirv::Decoration::flat)) + member_interpolation_kind = Shader_interface::Interpolation_kind::Flat; + else if(member.find_decoration(spirv::Decoration::no_perspective)) + member_interpolation_kind = Shader_interface::Interpolation_kind::Linear; + const Shader_interface_index_list_item index_list[1] = {{ + .prev = parent_index_list, .index = member.llvm_member_index, + }}; + member.type->add_to_shader_interface(shader_interface, + current_position, + member_interpolation_kind, + index_list, + base_type); + } + } }; class Constant_descriptor @@ -799,19 +1312,54 @@ struct Converted_module llvm_wrapper::Module module; std::string entry_function_name; std::shared_ptr inputs_struct; + std::shared_ptr built_in_inputs_struct; std::shared_ptr outputs_struct; + std::shared_ptr built_in_outputs_struct; spirv::Execution_model execution_model; + std::unique_ptr output_shader_interface; + std::unique_ptr built_in_output_shader_interface; + std::shared_ptr combined_outputs_struct; + static std::shared_ptr make_combined_outputs_struct( + ::LLVMContextRef context, + ::LLVMTargetDataRef target_data, + const char *name, + const std::shared_ptr &outputs_struct, + const std::shared_ptr &built_in_outputs_struct) + { + return std::make_shared( + std::vector{}, + context, + target_data, + name, + 0, + Struct_type_descriptor::Layout_kind::Default, + std::vector{ + Struct_type_descriptor::Member({}, built_in_outputs_struct), + Struct_type_descriptor::Member({}, outputs_struct), + }); + } Converted_module() = default; - explicit Converted_module(llvm_wrapper::Module module, - std::string entry_function_name, - std::shared_ptr inputs_struct, - std::shared_ptr outputs_struct, - spirv::Execution_model execution_model) noexcept + explicit Converted_module( + llvm_wrapper::Module module, + std::string entry_function_name, + std::shared_ptr inputs_struct, + std::shared_ptr built_in_inputs_struct, + std::shared_ptr outputs_struct, + std::shared_ptr built_in_outputs_struct, + spirv::Execution_model execution_model, + std::unique_ptr output_shader_interface, + std::unique_ptr built_in_output_shader_interface, + std::shared_ptr combined_outputs_struct) noexcept : module(std::move(module)), entry_function_name(std::move(entry_function_name)), inputs_struct(std::move(inputs_struct)), + built_in_inputs_struct(std::move(built_in_inputs_struct)), outputs_struct(std::move(outputs_struct)), - execution_model(execution_model) + built_in_outputs_struct(std::move(built_in_outputs_struct)), + execution_model(execution_model), + output_shader_interface(std::move(output_shader_interface)), + built_in_output_shader_interface(std::move(built_in_output_shader_interface)), + combined_outputs_struct(std::move(combined_outputs_struct)) { } }; @@ -834,15 +1382,18 @@ struct Jit_symbol_resolver class Spirv_to_llvm; -Converted_module spirv_to_llvm(::LLVMContextRef context, - ::LLVMTargetMachineRef target_machine, - const spirv::Word *shader_words, - std::size_t shader_size, - std::uint64_t shader_id, - spirv::Execution_model execution_model, - util::string_view entry_point_name, - const VkPipelineVertexInputStateCreateInfo *vertex_input_state, - pipeline::Instantiated_pipeline_layout &pipeline_layout); +Converted_module spirv_to_llvm( + ::LLVMContextRef context, + ::LLVMTargetMachineRef target_machine, + const spirv::Word *shader_words, + std::size_t shader_size, + std::uint64_t shader_id, + spirv::Execution_model execution_model, + util::string_view entry_point_name, + const VkPipelineVertexInputStateCreateInfo *vertex_input_state, + pipeline::Instantiated_pipeline_layout &pipeline_layout, + const Shader_interface *previous_stage_output_shader_interface, + const Shader_interface *previous_stage_built_in_output_shader_interface); } } diff --git a/src/spirv_to_llvm/spirv_to_llvm_implementation.h b/src/spirv_to_llvm/spirv_to_llvm_implementation.h index 95f00e2..d846eff 100644 --- a/src/spirv_to_llvm/spirv_to_llvm_implementation.h +++ b/src/spirv_to_llvm/spirv_to_llvm_implementation.h @@ -80,6 +80,16 @@ private: std::shared_ptr type; std::size_t member_index; }; + struct Built_in_input_variable_state + { + std::shared_ptr type; + std::size_t member_index; + }; + struct Built_in_output_variable_state + { + std::shared_ptr type; + std::size_t member_index; + }; struct Uniform_variable_state { std::shared_ptr type; @@ -94,7 +104,9 @@ private: }; typedef util::variant Variable_state; struct Function_state { @@ -104,16 +116,22 @@ private: ::LLVMValueRef io_struct; ::LLVMValueRef inputs_struct; ::LLVMValueRef outputs_struct; + ::LLVMValueRef built_in_inputs_struct; + ::LLVMValueRef built_in_outputs_struct; ::LLVMValueRef uniforms_struct; explicit Entry_block(::LLVMBasicBlockRef entry_block, ::LLVMValueRef io_struct, ::LLVMValueRef inputs_struct, ::LLVMValueRef outputs_struct, + ::LLVMValueRef built_in_inputs_struct, + ::LLVMValueRef built_in_outputs_struct, ::LLVMValueRef uniforms_struct) noexcept : entry_block(entry_block), io_struct(io_struct), inputs_struct(inputs_struct), outputs_struct(outputs_struct), + built_in_inputs_struct(built_in_inputs_struct), + built_in_outputs_struct(built_in_outputs_struct), uniforms_struct(uniforms_struct) { } @@ -241,6 +259,11 @@ private: std::size_t outputs_member; std::shared_ptr outputs_struct; std::shared_ptr outputs_struct_pointer_type; + std::size_t built_in_inputs_member; + std::shared_ptr built_in_inputs_struct; + std::size_t built_in_outputs_member; + std::shared_ptr built_in_outputs_struct; + std::shared_ptr built_in_outputs_struct_pointer_type; std::size_t uniforms_member; std::shared_ptr uniforms_struct_pointer_type; Stage stage; @@ -254,6 +277,9 @@ private: Op_entry_point_state *entry_point_state_pointer = nullptr; const VkPipelineVertexInputStateCreateInfo *vertex_input_state; pipeline::Instantiated_pipeline_layout &pipeline_layout; + const Shader_interface *previous_stage_output_shader_interface; + const Shader_interface *previous_stage_built_in_output_shader_interface; + std::shared_ptr combined_outputs_struct; private: Id_state &get_id_state(spirv::Id id) @@ -403,7 +429,9 @@ public: spirv::Execution_model execution_model, util::string_view entry_point_name, const VkPipelineVertexInputStateCreateInfo *vertex_input_state, - pipeline::Instantiated_pipeline_layout &pipeline_layout) + pipeline::Instantiated_pipeline_layout &pipeline_layout, + const Shader_interface *previous_stage_output_shader_interface, + const Shader_interface *previous_stage_built_in_output_shader_interface) : context(context), target_machine(target_machine), shader_id(shader_id), @@ -411,7 +439,10 @@ public: execution_model(execution_model), entry_point_name(entry_point_name), vertex_input_state(vertex_input_state), - pipeline_layout(pipeline_layout) + pipeline_layout(pipeline_layout), + previous_stage_output_shader_interface(previous_stage_output_shader_interface), + previous_stage_built_in_output_shader_interface( + previous_stage_built_in_output_shader_interface) { { std::ostringstream ss; @@ -428,7 +459,8 @@ public: context, target_data, get_prefixed_name("Io_struct", true).c_str(), - no_instruction_index); + no_instruction_index, + Struct_type_descriptor::Layout_kind::Default); assert(implicit_function_arguments.size() == 1); static_assert(io_struct_argument_index == 0, ""); implicit_function_arguments[io_struct_argument_index] = @@ -441,8 +473,9 @@ public: std::vector{}, context, target_data, - get_prefixed_name("Inputs", true).c_str(), - no_instruction_index); + get_prefixed_name("inputs", true).c_str(), + no_instruction_index, + Struct_type_descriptor::Layout_kind::Shader_interface); inputs_member = io_struct->add_member(Struct_type_descriptor::Member( {}, std::make_shared( @@ -451,29 +484,79 @@ public: std::vector{}, context, target_data, - get_prefixed_name("Outputs", true).c_str(), - no_instruction_index); + get_prefixed_name("outputs", true).c_str(), + no_instruction_index, + Struct_type_descriptor::Layout_kind::Shader_interface); outputs_struct_pointer_type = std::make_shared( std::vector{}, outputs_struct, 0, target_data); outputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct_pointer_type)); + built_in_inputs_struct = std::make_shared( + std::vector{}, + context, + target_data, + get_prefixed_name("built_in_inputs", true).c_str(), + no_instruction_index, + Struct_type_descriptor::Layout_kind::Shader_interface); + built_in_inputs_member = io_struct->add_member( + Struct_type_descriptor::Member({}, + std::make_shared( + std::vector{}, + built_in_inputs_struct, + 0, + target_data))); + built_in_outputs_struct = std::make_shared( + std::vector{}, + context, + target_data, + get_prefixed_name("built_in_outputs", true).c_str(), + no_instruction_index, + Struct_type_descriptor::Layout_kind::Shader_interface); + built_in_outputs_struct_pointer_type = std::make_shared( + std::vector{}, + built_in_outputs_struct, + 0, + target_data); + built_in_outputs_member = io_struct->add_member( + Struct_type_descriptor::Member({}, built_in_outputs_struct_pointer_type)); uniforms_struct_pointer_type = std::make_shared( std::vector{}, pipeline_layout.type, 0, target_data); uniforms_member = io_struct->add_member(Struct_type_descriptor::Member({}, uniforms_struct_pointer_type)); + combined_outputs_struct = + Converted_module::make_combined_outputs_struct(context, + target_data, + "combined_outputs_struct", + outputs_struct, + built_in_outputs_struct); } - ::LLVMValueRef generate_vertex_entry_function(Op_entry_point_state &entry_point, - ::LLVMValueRef main_function); - ::LLVMValueRef generate_fragment_entry_function(Op_entry_point_state &entry_point, - ::LLVMValueRef main_function); + ::LLVMValueRef generate_vertex_entry_function( + Op_entry_point_state &entry_point, + ::LLVMValueRef main_function, + Shader_interface &output_shader_interface, + Shader_interface &built_in_output_shader_interface); + ::LLVMValueRef generate_fragment_entry_function( + Op_entry_point_state &entry_point, + ::LLVMValueRef main_function, + Shader_interface &input_shader_interface, + Shader_interface &built_in_input_shader_interface); std::string generate_entry_function(Op_entry_point_state &entry_point, - ::LLVMValueRef main_function) + ::LLVMValueRef main_function, + Shader_interface *input_shader_interface, + Shader_interface *built_in_input_shader_interface, + Shader_interface *output_shader_interface, + Shader_interface *built_in_output_shader_interface) { ::LLVMValueRef entry_function = nullptr; switch(execution_model) { case spirv::Execution_model::vertex: - entry_function = generate_vertex_entry_function(entry_point, main_function); + assert(output_shader_interface); + assert(built_in_output_shader_interface); + entry_function = generate_vertex_entry_function(entry_point, + main_function, + *output_shader_interface, + *built_in_output_shader_interface); break; case spirv::Execution_model::tessellation_control: #warning implement execution model @@ -497,7 +580,12 @@ public: "unimplemented execution model: " + std::string(spirv::get_enumerant_name(execution_model))); case spirv::Execution_model::fragment: - entry_function = generate_fragment_entry_function(entry_point, main_function); + assert(input_shader_interface); + assert(built_in_input_shader_interface); + entry_function = generate_fragment_entry_function(entry_point, + main_function, + *input_shader_interface, + *built_in_input_shader_interface); break; case spirv::Execution_model::gl_compute: #warning implement execution model @@ -535,13 +623,56 @@ public: throw spirv::Parser_error(entry_point_state.instruction_start_index, entry_point_state.instruction_start_index, "No definition for function referenced in OpEntryPoint"); - auto entry_function_name = - generate_entry_function(entry_point_state, entry_point_id_state.function->function); + std::unique_ptr output_shader_interface; + std::unique_ptr built_in_output_shader_interface; + std::unique_ptr input_shader_interface; + std::unique_ptr built_in_input_shader_interface; + switch(execution_model) + { + case spirv::Execution_model::vertex: + output_shader_interface = std::make_unique(); + built_in_output_shader_interface = std::make_unique(); + break; + case spirv::Execution_model::tessellation_control: + case spirv::Execution_model::tessellation_evaluation: + case spirv::Execution_model::geometry: + input_shader_interface = std::make_unique(); + built_in_input_shader_interface = std::make_unique(); + output_shader_interface = std::make_unique(); + built_in_output_shader_interface = std::make_unique(); + break; + case spirv::Execution_model::fragment: + input_shader_interface = std::make_unique(); + built_in_input_shader_interface = std::make_unique(); + break; + case spirv::Execution_model::gl_compute: + case spirv::Execution_model::kernel: + break; + } + if(output_shader_interface) + outputs_struct->add_to_shader_interface(*output_shader_interface); + if(built_in_output_shader_interface) + built_in_outputs_struct->add_to_shader_interface(*built_in_output_shader_interface); + if(input_shader_interface) + inputs_struct->add_to_shader_interface(*input_shader_interface); + if(built_in_input_shader_interface) + built_in_inputs_struct->add_to_shader_interface(*built_in_input_shader_interface); + auto entry_function_name = generate_entry_function(entry_point_state, + entry_point_id_state.function->function, + input_shader_interface.get(), + built_in_input_shader_interface.get(), + output_shader_interface.get(), + built_in_output_shader_interface.get()); return Converted_module(std::move(module), std::move(entry_function_name), std::move(inputs_struct), + std::move(built_in_inputs_struct), std::move(outputs_struct), - execution_model); + std::move(built_in_outputs_struct), + execution_model, + std::move(output_shader_interface), + std::move(built_in_output_shader_interface), + std::move(combined_outputs_struct)); } virtual void handle_header(unsigned version_number_major, unsigned version_number_minor, diff --git a/src/spirv_to_llvm/vertex_entry_point.cpp b/src/spirv_to_llvm/vertex_entry_point.cpp index ee742ee..12b2543 100644 --- a/src/spirv_to_llvm/vertex_entry_point.cpp +++ b/src/spirv_to_llvm/vertex_entry_point.cpp @@ -29,9 +29,13 @@ namespace spirv_to_llvm { using namespace spirv; -::LLVMValueRef Spirv_to_llvm::generate_vertex_entry_function(Op_entry_point_state &entry_point, - ::LLVMValueRef main_function) +::LLVMValueRef Spirv_to_llvm::generate_vertex_entry_function( + Op_entry_point_state &entry_point, + ::LLVMValueRef main_function, + Shader_interface &output_shader_interface, + Shader_interface &built_in_output_shader_interface) { +#error finish adding shader interface code assert(vertex_input_state); typedef std::uint32_t Vertex_index_type; auto llvm_vertex_index_type = llvm_wrapper::Create_llvm_type()(context); @@ -905,10 +909,10 @@ using namespace spirv; "next_iteration_condition"); ::LLVMBuildCondBr(builder.get(), next_iteration_condition, loop_block, exit_block); ::LLVMPositionBuilderAtEnd(builder.get(), exit_block); - static_assert( - std::is_same()(0, 0, 0, nullptr, nullptr, nullptr)), - void>::value, - ""); + static_assert(std::is_same()( + 0, 0, 0, nullptr, nullptr, nullptr)), + void>::value, + ""); ::LLVMBuildRetVoid(builder.get()); return entry_function; } -- 2.30.2