working on implementing LLVM 7.0 shader compiler backend
[kazan.git] / shader-compiler-llvm-7 / src / backend.rs
index 74af1b523550289731a9f30d4261c484dadce264..cbb7b51365c9edb16faa12208311fb29cb4e2516 100644 (file)
@@ -1,22 +1,38 @@
 // SPDX-License-Identifier: LGPL-2.1-or-later
 // Copyright 2018 Jacob Lifshay
 use llvm_sys;
-use shader_compiler::backend::*;
+use shader_compiler::backend;
+use std::cell::Cell;
 use std::ffi::{CStr, CString};
 use std::fmt;
 use std::ops::Deref;
 use std::os::raw::{c_char, c_uint};
+use std::ptr::null_mut;
 use std::ptr::NonNull;
+use std::sync::{Once, ONCE_INIT};
+
+fn to_bool(v: llvm_sys::prelude::LLVMBool) -> bool {
+    v != 0
+}
 
 #[derive(Clone)]
-pub struct LLVM7ShaderCompilerConfig {
+pub struct LLVM7CompilerConfig {
     pub variable_vector_length_multiplier: u32,
+    pub optimization_mode: backend::OptimizationMode,
 }
 
-impl Default for LLVM7ShaderCompilerConfig {
+impl Default for LLVM7CompilerConfig {
     fn default() -> Self {
+        backend::CompilerIndependentConfig::default().into()
+    }
+}
+
+impl From<backend::CompilerIndependentConfig> for LLVM7CompilerConfig {
+    fn from(v: backend::CompilerIndependentConfig) -> Self {
+        let backend::CompilerIndependentConfig { optimization_mode } = v;
         Self {
             variable_vector_length_multiplier: 1,
+            optimization_mode,
         }
     }
 }
@@ -53,7 +69,13 @@ impl LLVM7String {
         LLVM7String(v)
     }
     unsafe fn from_ptr(v: *mut c_char) -> Option<Self> {
-        NonNull::new(v).map(LLVM7String)
+        NonNull::new(v).map(|v| Self::from_nonnull(v))
+    }
+}
+
+impl fmt::Debug for LLVM7String {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        (**self).fmt(f)
     }
 }
 
@@ -71,14 +93,16 @@ impl fmt::Debug for LLVM7Type {
     }
 }
 
-impl<'a> Type<'a> for LLVM7Type {}
+impl<'a> backend::types::Type<'a> for LLVM7Type {
+    type Context = LLVM7Context;
+}
 
 pub struct LLVM7TypeBuilder {
     context: llvm_sys::prelude::LLVMContextRef,
     variable_vector_length_multiplier: u32,
 }
 
-impl<'a> TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
+impl<'a> backend::types::TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
     fn build_bool(&self) -> LLVM7Type {
         unsafe { LLVM7Type(llvm_sys::core::LLVMInt1TypeInContext(self.context)) }
     }
@@ -107,10 +131,11 @@ impl<'a> TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
         assert_eq!(count as u32 as usize, count);
         unsafe { LLVM7Type(llvm_sys::core::LLVMArrayType(element.0, count as u32)) }
     }
