From 5bb91e58b7df90e0fbc1e7747cabbd18062eacf2 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 18 Aug 2017 00:28:42 -0700 Subject: [PATCH] new orc compile stack complete new orc compile stack supports debugging JIT code using gdb ready to add optimizations --- src/llvm_wrapper/CMakeLists.txt | 1 + src/llvm_wrapper/llvm_wrapper.cpp | 27 +++- src/llvm_wrapper/llvm_wrapper.h | 10 +- src/llvm_wrapper/orc_compile_stack.cpp | 216 +++++++++++++++++++++---- src/llvm_wrapper/orc_compile_stack.h | 33 ++-- src/pipeline/pipeline.cpp | 32 +++- src/pipeline/pipeline.h | 4 + src/spirv_to_llvm/spirv_to_llvm.h | 8 +- 8 files changed, 282 insertions(+), 49 deletions(-) diff --git a/src/llvm_wrapper/CMakeLists.txt b/src/llvm_wrapper/CMakeLists.txt index 8a56efc..2032751 100644 --- a/src/llvm_wrapper/CMakeLists.txt +++ b/src/llvm_wrapper/CMakeLists.txt @@ -29,3 +29,4 @@ else() set(llvm_libraries LLVM) endif() target_link_libraries(vulkan_cpu_llvm_wrapper util ${llvm_libraries}) +set_source_files_properties(orc_compile_stack.cpp PROPERTIES COMPILE_FLAGS -fno-rtti) # prevent link errors with missing type info diff --git a/src/llvm_wrapper/llvm_wrapper.cpp b/src/llvm_wrapper/llvm_wrapper.cpp index 6829317..24b579d 100644 --- a/src/llvm_wrapper/llvm_wrapper.cpp +++ b/src/llvm_wrapper/llvm_wrapper.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,12 @@ namespace vulkan_cpu { namespace llvm_wrapper { +// implement the unwrap functions that aren't in public llvm headers +static llvm::TargetMachine *unwrap(::LLVMTargetMachineRef v) noexcept +{ + return reinterpret_cast(v); +} + void Context::init_helper() { if(!::LLVMIsMultithreaded()) @@ -92,18 +99,34 @@ std::size_t Target_data::get_pointer_alignment(::LLVMTargetDataRef td) noexcept return llvm::unwrap(td)->getPointerABIAlignment(0); } -Target_machine Target_machine::create_native_target_machine() +Target_machine Target_machine::create_native_target_machine(::LLVMCodeGenOptLevel code_gen_level) { auto target = Target::get_native_target(); return Target_machine(::LLVMCreateTargetMachine(target.get(), Target::get_process_target_triple().get(), Target::get_host_cpu_name().get(), Target::get_host_cpu_features().get(), - ::LLVMCodeGenLevelDefault, + code_gen_level, ::LLVMRelocDefault, ::LLVMCodeModelJITDefault)); } +::LLVMCodeGenOptLevel Target_machine::get_code_gen_opt_level(::LLVMTargetMachineRef tm) noexcept +{ + switch(unwrap(tm)->getOptLevel()) + { + case llvm::CodeGenOpt::Level::None: + return ::LLVMCodeGenLevelNone; + case llvm::CodeGenOpt::Level::Less: + return ::LLVMCodeGenLevelLess; + case llvm::CodeGenOpt::Level::Default: + return ::LLVMCodeGenLevelDefault; + case llvm::CodeGenOpt::Level::Aggressive: + return ::LLVMCodeGenLevelAggressive; + } + return ::LLVMCodeGenLevelDefault; +} + void Module::set_target_machine(::LLVMModuleRef module, ::LLVMTargetMachineRef target_machine) { ::LLVMSetTarget(module, Target_machine::get_target_triple(target_machine).get()); diff --git a/src/llvm_wrapper/llvm_wrapper.h b/src/llvm_wrapper/llvm_wrapper.h index 8ab6604..e3555f8 100644 --- a/src/llvm_wrapper/llvm_wrapper.h +++ b/src/llvm_wrapper/llvm_wrapper.h @@ -259,7 +259,8 @@ struct Target_machine_deleter struct Target_machine : public Wrapper<::LLVMTargetMachineRef, Target_machine_deleter> { using Wrapper::Wrapper; - static Target_machine create_native_target_machine(); + static Target_machine create_native_target_machine( + ::LLVMCodeGenOptLevel code_gen_level = ::LLVMCodeGenLevelDefault); static Target get_target(::LLVMTargetMachineRef tm) { return Target(::LLVMGetTargetMachineTarget(tm)); @@ -300,6 +301,11 @@ struct Target_machine : public Wrapper<::LLVMTargetMachineRef, Target_machine_de { return get_feature_string(get()); } + static ::LLVMCodeGenOptLevel get_code_gen_opt_level(::LLVMTargetMachineRef tm) noexcept; + ::LLVMCodeGenOptLevel get_code_gen_opt_level() const noexcept + { + return get_code_gen_opt_level(get()); + } }; struct Module_deleter @@ -649,7 +655,7 @@ struct Create_llvm_type } }; -template +template struct Create_llvm_type { ::LLVMTypeRef operator()(::LLVMContextRef context) const diff --git a/src/llvm_wrapper/orc_compile_stack.cpp b/src/llvm_wrapper/orc_compile_stack.cpp index 3b437e9..c0f6720 100644 --- a/src/llvm_wrapper/orc_compile_stack.cpp +++ b/src/llvm_wrapper/orc_compile_stack.cpp @@ -24,9 +24,16 @@ #include #include #include +#include #include +#include +#include +#include +#include #include #include +#include +#include #if LLVM_VERSION_MAJOR != 4 || LLVM_VERSION_MINOR != 0 #error Orc compile stack is not yet implemented for this version of LLVM @@ -36,13 +43,14 @@ namespace vulkan_cpu { namespace llvm_wrapper { -namespace -{ // implement the unwrap functions that aren't in public llvm headers -llvm::TargetMachine *unwrap(::LLVMTargetMachineRef v) noexcept +static llvm::TargetMachine *unwrap(::LLVMTargetMachineRef v) noexcept { return reinterpret_cast(v); } +static ::LLVMTargetMachineRef wrap(llvm::TargetMachine *v) noexcept +{ + return reinterpret_cast<::LLVMTargetMachineRef>(v); } class Orc_compile_stack_implementation @@ -53,29 +61,180 @@ class Orc_compile_stack_implementation Orc_compile_stack_implementation &operator=(Orc_compile_stack_implementation &&) = delete; private: + typedef Orc_compile_stack::Symbol_resolver_callback Symbol_resolver_callback; + typedef Orc_compile_stack::Module_handle Module_handle; + +private: + // implement a wrapper for llvm::orc::ObjectLinkingLayer + // in order to tell GDB about the contained objects + class My_object_linking_layer + { + My_object_linking_layer(const My_object_linking_layer &) = delete; + My_object_linking_layer(My_object_linking_layer &&) = delete; + My_object_linking_layer &operator=(const My_object_linking_layer &) = delete; + My_object_linking_layer &operator=(My_object_linking_layer &&) = delete; + + private: + class On_loaded_functor + { + friend class My_object_linking_layer; + + private: + My_object_linking_layer *my_object_linking_layer; + explicit On_loaded_functor(My_object_linking_layer *my_object_linking_layer) noexcept + : my_object_linking_layer(my_object_linking_layer) + { + } + + public: + void operator()( + llvm::orc::ObjectLinkingLayerBase::ObjSetHandleT, + const std:: + vector>> + &object_set, + const std::vector> + &load_result) + { + assert(object_set.size() == load_result.size()); + for(std::size_t i = 0; i < object_set.size(); i++) + my_object_linking_layer->handle_loaded_object(*object_set[i]->getBinary(), + *load_result[i]); + } + }; + + private: + Module_handle create_module_handle() noexcept + { + return next_module_handle++; + } + static std::vector> make_jit_event_listener_list() + { + std::vector> retval; + auto static_deleter = [](llvm::JITEventListener *) + { + }; + if(auto *v = llvm::JITEventListener::createGDBRegistrationListener()) + { + // createGDBRegistrationListener returns a static object + retval.push_back(std::shared_ptr(v, static_deleter)); + } + if(auto *v = llvm::JITEventListener::createIntelJITEventListener()) + { + retval.push_back(std::shared_ptr(v)); + } + if(auto *v = llvm::JITEventListener::createOProfileJITEventListener()) + { + retval.push_back(std::shared_ptr(v)); + } + return retval; + } + void handle_loaded_object(const llvm::object::ObjectFile &object, + const llvm::RuntimeDyld::LoadedObjectInfo &load_info) + { + loaded_object_set.insert(&object); + for(auto &jit_event_listener : jit_event_listener_list) + jit_event_listener->NotifyObjectEmitted(object, load_info); + } + + public: + My_object_linking_layer() + : jit_event_listener_list(make_jit_event_listener_list()), + object_linking_layer(On_loaded_functor(this)) + { + } + ~My_object_linking_layer() + { + for(auto i = loaded_object_set.begin(); i != loaded_object_set.end();) + { + for(auto &jit_event_listener : jit_event_listener_list) + jit_event_listener->NotifyFreeingObject(**i); + i = loaded_object_set.erase(i); + } + } + typedef Module_handle ObjSetHandleT; + llvm::JITSymbol findSymbol(const std::string &name, bool exported_symbols_only) + { + return object_linking_layer.findSymbol(name, exported_symbols_only); + } + template + Module_handle addObjectSet( + std::vector>> + object_set, + std::unique_ptr memory_manager, + Symbol_resolver_pointer symbol_resolver) + { + auto retval = create_module_handle(); + auto &handle = handle_map[retval]; + handle = object_linking_layer.addObjectSet( + std::move(object_set), std::move(memory_manager), std::move(symbol_resolver)); + return retval; + } + + private: + Module_handle next_module_handle = 1; + std::vector> jit_event_listener_list; + llvm::orc::ObjectLinkingLayer object_linking_layer; + std::unordered_map handle_map; + std::unordered_multiset loaded_object_set; + }; + typedef std::function(std::unique_ptr)> + Optimize_function; + +private: + Orc_compile_stack::Optimize_function optimize_function; std::unique_ptr target_machine; - const llvm::DataLayout data_layout; - llvm::orc::ObjectLinkingLayer<> object_linking_layer; + My_object_linking_layer object_linking_layer; + llvm::orc::IRCompileLayer compile_layer; + llvm::orc::IRTransformLayer optimize_layer; public: - explicit Orc_compile_stack_implementation(Target_machine target_machine_in) - : target_machine(unwrap(target_machine_in.release())), - data_layout(target_machine->createDataLayout()) + explicit Orc_compile_stack_implementation( + Target_machine target_machine_in, Orc_compile_stack::Optimize_function optimize_function) + : optimize_function(std::move(optimize_function)), + target_machine(unwrap(target_machine_in.release())), + object_linking_layer(), + compile_layer(object_linking_layer, llvm::orc::SimpleCompiler(*target_machine)), + optimize_layer(compile_layer, + [this](std::unique_ptr module) + { + if(this->optimize_function) + { + auto rewrapped_module = Module(llvm::wrap(module.release())); + rewrapped_module = this->optimize_function( + std::move(rewrapped_module), wrap(target_machine.get())); + return std::unique_ptr( + llvm::unwrap(rewrapped_module.release())); + } + return module; + }) { -#warning finish - assert(!"finish"); } - void add_eagerly_compiled_ir(Module module, - ::LLVMOrcSymbolResolverFn symbol_resolver_callback, - void *symbol_resolver_user_data) + Module_handle add_eagerly_compiled_ir(Module module, + Symbol_resolver_callback symbol_resolver_callback, + void *symbol_resolver_user_data) { -#warning finish - assert(!"finish"); + auto resolver = llvm::orc::createLambdaResolver( + [this](const std::string &name) + { + return compile_layer.findSymbol(name, false); + }, + [symbol_resolver_user_data, symbol_resolver_callback](const std::string &name) + { + return llvm::JITSymbol(symbol_resolver_callback(name, symbol_resolver_user_data), + llvm::JITSymbolFlags::Exported); + }); + std::vector> module_set; + module_set.reserve(1); + module_set.push_back(std::unique_ptr(llvm::unwrap(module.release()))); + return optimize_layer.addModuleSet(std::move(module_set), + std::make_unique(), + std::move(resolver)); } - std::uintptr_t get_symbol_address(const char *symbol_name) + std::uintptr_t get_symbol_address(const std::string &symbol_name) { -#warning finish - assert(!"finish"); + auto symbol = compile_layer.findSymbol(symbol_name, true); + if(symbol) + return symbol.getAddress(); return 0; } }; @@ -85,24 +244,25 @@ void Orc_compile_stack_deleter::operator()(Orc_compile_stack_ref v) const noexce delete v; } -Orc_compile_stack Orc_compile_stack::create(Target_machine target_machine) +Orc_compile_stack Orc_compile_stack::create(Target_machine target_machine, + Optimize_function optimize_function) { -#warning finish - assert(!"finish"); - return {}; + return Orc_compile_stack(new Orc_compile_stack_implementation(std::move(target_machine), + std::move(optimize_function))); } -void Orc_compile_stack::add_eagerly_compiled_ir(Orc_compile_stack_ref orc_compile_stack, - Module module, - ::LLVMOrcSymbolResolverFn symbol_resolver_callback, - void *symbol_resolver_user_data) +Orc_compile_stack::Module_handle Orc_compile_stack::add_eagerly_compiled_ir( + Orc_compile_stack_ref orc_compile_stack, + Module module, + Symbol_resolver_callback symbol_resolver_callback, + void *symbol_resolver_user_data) { - orc_compile_stack->add_eagerly_compiled_ir( + return orc_compile_stack->add_eagerly_compiled_ir( std::move(module), symbol_resolver_callback, symbol_resolver_user_data); } std::uintptr_t Orc_compile_stack::get_symbol_address(Orc_compile_stack_ref orc_compile_stack, - const char *symbol_name) + const std::string &symbol_name) { return orc_compile_stack->get_symbol_address(symbol_name); } diff --git a/src/llvm_wrapper/orc_compile_stack.h b/src/llvm_wrapper/orc_compile_stack.h index 413d314..80042c8 100644 --- a/src/llvm_wrapper/orc_compile_stack.h +++ b/src/llvm_wrapper/orc_compile_stack.h @@ -24,6 +24,9 @@ #define LLVM_WRAPPER_ORC_COMPILE_STACK_H_ #include "llvm_wrapper.h" +#include +#include +#include namespace vulkan_cpu { @@ -41,31 +44,35 @@ struct Orc_compile_stack_deleter struct Orc_compile_stack : public Wrapper { using Wrapper::Wrapper; - static Orc_compile_stack create(Target_machine target_machine); - static void add_eagerly_compiled_ir(Orc_compile_stack_ref orc_compile_stack, - Module module, - ::LLVMOrcSymbolResolverFn symbol_resolver_callback, - void *symbol_resolver_user_data); - void add_eagerly_compiled_ir(Module module, - ::LLVMOrcSymbolResolverFn symbol_resolver_callback, - void *symbol_resolver_user_data) + typedef std::uintptr_t (*Symbol_resolver_callback)(const std::string &name, void *user_data); + typedef std::uint64_t Module_handle; + typedef std::function Optimize_function; + static Orc_compile_stack create(Target_machine target_machine, + Optimize_function optimize_function = nullptr); + static Module_handle add_eagerly_compiled_ir(Orc_compile_stack_ref orc_compile_stack, + Module module, + Symbol_resolver_callback symbol_resolver_callback, + void *symbol_resolver_user_data); + Module_handle add_eagerly_compiled_ir(Module module, + Symbol_resolver_callback symbol_resolver_callback, + void *symbol_resolver_user_data) { - add_eagerly_compiled_ir( + return add_eagerly_compiled_ir( get(), std::move(module), symbol_resolver_callback, symbol_resolver_user_data); } static std::uintptr_t get_symbol_address(Orc_compile_stack_ref orc_compile_stack, - const char *symbol_name); + const std::string &symbol_name); template - static T *get_symbol(Orc_compile_stack_ref orc_compile_stack, const char *symbol_name) + static T *get_symbol(Orc_compile_stack_ref orc_compile_stack, const std::string &symbol_name) { return reinterpret_cast(get_symbol_address(orc_compile_stack, symbol_name)); } - std::uintptr_t get_symbol_address(const char *symbol_name) + std::uintptr_t get_symbol_address(const std::string &symbol_name) { return get_symbol_address(get(), symbol_name); } template - T *get_symbol(const char *symbol_name) + T *get_symbol(const std::string &symbol_name) { return get_symbol(get(), symbol_name); } diff --git a/src/pipeline/pipeline.cpp b/src/pipeline/pipeline.cpp index 84182a5..70cae74 100644 --- a/src/pipeline/pipeline.cpp +++ b/src/pipeline/pipeline.cpp @@ -23,6 +23,7 @@ #include "pipeline.h" #include "spirv_to_llvm/spirv_to_llvm.h" #include "llvm_wrapper/llvm_wrapper.h" +#include "llvm_wrapper/orc_compile_stack.h" #include "vulkan/util.h" #include "util/soft_float.h" #include "json/json.h" @@ -78,11 +79,31 @@ Pipeline_layout_handle Pipeline_layout_handle::make( return Pipeline_layout_handle(new Pipeline_layout()); } +llvm_wrapper::Module Pipeline::optimize_module(llvm_wrapper::Module module, + ::LLVMTargetMachineRef target_machine) +{ + switch(llvm_wrapper::Target_machine::get_code_gen_opt_level(target_machine)) + { + case ::LLVMCodeGenLevelNone: + case ::LLVMCodeGenLevelLess: + break; + case ::LLVMCodeGenLevelDefault: + case ::LLVMCodeGenLevelAggressive: + { +#warning finish implementing module optimizations + std::cerr << "optimized module:" << std::endl; + ::LLVMDumpModule(module.get()); + break; + } + } + return module; +} + struct Graphics_pipeline::Implementation { llvm_wrapper::Context llvm_context = llvm_wrapper::Context::create(); spirv_to_llvm::Jit_symbol_resolver jit_symbol_resolver; - llvm_wrapper::Orc_jit_stack jit_stack; + llvm_wrapper::Orc_compile_stack jit_stack; llvm_wrapper::Target_data data_layout; std::vector compiled_shaders; std::shared_ptr vertex_shader_output_struct; @@ -354,7 +375,11 @@ std::unique_ptr Graphics_pipeline::make( throw std::runtime_error("creating derived pipelines is not implemented"); } auto implementation = std::make_shared(); - auto llvm_target_machine = llvm_wrapper::Target_machine::create_native_target_machine(); + auto optimization_level = ::LLVMCodeGenLevelDefault; + if(create_info.flags & VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT) + optimization_level = ::LLVMCodeGenLevelNone; + auto llvm_target_machine = + llvm_wrapper::Target_machine::create_native_target_machine(optimization_level); implementation->compiled_shaders.reserve(create_info.stageCount); for(std::size_t i = 0; i < create_info.stageCount; i++) { @@ -394,7 +419,8 @@ std::unique_ptr Graphics_pipeline::make( implementation->compiled_shaders.push_back(std::move(compiled_shader)); } implementation->data_layout = llvm_target_machine.create_target_data_layout(); - implementation->jit_stack = llvm_wrapper::Orc_jit_stack::create(std::move(llvm_target_machine)); + implementation->jit_stack = + llvm_wrapper::Orc_compile_stack::create(std::move(llvm_target_machine), optimize_module); Vertex_shader_function vertex_shader_function = nullptr; std::size_t vertex_shader_output_struct_size = 0; for(auto &compiled_shader : implementation->compiled_shaders) diff --git a/src/pipeline/pipeline.h b/src/pipeline/pipeline.h index 80ccedc..d62eb5c 100644 --- a/src/pipeline/pipeline.h +++ b/src/pipeline/pipeline.h @@ -179,6 +179,10 @@ public: { return reinterpret_cast(pipeline); } + +protected: + static llvm_wrapper::Module optimize_module(llvm_wrapper::Module module, + ::LLVMTargetMachineRef target_machine); }; inline VkPipeline to_handle(Pipeline *pipeline) noexcept diff --git a/src/spirv_to_llvm/spirv_to_llvm.h b/src/spirv_to_llvm/spirv_to_llvm.h index 1291233..1171b7a 100644 --- a/src/spirv_to_llvm/spirv_to_llvm.h +++ b/src/spirv_to_llvm/spirv_to_llvm.h @@ -33,6 +33,7 @@ #include #include #include "llvm_wrapper/llvm_wrapper.h" +#include "util/string_view.h" namespace vulkan_cpu { @@ -589,7 +590,7 @@ struct Converted_module struct Jit_symbol_resolver { typedef void (*Resolved_symbol)(); - Resolved_symbol resolve(const char *name) + Resolved_symbol resolve(util::string_view name) { #warning finish implementing return nullptr; @@ -599,6 +600,11 @@ struct Jit_symbol_resolver return reinterpret_cast( static_cast(user_data)->resolve(name)); } + static std::uintptr_t resolve(const std::string &name, void *user_data) noexcept + { + return reinterpret_cast( + static_cast(user_data)->resolve(name)); + } }; class Spirv_to_llvm; -- 2.30.2