From 565103a5aa685a6999dd9a4b6c78b02b7ab1d7f1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 17 Oct 2018 01:23:25 -0700 Subject: [PATCH] add function name validation --- shader-compiler-llvm-7/src/backend.rs | 19 ++++++++++++ shader-compiler-llvm-7/src/tests.rs | 42 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/shader-compiler-llvm-7/src/backend.rs b/shader-compiler-llvm-7/src/backend.rs index fb466d4..3e02ad7 100644 --- a/shader-compiler-llvm-7/src/backend.rs +++ b/shader-compiler-llvm-7/src/backend.rs @@ -4,6 +4,7 @@ use llvm; use shader_compiler::backend; use std::cell::RefCell; use std::collections::HashMap; +use std::collections::HashSet; use std::ffi::{CStr, CString}; use std::fmt; use std::hash::Hash; @@ -292,6 +293,7 @@ impl<'a> backend::Context<'a> for LLVM7Context { LLVM7Module { context: self.context.as_ref().unwrap().0, module: module_ref, + name_set: HashSet::new(), } } } @@ -379,6 +381,7 @@ impl Drop for OwnedContext { pub struct LLVM7Module { context: llvm::LLVMContextRef, module: llvm::LLVMModuleRef, + name_set: HashSet, } impl fmt::Debug for LLVM7Module { @@ -403,6 +406,22 @@ impl<'a> backend::Module<'a> for LLVM7Module { } } fn add_function(&mut self, name: &str, ty: LLVM7Type) -> LLVM7Function { + fn is_start_char(c: char) -> bool { + if c.is_ascii_alphabetic() { + true + } else { + match c { + '_' | '.' | '$' | '-' => true, + _ => false, + } + } + } + fn is_continue_char(c: char) -> bool { + is_start_char(c) || c.is_ascii_digit() + } + assert!(is_start_char(name.chars().next().unwrap())); + assert!(name.chars().all(is_continue_char)); + assert!(self.name_set.insert(name.into())); let name = CString::new(name).unwrap(); unsafe { LLVM7Function { diff --git a/shader-compiler-llvm-7/src/tests.rs b/shader-compiler-llvm-7/src/tests.rs index 4ce5036..ac532cf 100644 --- a/shader-compiler-llvm-7/src/tests.rs +++ b/shader-compiler-llvm-7/src/tests.rs @@ -55,4 +55,46 @@ mod tests { function(0); } } + + #[test] + fn test_names() { + const NAMES: &[&str] = &["main", "abc123-$._"]; + type GeneratedFunctionType = unsafe extern "C" fn(u32); + #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)] + struct Test; + impl CompilerUser for Test { + type FunctionKey = String; + type Error = String; + fn create_error(message: String) -> String { + message + } + fn run<'a, C: Context<'a>>( + self, + context: &'a C, + ) -> Result, String> { + let type_builder = context.create_type_builder(); + let mut module = context.create_module("test_module"); + let mut functions = Vec::new(); + let mut detached_builder = context.create_builder(); + for name in NAMES { + let mut function = + module.add_function(name, type_builder.build::()); + let builder = detached_builder.attach(function.append_new_basic_block(None)); + detached_builder = builder.build_return(None); + functions.push((name.to_string(), function)); + } + let module = module.verify().unwrap(); + Ok(CompileInputs { + module, + callable_functions: functions.into_iter().collect(), + }) + } + } + let compiled_code = make_compiler().run(Test, Default::default()).unwrap(); + let function = compiled_code.get(&"main".to_string()).unwrap(); + unsafe { + let function: GeneratedFunctionType = mem::transmute(function); + function(0); + } + } } -- 2.30.2