-    fn build_vector(&self, element: LLVM7Type, length: VectorLength) -> LLVM7Type {
+    fn build_vector(&self, element: LLVM7Type, length: backend::types::VectorLength) -> LLVM7Type {
+        use self::backend::types::VectorLength::*;
         let length = match length {
-            VectorLength::Fixed { length } => length,
-            VectorLength::Variable { base_length } => base_length
+            Fixed { length } => length,
+            Variable { base_length } => base_length
                 .checked_mul(self.variable_vector_length_multiplier)
                 .unwrap(),
         };
@@ -145,31 +170,128 @@ impl<'a> TypeBuilder<'a, LLVM7Type> for LLVM7TypeBuilder {
     }
 }
 
+#[derive(Clone)]
+#[repr(transparent)]
+pub struct LLVM7Value(llvm_sys::prelude::LLVMValueRef);
+
+impl fmt::Debug for LLVM7Value {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        unsafe {
+            let string = LLVM7String::from_ptr(llvm_sys::core::LLVMPrintValueToString(self.0))
+                .ok_or(fmt::Error)?;
+            f.write_str(&string.to_string_lossy())
+        }
+    }
+}
+
+impl<'a> backend::Value<'a> for LLVM7Value {
+    type Context = LLVM7Context;
+}
+
+#[derive(Clone)]
+#[repr(transparent)]
+pub struct LLVM7BasicBlock(llvm_sys::prelude::LLVMBasicBlockRef);
+
+impl fmt::Debug for LLVM7BasicBlock {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        use self::backend::BasicBlock;
+        unsafe {
+            let string =
+                LLVM7String::from_ptr(llvm_sys::core::LLVMPrintValueToString(self.as_value().0))
+                    .ok_or(fmt::Error)?;
+            f.write_str(&string.to_string_lossy())
+        }
+    }
+}
+
+impl<'a> backend::BasicBlock<'a> for LLVM7BasicBlock {
+    type Context = LLVM7Context;
+    fn as_value(&self) -> LLVM7Value {
+        unsafe { LLVM7Value(llvm_sys::core::LLVMBasicBlockAsValue(self.0)) }
+    }
+}
+
+impl<'a> backend::BuildableBasicBlock<'a> for LLVM7BasicBlock {
+    type Context = LLVM7Context;
+    fn as_basic_block(&self) -> LLVM7BasicBlock {
+        self.clone()
+    }
+}
+
+pub struct LLVM7Function {
+    context: llvm_sys::prelude::LLVMContextRef,
+    function: llvm_sys::prelude::LLVMValueRef,
+}
+
+impl fmt::Debug for LLVM7Function {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        unsafe {
+            let string =
+                LLVM7String::from_ptr(llvm_sys::core::LLVMPrintValueToString(self.function))
+                    .ok_or(fmt::Error)?;
+            f.write_str(&string.to_string_lossy())
+        }
+    }
+}
+
+impl<'a> backend::Function<'a> for LLVM7Function {
+    type Context = LLVM7Context;
+    fn as_value(&self) -> LLVM7Value {
+        LLVM7Value(self.function)
+    }
+    fn append_new_basic_block(&mut self, name: Option<&str>) -> LLVM7BasicBlock {
+        let name = CString::new(name.unwrap_or("")).unwrap();
+        unsafe {
+            LLVM7BasicBlock(llvm_sys::core::LLVMAppendBasicBlockInContext(
+                self.context,
+                self.function,
+                name.as_ptr(),
+            ))
+        }
+    }
+}
+
 pub struct LLVM7Context {
     context: llvm_sys::prelude::LLVMContextRef,
-    config: LLVM7ShaderCompilerConfig,
+    modules: Cell<Vec<llvm_sys::prelude::LLVMModuleRef>>,
+    config: LLVM7CompilerConfig,
 }
 
 impl Drop for LLVM7Context {
     fn drop(&mut self) {
         unsafe {
+            for module in self.modules.get_mut().drain(..) {
+                llvm_sys::core::LLVMDisposeModule(module);
+            }
             llvm_sys::core::LLVMContextDispose(self.context);
         }
     }
 }
 
