start adding graphics pipeline
[kazan.git] / src / spirv_to_llvm / spirv_to_llvm.cpp
index 560ed8ab4c9bd7a87f276e04598222208e994ae3..09ea92e758d217277f651dd7088706059f26ec3f 100644 (file)
@@ -24,6 +24,7 @@
 #include "util/optional.h"
 #include "util/variant.h"
 #include "util/enum.h"
+#include "pipeline/pipeline.h"
 #include <functional>
 #include <list>
 #include <iostream>
@@ -637,6 +638,8 @@ private:
     llvm_wrapper::Builder builder;
     util::optional<Last_merge_instruction> last_merge_instruction;
     std::list<std::function<void()>> function_entry_block_handlers;
+    spirv::Execution_model execution_model;
+    util::string_view entry_point_name;
 
 private:
     Id_state &get_id_state(Id id)
@@ -712,11 +715,11 @@ private:
         {
             auto &function = get_id_state(current_function_id).function.value();
             state.label = Label_state(::LLVMAppendBasicBlockInContext(
-                context, function.function, get_prefixed_name(get_name(id)).c_str()));
+                context, function.function, get_prefixed_name(get_name(id), false).c_str()));
         }
         return state.label->basic_block;
     }
-    std::string get_prefixed_name(std::string name) const
+    std::string get_prefixed_name(std::string name, bool is_builtin_name) const
     {
         if(!name.empty())
         {
@@ -727,11 +730,13 @@ private:
                 // ensure name doesn't conflict with names generated by get_or_make_prefixed_name
                 name.insert(0, "_");
             }
+            if(!is_builtin_name)
+                name.insert(0, "_"); // ensure user names don't conflict with builtin names
             return name_prefix_string + std::move(name);
         }
         return name;
     }
-    std::string get_or_make_prefixed_name(std::string name)
+    std::string get_or_make_prefixed_name(std::string name, bool is_builtin_name)
     {
         if(name.empty())
         {
@@ -739,14 +744,21 @@ private:
             ss << name_prefix_string << next_name_index++;
             return ss.str();
         }
-        return get_prefixed_name(std::move(name));
+        return get_prefixed_name(std::move(name), is_builtin_name);
     }
 
 public:
     explicit Spirv_to_llvm(::LLVMContextRef context,
                            ::LLVMTargetMachineRef target_machine,
-                           std::uint64_t shader_id)
-        : context(context), target_machine(target_machine), shader_id(shader_id), stage()
+                           std::uint64_t shader_id,
+                           spirv::Execution_model execution_model,
+                           util::string_view entry_point_name)
+        : context(context),
+          target_machine(target_machine),
+          shader_id(shader_id),
+          stage(),
+          execution_model(execution_model),
+          entry_point_name(entry_point_name)
     {
         {
             std::ostringstream ss;
@@ -754,7 +766,7 @@ public:
             name_prefix_string = ss.str();
         }
         module = llvm_wrapper::Module::create_with_target_machine(
-            get_prefixed_name("module").c_str(), context, target_machine);
+            get_prefixed_name("module", true).c_str(), context, target_machine);
         target_data = ::LLVMGetModuleDataLayout(module.get());
         builder = llvm_wrapper::Builder::create(context);
         constexpr std::size_t no_instruction_index = 0;
@@ -762,7 +774,7 @@ public:
             std::make_shared<Struct_type_descriptor>(std::vector<Decoration_with_parameters>{},
                                                      context,
                                                      target_data,
-                                                     get_prefixed_name("Io_struct").c_str(),
+                                                     get_prefixed_name("Io_struct", true).c_str(),
                                                      no_instruction_index);
         assert(implicit_function_arguments.size() == 1);
         static_assert(io_struct_argument_index == 0, "");
@@ -775,17 +787,86 @@ public:
             std::make_shared<Struct_type_descriptor>(std::vector<Decoration_with_parameters>{},
                                                      context,
                                                      target_data,
-                                                     get_prefixed_name("Inputs").c_str(),
+                                                     get_prefixed_name("Inputs", true).c_str(),
                                                      no_instruction_index);
         inputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, inputs_struct));
         outputs_struct =
             std::make_shared<Struct_type_descriptor>(std::vector<Decoration_with_parameters>{},
                                                      context,
                                                      target_data,
-                                                     get_prefixed_name("Outputs").c_str(),
+                                                     get_prefixed_name("Outputs", true).c_str(),
                                                      no_instruction_index);
         outputs_member = io_struct->add_member(Struct_type_descriptor::Member({}, outputs_struct));
     }
