generates usable wrapper for vertex shader
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Aug 2017 07:38:14 +0000 (00:38 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Aug 2017 07:38:14 +0000 (00:38 -0700)
src/demo/demo.cpp
src/spirv_to_llvm/spirv_to_llvm.cpp
src/spirv_to_llvm/spirv_to_llvm.h

index e5e2b48b956e41defed27add0b10f925684065c9..5c40345a9b632556857b0ea6bb62248e7d472660 100644 (file)
@@ -272,7 +272,7 @@ int test_main(int argc, char **argv)
             nullptr);
         auto function = reinterpret_cast<void *>(
             orc_jit_stack.get_symbol_address(converted_module.entry_function_name.c_str()));
-        std::cerr << "entry point: " << converted_module.entry_function_name << ": " << function
+        std::cerr << "entry point: " << converted_module.entry_function_name << ": " << reinterpret_cast<void *>(function)
                   << std::endl;
     }
     else
index 8ea497b28d348e8b281b101a4e520f85897c52e9..423507c7096b88397c4d7586d58f2dd80178d7ab 100644 (file)
@@ -632,6 +632,7 @@ private:
     std::shared_ptr<Struct_type_descriptor> inputs_struct;
     std::size_t outputs_member;
     std::shared_ptr<Struct_type_descriptor> outputs_struct;
+    std::shared_ptr<Pointer_type_descriptor> outputs_struct_pointer_type;
     Stage stage;
     Id current_function_id = 0;
     Id current_basic_block_id = 0;
@@ -640,6 +641,7 @@ private:
     std::list<std::function<void()>> function_entry_block_handlers;
     spirv::Execution_model execution_model;
     util::string_view entry_point_name;
+    Op_entry_point_state *entry_point_state_pointer = nullptr;
 
 private:
     Id_state &get_id_state(Id id)
@@ -746,6 +748,38 @@ private:
         }
         return get_prefixed_name(std::move(name), is_builtin_name);
     }
+    Op_entry_point_state &get_entry_point_state()
+    {
+        if(entry_point_state_pointer)
+            return *entry_point_state_pointer;
+        for(auto &id_state : id_states)
+        {
+            for(auto &entry_point : id_state.op_entry_points)
+            {
+                if(entry_point.entry_point.name != entry_point_name
+                   || entry_point.entry_point.execution_model != execution_model)
+                    continue;
+                if(entry_point_state_pointer)
+                    throw Parser_error(entry_point.instruction_start_index,
+                                       entry_point.instruction_start_index,
+                                       "duplicate entry point: "
+                                           + std::string(spirv::get_enumerant_name(execution_model))
+                                           + " \""
+                                           + std::string(entry_point_name)
+                                           + "\"");
+                entry_point_state_pointer = &entry_point;
+            }
+        }
+        if(entry_point_state_pointer)
+            return *entry_point_state_pointer;
+        throw Parser_error(0,
+                           0,
+                           "can't find entry point: "
+                               + std::string(spirv::get_enumerant_name(execution_model))
+                               + " \""
+                               + std::string(entry_point_name)
+                               + "\"");
+    }
 
 public:
     explicit Spirv_to_llvm(::LLVMContextRef context,
@@ -789,14 +823,20 @@ public:
                                                      target_data,
                                                      get_prefixed_name("Inputs", true).c_str(),
                                                      no_instruction_index);
-        inputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, inputs_struct));
+        inputs_member = io_struct->add_member(Struct_type_descriptor::Member(
+            {},
+            std::make_shared<Pointer_type_descriptor>(
+                std::vector<Decoration_with_parameters>{}, inputs_struct, 0, target_data)));
         outputs_struct =
             std::make_shared<Struct_type_descriptor>(std::vector<Decoration_with_parameters>{},
                                                      context,
                                                      target_data,
                                                      get_prefixed_name("Outputs", true).c_str(),
                                                      no_instruction_index);
-        outputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct));
+        outputs_struct_pointer_type = std::make_shared<Pointer_type_descriptor>(
+            std::vector<Decoration_with_parameters>{}, outputs_struct, 0, target_data);
+        outputs_member =
+            io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct_pointer_type));
     }
     std::string generate_entry_function(Op_entry_point_state &entry_point,
                                         ::LLVMValueRef main_function)