-impl<'a> Context<'a> for LLVM7Context {
+impl<'a> backend::Context<'a> for LLVM7Context {
+    type Value = LLVM7Value;
+    type BasicBlock = LLVM7BasicBlock;
+    type BuildableBasicBlock = LLVM7BasicBlock;
+    type Function = LLVM7Function;
     type Type = LLVM7Type;
     type TypeBuilder = LLVM7TypeBuilder;
     type Module = LLVM7Module;
-    type Builder = LLVM7Builder;
+    type VerifiedModule = LLVM7Module;
+    type AttachedBuilder = LLVM7Builder;
+    type DetachedBuilder = LLVM7Builder;
     fn create_module(&self, name: &str) -> LLVM7Module {
         let name = CString::new(name).unwrap();
+        let mut modules = self.modules.take();
+        modules.reserve(1); // so we don't unwind without freeing the new module
         unsafe {
-            LLVM7Module(llvm_sys::core::LLVMModuleCreateWithNameInContext(
-                name.as_ptr(),
-                self.context,
-            ))
+            let module =
+                llvm_sys::core::LLVMModuleCreateWithNameInContext(name.as_ptr(), self.context);
+            modules.push(module);
+            self.modules.set(modules);
+            LLVM7Module {
+                context: self.context,
+                module,
+            }
         }
     }
     fn create_builder(&self) -> LLVM7Builder {
@@ -194,48 +316,230 @@ impl Drop for LLVM7Builder {
     }
 }
 
-impl<'a> Builder<'a> for LLVM7Builder {}
+impl<'a> backend::AttachedBuilder<'a> for LLVM7Builder {
+    type Context = LLVM7Context;
+    fn current_basic_block(&self) -> LLVM7BasicBlock {
+        unsafe { LLVM7BasicBlock(llvm_sys::core::LLVMGetInsertBlock(self.0)) }
+    }
+    fn build_return(self, value: Option<LLVM7Value>) -> LLVM7Builder {
+        unsafe {
+            match value {
+                Some(value) => llvm_sys::core::LLVMBuildRet(self.0, value.0),
+                None => llvm_sys::core::LLVMBuildRetVoid(self.0),
+            };
+            llvm_sys::core::LLVMClearInsertionPosition(self.0);
+        }
+        self
+    }
+}
 
-#[repr(transparent)]
-pub struct LLVM7Module(llvm_sys::prelude::LLVMModuleRef);
+impl<'a> backend::DetachedBuilder<'a> for LLVM7Builder {
+    type Context = LLVM7Context;
+    fn attach(self, basic_block: LLVM7BasicBlock) -> LLVM7Builder {
+        unsafe {
+            llvm_sys::core::LLVMPositionBuilderAtEnd(self.0, basic_block.0);
+        }
+        self
+    }
+}
 
-impl Drop for LLVM7Module {
-    fn drop(&mut self) {
+pub struct LLVM7Module {
+    context: llvm_sys::prelude::LLVMContextRef,
+    module: llvm_sys::prelude::LLVMModuleRef,
+}
+
+impl fmt::Debug for LLVM7Module {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         unsafe {
-            llvm_sys::core::LLVMDisposeModule(self.0);
+            let string =
+                LLVM7String::from_ptr(llvm_sys::core::LLVMPrintModuleToString(self.module))
+                    .ok_or(fmt::Error)?;
+            f.write_str(&string.to_string_lossy())
         }
     }
 }
 
-impl<'a> Module<'a> for LLVM7Module {
+impl<'a> backend::Module<'a> for LLVM7Module {
+    type Context = LLVM7Context;
     fn set_source_file_name(&mut self, source_file_name: &str) {
         unsafe {
             llvm_sys::core::LLVMSetSourceFileName(
-                self.0,
+                self.module,
                 source_file_name.as_ptr() as *const c_char,
                 source_file_name.len(),
             )
         }
     }
+    fn add_function(&mut self, name: &str, ty: LLVM7Type) -> LLVM7Function {
+        let name = CString::new(name).unwrap();
+        unsafe {
+            LLVM7Function {
+                context: self.context,
+                function: llvm_sys::core::LLVMAddFunction(self.module, name.as_ptr(), ty.0),
+            }
+        }
+    }
+    fn verify(self) -> Result<LLVM7Module, backend::VerificationFailure<'a, LLVM7Module>> {
+        unsafe {
+            let mut message = null_mut();
+            match to_bool(llvm_sys::analysis::LLVMVerifyModule(
+                self.module,
+                llvm_sys::analysis::LLVMVerifierFailureAction::LLVMReturnStatusAction,
+                &mut message,
+            )) {
+                broken if broken != false => {
+                    let message = LLVM7String::from_ptr(message).unwrap();
+                    let message = message.to_string_lossy();
+                    Err(backend::VerificationFailure::new(
+                        self,
+                        message.into_owned(),
+                    ))
+                }
+                _ => Ok(self),
+            }
+        }
+    }
+    unsafe fn to_verified_module_unchecked(self) -> LLVM7Module {
+        self
+    }
 }
 
-pub struct LLVM7ShaderCompiler;
+impl<'a> backend::VerifiedModule<'a> for LLVM7Module {
+    type Context = LLVM7Context;
+    fn into_module(self) -> LLVM7Module {
+        self
+    }
+}
+
+struct LLVM7TargetMachine(llvm_sys::target_machine::LLVMTargetMachineRef);
+
+impl Drop for LLVM7TargetMachine {
+    fn drop(&mut self) {
+        unsafe {
+            llvm_sys::target_machine::LLVMDisposeTargetMachine(self.0);
+        }
+    }
+}
 
-impl ShaderCompiler for LLVM7ShaderCompiler {
-    type Config = LLVM7ShaderCompilerConfig;
-    fn name() -> &'static str {
+impl LLVM7TargetMachine {
+    fn take(mut self) -> llvm_sys::target_machine::LLVMTargetMachineRef {
+        let retval = self.0;
+        self.0 = null_mut();
+        retval
+    }
+}
+
+struct LLVM7OrcJITStack(llvm_sys::orc::LLVMOrcJITStackRef);
+
+impl Drop for LLVM7OrcJITStack {
+    fn drop(&mut self) {
+        unsafe {
+            match llvm_sys::orc::LLVMOrcDisposeInstance(self.0) {
+                llvm_sys::orc::LLVMOrcErrorCode::LLVMOrcErrSuccess => {}
+                llvm_sys::orc::LLVMOrcErrorCode::LLVMOrcErrGeneric => {
+                    panic!("LLVMOrcDisposeInstance failed");
+                }
+            }
+        }
+    }
+}
+
+fn initialize_native_target() {
+    static ONCE: Once = ONCE_INIT;
+    ONCE.call_once(|| unsafe {
+        assert_eq!(llvm_sys::target::LLVM_InitializeNativeTarget(), 0);
+        assert_eq!(llvm_sys::target::LLVM_InitializeNativeAsmParser(), 0);
+    });
+}
+
+extern "C" fn symbol_resolver_fn<Void>(name: *const c_char, _lookup_context: *mut Void) -> u64 {
+    let name = unsafe { CStr::from_ptr(name) };
+    panic!("symbol_resolver_fn is unimplemented: name = {:?}", name)
+}
+
+#[derive(Copy, Clone)]
+pub struct LLVM7Compiler;
+
+impl backend::Compiler for LLVM7Compiler {
+    type Config = LLVM7CompilerConfig;
+    fn name(self) -> &'static str {
         "LLVM 7"
     }
-    fn run_with_user<SCU: ShaderCompilerUser>(
-        shader_compiler_user: SCU,
-        config: LLVM7ShaderCompilerConfig,
-    ) -> SCU::ReturnType {
-        let context = unsafe {
-            LLVM7Context {
+    fn run<U: backend::CompilerUser>(
+        self,
+        user: U,
+        config: LLVM7CompilerConfig,
+    ) -> Result<Box<dyn backend::CompiledCode<U::FunctionKey>>, U::Error> {
+        unsafe {
+            initialize_native_target();
+            let context = LLVM7Context {
                 context: llvm_sys::core::LLVMContextCreate(),
-                config,
+                modules: Vec::new().into(),
+                config: config.clone(),
+            };
+            let backend::CompileInputs {
+                module,
+                callable_functions,
+            } = user.run(&context)?;
+            for callable_function in callable_functions.values() {
+                assert_eq!(
+                    llvm_sys::core::LLVMGetGlobalParent(callable_function.function),
+                    module.module
+                );
             }
-        };
-        shader_compiler_user.run_with_context(&context)
+            let target_triple =
+                LLVM7String::from_ptr(llvm_sys::target_machine::LLVMGetDefaultTargetTriple())
+                    .unwrap();
+            let mut target = null_mut();
+            let mut error = null_mut();
+            let success = !to_bool(llvm_sys::target_machine::LLVMGetTargetFromTriple(
+                target_triple.as_ptr(),
+                &mut target,
+                &mut error,
+            ));
+            if !success {
+                let error = LLVM7String::from_ptr(error).unwrap();
+                return Err(U::create_error(error.to_string_lossy().into()));
+            }
+            if !to_bool(llvm_sys::target_machine::LLVMTargetHasJIT(target)) {
+                return Err(U::create_error(format!(
+                    "target {:?} doesn't support JIT",
+                    target_triple
+                )));
+            }
+            let host_cpu_name =
+                LLVM7String::from_ptr(llvm_sys::target_machine::LLVMGetHostCPUName()).unwrap();
+            let host_cpu_features =
+                LLVM7String::from_ptr(llvm_sys::target_machine::LLVMGetHostCPUFeatures()).unwrap();
+            let target_machine =
+                LLVM7TargetMachine(llvm_sys::target_machine::LLVMCreateTargetMachine(
+                    target,
+                    target_triple.as_ptr(),
+                    host_cpu_name.as_ptr(),
+                    host_cpu_features.as_ptr(),
+                    match config.optimization_mode {
+                        backend::OptimizationMode::NoOptimizations => {
+                            llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelNone
+                        }
+                        backend::OptimizationMode::Normal => {
+                            llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelDefault
+                        }
+                    },
+                    llvm_sys::target_machine::LLVMRelocMode::LLVMRelocDefault,
+                    llvm_sys::target_machine::LLVMCodeModel::LLVMCodeModelJITDefault,
+                ));
+            assert!(!target_machine.0.is_null());
+            let orc_jit_stack =
+                LLVM7OrcJITStack(llvm_sys::orc::LLVMOrcCreateInstance(target_machine.take()));
+            let mut orc_module_handle = 0;
+            llvm_sys::orc::LLVMOrcAddEagerlyCompiledIR(
+                orc_jit_stack.0,
+                &mut orc_module_handle,
+                module.module,
+                Some(symbol_resolver_fn),
+                null_mut(),
+            );
+            unimplemented!()
+        }
     }
 }