+    std::string generate_entry_function(Op_entry_point_state &entry_point)
+    {
+        ::LLVMValueRef function = nullptr;
+        switch(execution_model)
+        {
+        case spirv::Execution_model::vertex:
+        {
+            typedef void (*Vertex_shader_function)(std::uint32_t vertex_start_index,
+                                                   std::uint32_t vertex_count,
+                                                   std::uint32_t instance_id,
+                                                   void *output_buffer);
+            constexpr std::size_t arg_vertex_start_index = 0;
+            constexpr std::size_t arg_vertex_count = 1;
+            constexpr std::size_t arg_instance_id = 2;
+            constexpr std::size_t arg_output_buffer = 3;
+            static_assert(std::is_same<Vertex_shader_function,
+                                       pipeline::Graphics_pipeline::Vertex_shader_function>::value,
+                          "vertex shader function signature mismatch");
+            auto function_type = llvm_wrapper::Create_llvm_type<Vertex_shader_function>()(context);
+            function = ::LLVMAddFunction(
+                module.get(), get_prefixed_name("vertex_entry_point", true).c_str(), function_type);
+            llvm_wrapper::Module::set_function_target_machine(function, target_machine);
+            static_cast<void>(arg_vertex_start_index);
+            static_cast<void>(arg_vertex_count);
+            static_cast<void>(arg_instance_id);
+            static_cast<void>(arg_output_buffer);
+#warning finish implementing vertex execution model
+            break;
+        }
+        case spirv::Execution_model::tessellation_control:
+#warning implement execution model
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        case spirv::Execution_model::tessellation_evaluation:
+#warning implement execution model
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        case spirv::Execution_model::geometry:
+#warning implement execution model
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        case spirv::Execution_model::fragment:
+#warning implement execution model
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        case spirv::Execution_model::gl_compute:
+#warning implement execution model
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        case spirv::Execution_model::kernel:
+            // TODO: implement execution model as extension
+            throw Parser_error(entry_point.instruction_start_index,
+                               entry_point.instruction_start_index,
+                               "unimplemented execution model: "
+                                   + std::string(spirv::get_enumerant_name(execution_model)));
+        }
+        assert(function);
+        return ::LLVMGetValueName(function);
+    }
     Converted_module run(const Word *shader_words, std::size_t shader_size)
     {
         stage = Stage::calculate_types;
@@ -798,7 +879,7 @@ public:
 #warning finish Spirv_to_llvm::run
         stage = Stage::generate_code;
         spirv::parse(*this, shader_words, shader_size);
-        std::vector<Converted_module::Entry_point> entry_points;
+        std::string entry_function_name;
         for(auto &id_state : id_states)
         {
             for(auto &entry_point : id_state.op_entry_points)
@@ -807,19 +888,35 @@ public:
                     throw Parser_error(entry_point.instruction_start_index,
                                        entry_point.instruction_start_index,
                                        "No definition for function referenced in OpEntryPoint");
-                entry_points.push_back(
-                    Converted_module::Entry_point(std::string(entry_point.entry_point.name),
-                                                  id_state.function->output_function_name));
+                if(entry_point.entry_point.name != entry_point_name
+                   || entry_point.entry_point.execution_model != execution_model)
+                    continue;
+                if(!entry_function_name.empty())
+                    throw Parser_error(entry_point.instruction_start_index,
+                                       entry_point.instruction_start_index,
+                                       "duplicate entry point: "
+                                           + std::string(spirv::get_enumerant_name(execution_model))
+                                           + " \""
+                                           + std::string(entry_point_name)
+                                           + "\"");
+                entry_function_name = generate_entry_function(entry_point);
             }
         }
-        Converted_module retval(std::move(module),
-                                std::move(entry_points),
+        if(entry_function_name.empty())
+            throw Parser_error(0,
+                               0,
+                               "can't find entry point: "
+                                   + std::string(spirv::get_enumerant_name(execution_model))
+                                   + " \""
+                                   + std::string(entry_point_name)
+                                   + "\"");
+        return Converted_module(std::move(module),
+                                std::move(entry_function_name),
                                 std::move(io_struct),
                                 inputs_member,
                                 std::move(inputs_struct),
                                 outputs_member,
                                 std::move(outputs_struct));
-        return retval;
     }
     virtual void handle_header(unsigned version_number_major,
                                unsigned version_number_minor,
@@ -2430,7 +2527,7 @@ void Spirv_to_llvm::handle_instruction_op_type_struct(Op_type_struct instruction
             state.decorations,
             context,
             ::LLVMGetModuleDataLayout(module.get()),
-            get_prefixed_name(get_name(instruction.result)).c_str(),
+            get_prefixed_name(get_name(instruction.result), false).c_str(),
             instruction_start_index,
             std::move(members));
         break;
@@ -2852,7 +2949,7 @@ void Spirv_to_llvm::handle_instruction_op_function(Op_function instruction,
         auto function_name = get_name(current_function_id);
         if(function_name.empty() && state.op_entry_points.size() == 1)
             function_name = std::string(state.op_entry_points[0].entry_point.name);
-        function_name = get_or_make_prefixed_name(std::move(function_name));
+        function_name = get_or_make_prefixed_name(std::move(function_name), false);
         auto function = ::LLVMAddFunction(
             module.get(), function_name.c_str(), function_type->get_or_make_type().type);
         llvm_wrapper::Module::set_function_target_machine(function, target_machine);
@@ -8810,11 +8907,14 @@ void Spirv_to_llvm::handle_instruction_glsl_std_450_op_n_clamp(Glsl_std_450_op_n
 
 Converted_module spirv_to_llvm(::LLVMContextRef context,
                                ::LLVMTargetMachineRef target_machine,
-                               const Word *shader_words,
+                               const spirv::Word *shader_words,
                                std::size_t shader_size,
-                               std::uint64_t shader_id)
+                               std::uint64_t shader_id,
+                               spirv::Execution_model execution_model,
+                               util::string_view entry_point_name)
 {
-    return Spirv_to_llvm(context, target_machine, shader_id).run(shader_words, shader_size);
+    return Spirv_to_llvm(context, target_machine, shader_id, execution_model, entry_point_name)
+        .run(shader_words, shader_size);
 }
 }
 }