@@ -829,7 +869,7 @@ public:
             ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_vertex_end_index),
                                "vertex_end_index");
             ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_instance_id), "instance_id");
-            ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer");
+            ::LLVMSetValueName(::LLVMGetParam(entry_function, arg_output_buffer), "output_buffer_");
             auto entry_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "entry");
             auto loop_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "loop");
             auto exit_block = ::LLVMAppendBasicBlockInContext(context, entry_function, "exit");
@@ -837,10 +877,23 @@ public:
             auto io_struct_type = io_struct->get_or_make_type();
             auto io_struct_pointer =
                 ::LLVMBuildAlloca(builder.get(), io_struct_type.type, "io_struct");
+            auto inputs_struct_pointer =
+                ::LLVMBuildAlloca(builder.get(), inputs_struct->get_or_make_type().type, "inputs");
             ::LLVMSetAlignment(
                 ::LLVMBuildStore(
                     builder.get(), ::LLVMConstNull(io_struct_type.type), io_struct_pointer),
                 io_struct_type.alignment);
+            auto inputs_pointer =
+                ::LLVMBuildStructGEP(builder.get(),
+                                     io_struct_pointer,
+                                     io_struct->get_members(true)[inputs_member].llvm_member_index,
+                                     "inputs_pointer");
+            ::LLVMBuildStore(builder.get(), inputs_struct_pointer, inputs_pointer);
+            auto start_output_buffer =
+                ::LLVMBuildBitCast(builder.get(),
+                                   ::LLVMGetParam(entry_function, arg_output_buffer),
+                                   outputs_struct_pointer_type->get_or_make_type().type,
+                                   "start_output_buffer");
             auto start_loop_condition =
                 ::LLVMBuildICmp(builder.get(),
                                 ::LLVMIntULT,
@@ -853,6 +906,10 @@ public:
                 ::LLVMBuildPhi(builder.get(),
                                llvm_wrapper::Create_llvm_type<Vertex_index_type>()(context),
                                "vertex_index");
+            auto output_buffer =
+                ::LLVMBuildPhi(builder.get(),
+                               outputs_struct_pointer_type->get_or_make_type().type,
+                               "output_buffer");
             auto next_vertex_index =
                 ::LLVMBuildNUWAdd(builder.get(),
                                   vertex_index,
@@ -869,18 +926,38 @@ public:
                               vertex_index_incoming_values,
                               vertex_index_incoming_blocks,
                               vertex_index_incoming_count);
-            for(auto &member : io_struct->get_members(true))
+            ::LLVMValueRef next_output_buffer;
+            {
+                constexpr std::size_t index_count = 1;
+                ::LLVMValueRef indexes[index_count] = {::LLVMConstInt(
+                    llvm_wrapper::Create_llvm_type<std::ptrdiff_t>()(context), 1, true)};
+                next_output_buffer = ::LLVMBuildGEP(
+                    builder.get(), output_buffer, indexes, index_count, "next_output_buffer");
+            }
+            constexpr std::size_t output_buffer_incoming_count = 2;
+            ::LLVMValueRef output_buffer_incoming_values[output_buffer_incoming_count] = {
+                next_output_buffer, start_output_buffer,
+            };
+            ::LLVMBasicBlockRef output_buffer_incoming_blocks[output_buffer_incoming_count] = {
+                loop_block, entry_block,
+            };
+            ::LLVMAddIncoming(output_buffer,
+                              output_buffer_incoming_values,
+                              output_buffer_incoming_blocks,
+                              output_buffer_incoming_count);
+            auto &&members = io_struct->get_members(true);
+            for(std::size_t member_index = 0; member_index < members.size(); member_index++)
             {
-                if(member.type == inputs_struct)
+                auto &member = members[member_index];
+                if(member_index == inputs_member)
                 {
-                    auto inputs_struct_pointer = ::LLVMBuildStructGEP(
-                        builder.get(), io_struct_pointer, member.llvm_member_index, "inputs");
                     for(auto &input_member : inputs_struct->get_members(true))
                     {
                         auto input_pointer = ::LLVMBuildStructGEP(builder.get(),
                                                                   inputs_struct_pointer,
                                                                   input_member.llvm_member_index,
                                                                   "input");
+                        ::LLVMDumpType(::LLVMTypeOf(input_pointer));
                         util::optional<spirv::Built_in> built_in;
                         static_cast<void>(input_pointer);
                         for(auto &decoration : input_member.decorations)
@@ -1238,10 +1315,15 @@ public:
                         } while(false);
                     }
                 }
-                else if(member.type == outputs_struct)
+                else if(member_index == outputs_member)
                 {
-                    auto outputs_struct_pointer = ::LLVMBuildStructGEP(
-                        builder.get(), io_struct_pointer, member.llvm_member_index, "outputs");
+                    auto outputs_struct_pointer = output_buffer;
+                    ::LLVMBuildStore(builder.get(),
+                                     outputs_struct_pointer,
+                                     ::LLVMBuildStructGEP(builder.get(),
+                                                          io_struct_pointer,
+                                                          member.llvm_member_index,
+                                                          "outputs_pointer"));
                     for(auto &output_member : outputs_struct->get_members(true))
                     {
                         auto output_pointer = ::LLVMBuildStructGEP(builder.get(),
@@ -1497,44 +1579,17 @@ public:
 #warning finish Spirv_to_llvm::run
         stage = Stage::generate_code;
         spirv::parse(*this, shader_words, shader_size);
-        std::string entry_function_name;
-        for(auto &id_state : id_states)
-        {
-            for(auto &entry_point : id_state.op_entry_points)
-            {
-                if(!id_state.function)
-                    throw Parser_error(entry_point.instruction_start_index,
-                                       entry_point.instruction_start_index,
-                                       "No definition for function referenced in OpEntryPoint");
-                if(entry_point.entry_point.name != entry_point_name
-                   || entry_point.entry_point.execution_model != execution_model)
-                    continue;
-                if(!entry_function_name.empty())
-                    throw Parser_error(entry_point.instruction_start_index,
-                                       entry_point.instruction_start_index,
-                                       "duplicate entry point: "
-                                           + std::string(spirv::get_enumerant_name(execution_model))
-                                           + " \""
-                                           + std::string(entry_point_name)
-                                           + "\"");
-                entry_function_name =
-                    generate_entry_function(entry_point, id_state.function->function);
-            }
-        }
-        if(entry_function_name.empty())
-            throw Parser_error(0,
-                               0,
-                               "can't find entry point: "
-                                   + std::string(spirv::get_enumerant_name(execution_model))
-                                   + " \""
-                                   + std::string(entry_point_name)
-                                   + "\"");
+        auto &entry_point_state = get_entry_point_state();
+        auto &entry_point_id_state = get_id_state(entry_point_state.entry_point.entry_point);
+        if(!entry_point_id_state.function)
+            throw Parser_error(entry_point_state.instruction_start_index,
+                               entry_point_state.instruction_start_index,
+                               "No definition for function referenced in OpEntryPoint");
+        auto entry_function_name =
+            generate_entry_function(entry_point_state, entry_point_id_state.function->function);
         return Converted_module(std::move(module),
                                 std::move(entry_function_name),
-                                std::move(io_struct),
-                                inputs_member,
                                 std::move(inputs_struct),
-                                outputs_member,
                                 std::move(outputs_struct));
     }
     virtual void handle_header(unsigned version_number_major,
@@ -2830,6 +2885,10 @@ void Spirv_to_llvm::handle_instruction_op_entry_point(Op_entry_point instruction
 {
     if(stage == util::Enum_traits<Stage>::values[0])
     {
+        if(entry_point_state_pointer)
+            throw Parser_error(instruction_start_index,
+                               instruction_start_index,
+                               "invalid location for OpEntryPoint");
         auto &state = get_id_state(instruction.entry_point);
         state.op_entry_points.push_back(
             Op_entry_point_state{std::move(instruction), instruction_start_index});
@@ -3885,6 +3944,16 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
     case Stage::generate_code:
     {
         auto &state = get_id_state(instruction.result);
+        auto &entry_point_state = get_entry_point_state();
+        bool is_part_of_entry_point_interface = false;
+        for(Id_ref id : entry_point_state.entry_point.interface)
+        {
+            if(instruction.result == id)
+            {
+                is_part_of_entry_point_interface = true;
+                break;
+            }
+        }
         switch(instruction.storage_class)
         {
         case Storage_class::uniform_constant:
@@ -3896,6 +3965,12 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                 throw Parser_error(instruction_start_index,
                                    instruction_start_index,
                                    "shader input variable initializers are not implemented");
+            if(!is_part_of_entry_point_interface)
+            {
+                auto type = get_type(instruction.result_type, instruction_start_index);
+                state.value = Value(::LLVMGetUndef(type->get_or_make_type().type), type);
+                return;
+            }
             auto set_value_fn = [this, instruction, &state, instruction_start_index]()
             {
                 auto &variable = util::get<Input_variable_state>(state.variable);
@@ -3922,6 +3997,12 @@ void Spirv_to_llvm::handle_instruction_op_variable(Op_variable instruction,
                 throw Parser_error(instruction_start_index,
                                    instruction_start_index,
                                    "shader output variable initializers are not implemented");
+            if(!is_part_of_entry_point_interface)
+            {
+                auto type = get_type(instruction.result_type, instruction_start_index);
+                state.value = Value(::LLVMGetUndef(type->get_or_make_type().type), type);
+                return;
+            }
             auto set_value_fn = [this, instruction, &state, instruction_start_index]()
             {
                 auto &variable = util::get<Output_variable_state>(state.variable);
@@ -6198,15 +6279,21 @@ void Spirv_to_llvm::handle_instruction_op_label(Op_label instruction,
         if(!function.entry_block)
         {
             auto io_struct_value = ::LLVMGetParam(function.function, io_struct_argument_index);
-            auto inputs_struct_value = ::LLVMBuildStructGEP(
+            auto inputs_struct_value = ::LLVMBuildLoad(
                 builder.get(),
-                io_struct_value,
-                io_struct->get_members(true)[this->inputs_member].llvm_member_index,
+                ::LLVMBuildStructGEP(
+                    builder.get(),
+                    io_struct_value,
+                    io_struct->get_members(true)[this->inputs_member].llvm_member_index,
+                    "inputs_pointer"),
                 "inputs");
-            auto outputs_struct_value = ::LLVMBuildStructGEP(
+            auto outputs_struct_value = ::LLVMBuildLoad(
                 builder.get(),
-                io_struct_value,
-                io_struct->get_members(true)[this->outputs_member].llvm_member_index,
+                ::LLVMBuildStructGEP(
+                    builder.get(),
+                    io_struct_value,
+                    io_struct->get_members(true)[this->outputs_member].llvm_member_index,
+                    "outputs_pointer"),
                 "outputs");
             function.entry_block = Function_state::Entry_block(
                 block, io_struct_value, inputs_struct_value, outputs_struct_value);
index 8d57b8f29c24106cc608e8a5f89e3bb46c1c7998..42ed3f45867ba97aad6eee1eff1e7050b4b45c08 100644 (file)
@@ -568,25 +568,16 @@ struct Converted_module
 {
     llvm_wrapper::Module module;
     std::string entry_function_name;
-    std::shared_ptr<Struct_type_descriptor> io_struct;
-    std::size_t inputs_member;
     std::shared_ptr<Struct_type_descriptor> inputs_struct;
-    std::size_t outputs_member;
     std::shared_ptr<Struct_type_descriptor> outputs_struct;
     Converted_module() = default;
     explicit Converted_module(llvm_wrapper::Module module,
                               std::string entry_function_name,
-                              std::shared_ptr<Struct_type_descriptor> io_struct,
-                              std::size_t inputs_member,
                               std::shared_ptr<Struct_type_descriptor> inputs_struct,
-                              std::size_t outputs_member,
                               std::shared_ptr<Struct_type_descriptor> outputs_struct) noexcept
         : module(std::move(module)),
           entry_function_name(std::move(entry_function_name)),
-          io_struct(std::move(io_struct)),
-          inputs_member(inputs_member),
           inputs_struct(std::move(inputs_struct)),
-          outputs_member(outputs_member),
           outputs_struct(std::move(outputs_struct))
     {
     }