From 129c61cc09b5a5dd0d7e10b2c2f744274b4ae1b3 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 16 Aug 2017 00:38:14 -0700 Subject: [PATCH] generates usable wrapper for vertex shader --- src/demo/demo.cpp | 2 +- src/spirv_to_llvm/spirv_to_llvm.cpp | 189 ++++++++++++++++++++-------- src/spirv_to_llvm/spirv_to_llvm.h | 9 -- 3 files changed, 139 insertions(+), 61 deletions(-) diff --git a/src/demo/demo.cpp b/src/demo/demo.cpp index e5e2b48..5c40345 100644 --- a/src/demo/demo.cpp +++ b/src/demo/demo.cpp @@ -272,7 +272,7 @@ int test_main(int argc, char **argv) nullptr); auto function = reinterpret_cast( orc_jit_stack.get_symbol_address(converted_module.entry_function_name.c_str())); - std::cerr << "entry point: " << converted_module.entry_function_name << ": " << function + std::cerr << "entry point: " << converted_module.entry_function_name << ": " << reinterpret_cast(function) << std::endl; } else diff --git a/src/spirv_to_llvm/spirv_to_llvm.cpp b/src/spirv_to_llvm/spirv_to_llvm.cpp index 8ea497b..423507c 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.cpp +++ b/src/spirv_to_llvm/spirv_to_llvm.cpp @@ -632,6 +632,7 @@ private: std::shared_ptr inputs_struct; std::size_t outputs_member; std::shared_ptr outputs_struct; + std::shared_ptr outputs_struct_pointer_type; Stage stage; Id current_function_id = 0; Id current_basic_block_id = 0; @@ -640,6 +641,7 @@ private: std::list> function_entry_block_handlers; spirv::Execution_model execution_model; util::string_view entry_point_name; + Op_entry_point_state *entry_point_state_pointer = nullptr; private: Id_state &get_id_state(Id id) @@ -746,6 +748,38 @@ private: } return get_prefixed_name(std::move(name), is_builtin_name); } + Op_entry_point_state &get_entry_point_state() + { + if(entry_point_state_pointer) + return *entry_point_state_pointer; + for(auto &id_state : id_states) + { + for(auto &entry_point : id_state.op_entry_points) + { + if(entry_point.entry_point.name != entry_point_name + || entry_point.entry_point.execution_model != execution_model) + continue; + if(entry_point_state_pointer) + throw Parser_error(entry_point.instruction_start_index, + entry_point.instruction_start_index, + "duplicate entry point: " + + std::string(spirv::get_enumerant_name(execution_model)) + + " \"" + + std::string(entry_point_name) + + "\""); + entry_point_state_pointer = &entry_point; + } + } + if(entry_point_state_pointer) + return *entry_point_state_pointer; + throw Parser_error(0, + 0, + "can't find entry point: " + + std::string(spirv::get_enumerant_name(execution_model)) + + " \"" + + std::string(entry_point_name) + + "\""); + } public: explicit Spirv_to_llvm(::LLVMContextRef context, @@ -789,14 +823,20 @@ public: target_data, get_prefixed_name("Inputs", true).c_str(), no_instruction_index); - inputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, inputs_struct)); + inputs_member = io_struct->add_member(Struct_type_descriptor::Member( + {}, + std::make_shared( + std::vector{}, inputs_struct, 0, target_data))); outputs_struct = std::make_shared(std::vector{}, context, target_data, get_prefixed_name("Outputs", true).c_str(), no_instruction_index); - outputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct)); + outputs_struct_pointer_type = std::make_shared( + std::vector{}, outputs_struct, 0, target_data); + outputs_member = + io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct_pointer_type)); } std::string generate_entry_function(Op_entry_point_state &entry_point, ::LLVMValueRef main_function) @@ -829,7 +869,7 @@ public: ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_vertex_end_index), "vertex_end_index"); ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_instance_id), "instance_id"); - ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer"); + ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer_"); auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry"); auto loop_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "loop"); auto exit_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "exit"); @@ -837,10 +877,23 @@ public: auto io_struct_type = io_struct->get_or_make_type(); auto io_struct_pointer = ::LLVMBuildAlloca(builder.get(), io_struct_type.type, "io_struct"); + auto inputs_struct_pointer = + ::LLVMBuildAlloca(builder.get(), inputs_struct->get_or_make_type().type, "inputs"); ::LLVMSetAlignment( ::LLVMBuildStore( builder.get(), ::LLVMConstNull(io_struct_type.type), io_struct_pointer), io_struct_type.alignment); + auto inputs_pointer = + ::LLVMBuildStructGEP(builder.get(), + io_struct_pointer, + io_struct->get_members(true)[inputs_member].llvm_member_index, + "inputs_pointer"); + ::LLVMBuildStore(builder.get(), inputs_struct_pointer, inputs_pointer); + auto start_output_buffer = + ::LLVMBuildBitCast(builder.get(), + ::LLVMGetParam(entry_function, arg_output_buffer), + outputs_struct_pointer_type->get_or_make_type().type, + "start_output_buffer"); auto start_loop_condition = ::LLVMBuildICmp(builder.get(), ::LLVMIntULT, @@ -853,6 +906,10 @@ public: ::LLVMBuildPhi(builder.get(), llvm_wrapper::Create_llvm_type()(context), "vertex_index"); + auto output_buffer = + ::LLVMBuildPhi(builder.get(), + outputs_struct_pointer_type->get_or_make_type().type, + "output_buffer"); auto next_vertex_index = ::LLVMBuildNUWAdd(builder.get(), vertex_index, @@ -869,18 +926,38 @@ public: vertex_index_incoming_values, vertex_index_incoming_blocks, vertex_index_incoming_count); - for(auto &member : io_struct->get_members(true)) + ::LLVMValueRef next_output_buffer; + { + constexpr std::size_t index_count = 1; + ::LLVMValueRef indexes[index_count] = {::LLVMConstInt( + llvm_wrapper::Create_llvm_type()(context), 1, true)}; + next_output_buffer = ::LLVMBuildGEP( + builder.get(), output_buffer, indexes, index_count, "next_output_buffer"); + } + constexpr std::size_t output_buffer_incoming_count = 2; + ::LLVMValueRef output_buffer_incoming_values[output_buffer_incoming_count] = { + next_output_buffer, start_output_buffer, + }; + ::LLVMBasicBlockRef output_buffer_incoming_blocks[output_buffer_incoming_count] = { + loop_block, entry_block, + }; + ::LLVMAddIncoming(output_buffer, + output_buffer_incoming_values, + output_buffer_incoming_blocks, + output_buffer_incoming_count); + auto &&members = io_struct->get_members(true); + for(std::size_t member_index = 0; member_index < members.size(); member_index++) { - if(member.type == inputs_struct) + auto &member = members[member_index]; + if(member_index == inputs_member) { - auto inputs_struct_pointer = ::LLVMBuildStructGEP( - builder.get(), io_struct_pointer, member.llvm_member_index, "inputs"); for(auto &input_member : inputs_struct->get_members(true)) { auto input_pointer = ::LLVMBuildStructGEP(builder.get(), inputs_struct_pointer, input_member.llvm_member_index, "input"); + ::LLVMDumpType(::LLVMTypeOf(input_pointer)); util::optional built_in; static_cast(input_pointer); for(auto &decoration : input_member.decorations) @@ -1238,10 +1315,15 @@ public: } while(false); } } - else if(member.type == outputs_struct) + else if(member_index == outputs_member) { - auto outputs_struct_pointer = ::LLVMBuildStructGEP( - builder.get(), io_struct_pointer, member.llvm_member_index, "outputs"); + auto outputs_struct_pointer = output_buffer; + ::LLVMBuildStore(builder.get(), + outputs_struct_pointer, + ::LLVMBuildStructGEP(builder.get(), + io_struct_pointer, + member.llvm_member_index, + "outputs_pointer")); for(auto &output_member : outputs_struct->get_members(true)) { auto output_pointer = ::LLVMBuildStructGEP(builder.get(), @@ -1497,44 +1579,17 @@ public: #warning finish Spirv_to_llvm::run stage = Stage::generate_code; spirv::parse(*this, shader_words, shader_size); - std::string entry_function_name; - for(auto &id_state : id_states) - { - for(auto &entry_point : id_state.op_entry_points) - { - if(!id_state.function) - throw Parser_error(entry_point.instruction_start_index, - entry_point.instruction_start_index, - "No definition for function referenced in OpEntryPoint"); - if(entry_point.entry_point.name != entry_point_name - || entry_point.entry_point.execution_model != execution_model) - continue; - if(!entry_function_name.empty()) - throw Parser_error(entry_point.instruction_start_index, - entry_point.instruction_start_index, - "duplicate entry point: " - + std::string(spirv::get_enumerant_name(execution_model)) - + " \"" - + std::string(entry_point_name) - + "\""); - entry_function_name = - generate_entry_function(entry_point, id_state.function->function); - } - } - if(entry_function_name.empty()) - throw Parser_error(0, - 0, - "can't find entry point: " - + std::string(spirv::get_enumerant_name(execution_model)) - + " \"" - + std::string(entry_point_name) - + "\""); + auto &entry_point_state = get_entry_point_state(); + auto &entry_point_id_state = get_id_state(entry_point_state.entry_point.entry_point); + if(!entry_point_id_state.function) + throw Parser_error(entry_point_state.instruction_start_index, + entry_point_state.instruction_start_index, + "No definition for function referenced in OpEntryPoint"); + auto entry_function_name = + generate_entry_function(entry_point_state, entry_point_id_state.function->function); return Converted_module(std::move(module), std::move(entry_function_name), - std::move(io_struct), - inputs_member, std::move(inputs_struct), - outputs_member, std::move(outputs_struct)); } virtual void handle_header(unsigned version_number_major, @@ -2830,6 +2885,10 @@ void Spirv_to_llvm::handle_instruction_op_entry_point(Op_entry_point instruction { if(stage == util::Enum_traits::values[0]) { + if(entry_point_state_pointer) + throw Parser_error(instruction_start_index, + instruction_start_index, + "invalid location for OpEntryPoint"); auto &state = get_id_state(instruction.entry_point); state.op_entry_points.push_back( Op_entry_point_state{std::move(instruction), instruction_start_index}); @@ -3885,6 +3944,16 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, case Stage::generate_code: { auto &state = get_id_state(instruction.result); + auto &entry_point_state = get_entry_point_state(); + bool is_part_of_entry_point_interface = false; + for(Id_ref id : entry_point_state.entry_point.interface) + { + if(instruction.result == id) + { + is_part_of_entry_point_interface = true; + break; + } + } switch(instruction.storage_class) { case Storage_class::uniform_constant: @@ -3896,6 +3965,12 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, throw Parser_error(instruction_start_index, instruction_start_index, "shader input variable initializers are not implemented"); + if(!is_part_of_entry_point_interface) + { + auto type = get_type(instruction.result_type, instruction_start_index); + state.value = Value(::LLVMGetUndef(type->get_or_make_type().type), type); + return; + } auto set_value_fn = [this, instruction, &state, instruction_start_index]() { auto &variable = util::get(state.variable); @@ -3922,6 +3997,12 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction, throw Parser_error(instruction_start_index, instruction_start_index, "shader output variable initializers are not implemented"); + if(!is_part_of_entry_point_interface) + { + auto type = get_type(instruction.result_type, instruction_start_index); + state.value = Value(::LLVMGetUndef(type->get_or_make_type().type), type); + return; + } auto set_value_fn = [this, instruction, &state, instruction_start_index]() { auto &variable = util::get(state.variable); @@ -6198,15 +6279,21 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction, if(!function.entry_block) { auto io_struct_value = ::LLVMGetParam(function.function, io_struct_argument_index); - auto inputs_struct_value = ::LLVMBuildStructGEP( + auto inputs_struct_value = ::LLVMBuildLoad( builder.get(), - io_struct_value, - io_struct->get_members(true)[this->inputs_member].llvm_member_index, + ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->inputs_member].llvm_member_index, + "inputs_pointer"), "inputs"); - auto outputs_struct_value = ::LLVMBuildStructGEP( + auto outputs_struct_value = ::LLVMBuildLoad( builder.get(), - io_struct_value, - io_struct->get_members(true)[this->outputs_member].llvm_member_index, + ::LLVMBuildStructGEP( + builder.get(), + io_struct_value, + io_struct->get_members(true)[this->outputs_member].llvm_member_index, + "outputs_pointer"), "outputs"); function.entry_block = Function_state::Entry_block( block, io_struct_value, inputs_struct_value, outputs_struct_value); diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index 8d57b8f..42ed3f4 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -568,25 +568,16 @@ struct Converted_module { llvm_wrapper::Module module; std::string entry_function_name; - std::shared_ptr io_struct; - std::size_t inputs_member; std::shared_ptr inputs_struct; - std::size_t outputs_member; std::shared_ptr outputs_struct; Converted_module() = default; explicit Converted_module(llvm_wrapper::Module module, std::string entry_function_name, - std::shared_ptr io_struct, - std::size_t inputs_member, std::shared_ptr inputs_struct, - std::size_t outputs_member, std::shared_ptr outputs_struct) noexcept : module(std::move(module)), entry_function_name(std::move(entry_function_name)), - io_struct(std::move(io_struct)), - inputs_member(inputs_member), inputs_struct(std::move(inputs_struct)), - outputs_member(outputs_member), outputs_struct(std::move(outputs_struct)) { } -- 2.30.2