From 99ac094fb945300882b41ffde422a8ae308c3fac Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 25 Oct 2018 01:51:41 -0700 Subject: [PATCH] working on spirv parser generator --- spirv-parser-generator/src/ast.rs | 72 ++-- spirv-parser-generator/src/generate.rs | 467 ++++++++++++++++++++++++- 2 files changed, 483 insertions(+), 56 deletions(-) diff --git a/spirv-parser-generator/src/ast.rs b/spirv-parser-generator/src/ast.rs index 6fdafc7..5f6690b 100644 --- a/spirv-parser-generator/src/ast.rs +++ b/spirv-parser-generator/src/ast.rs @@ -7,60 +7,28 @@ use serde::de::{self, Deserialize, Deserializer}; use std::borrow::Cow; use std::fmt; use std::mem; -use std::num::ParseIntError; use util::NameFormat::*; use util::WordIterator; #[derive(Copy, Clone)] -pub struct QuotedInteger(pub T); +pub struct QuotedInteger(pub u32); -pub trait QuotedIntegerProperties: Sized { - const DIGIT_COUNT: usize; - fn from_str_radix(src: &str, radix: u32) -> Result; -} - -impl QuotedIntegerProperties for QuotedInteger { - const DIGIT_COUNT: usize = 4; - fn from_str_radix(src: &str, radix: u32) -> Result { - Ok(QuotedInteger(u16::from_str_radix(src, radix)?)) - } -} - -impl QuotedIntegerProperties for QuotedInteger { - const DIGIT_COUNT: usize = 8; - fn from_str_radix(src: &str, radix: u32) -> Result { - Ok(QuotedInteger(u32::from_str_radix(src, radix)?)) - } -} - -impl ToTokens for QuotedInteger { +impl ToTokens for QuotedInteger { fn to_tokens(&self, tokens: &mut TokenStream) { self.0.to_tokens(tokens) } } -impl fmt::Display for QuotedInteger { +impl fmt::Display for QuotedInteger { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:#06X}", self.0) } } -impl fmt::Display for QuotedInteger { +impl fmt::Debug for QuotedInteger { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:#010X}", self.0) - } -} - -impl fmt::Debug for QuotedInteger -where - Self: fmt::Display + Copy, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - struct DisplayQuotedInteger(QuotedInteger); - impl fmt::Debug for DisplayQuotedInteger - where - QuotedInteger: fmt::Display, - { + struct DisplayQuotedInteger(QuotedInteger); + impl fmt::Debug for DisplayQuotedInteger { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) } @@ -71,10 +39,7 @@ where } } -impl<'de, T> Deserialize<'de> for QuotedInteger -where - Self: QuotedIntegerProperties, -{ +impl<'de> Deserialize<'de> for QuotedInteger { fn deserialize>(deserializer: D) -> Result { let s = String::deserialize(deserializer)?; let prefix = "0x"; @@ -91,12 +56,12 @@ where "invalid quoted integer -- not a hexadecimal digit", )); } - if digits.len() != Self::DIGIT_COUNT { + if digits.len() > 8 { return Err(de::Error::custom( - "invalid quoted integer -- wrong number of hex digits", + "invalid quoted integer -- too many hexadecimal digits", )); } - Ok(Self::from_str_radix(digits, radix).unwrap()) + Ok(QuotedInteger(u32::from_str_radix(digits, radix).unwrap())) } } @@ -166,7 +131,16 @@ pub struct InstructionOperand { impl InstructionOperand { pub fn fixup(&mut self) -> Result<(), ::Error> { - if self.name.is_none() { + if let Some(name) = self.name.take() { + let substitute_name = match &*name { + "'Member 0 type', +\n'member 1 type', +\n..." => Some("Member Types"), + "'Parameter 0 Type', +\n'Parameter 1 Type', +\n..." => Some("Parameter Types"), + "'Argument 0', +\n'Argument 1', +\n..." => Some("Arguments"), + "'Operand 1', +\n'Operand 2', +\n..." => Some("Operands"), + _ => None, + }; + self.name = Some(substitute_name.map(String::from).unwrap_or(name)); + } else { self.name = Some( SnakeCase .name_from_words(WordIterator::new(self.kind.as_ref())) @@ -577,7 +551,7 @@ impl Enumerant { } } -impl Enumerant, BitwiseEnumerantParameter> { +impl Enumerant { pub fn fixup(&mut self) -> Result<(), ::Error> { for parameter in self.parameters.iter_mut() { parameter.fixup()?; @@ -742,7 +716,7 @@ impl AsRef for LiteralKind { pub enum OperandKind { BitEnum { kind: Kind, - enumerants: Vec, BitwiseEnumerantParameter>>, + enumerants: Vec>, }, ValueEnum { kind: Kind, @@ -766,7 +740,7 @@ pub enum OperandKind { #[serde(deny_unknown_fields)] pub struct CoreGrammar { pub copyright: Vec, - pub magic_number: QuotedInteger, + pub magic_number: QuotedInteger, pub major_version: u32, pub minor_version: u32, pub revision: u32, diff --git a/spirv-parser-generator/src/generate.rs b/spirv-parser-generator/src/generate.rs index b3d8884..27d2d0e 100644 --- a/spirv-parser-generator/src/generate.rs +++ b/spirv-parser-generator/src/generate.rs @@ -158,6 +158,170 @@ pub(crate) fn generate( writeln!(&mut out, "// {}", i); } } + writeln!( + &mut out, + "{}", + stringify!( + use std::result; + use std::error; + use std::fmt; + use std::mem; + use std::str::Utf8Error; + use std::string::FromUtf8Error; + + trait SPIRVParse: Sized { + fn spirv_parse<'a>(words: &'a [u32], parse_state: &mut ParseState) + -> Result<(Self, &'a [u32])>; + } + + impl SPIRVParse for Option { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + if words.is_empty() { + Ok((None, words)) + } else { + let (value, words) = T::spirv_parse(words, parse_state)?; + Ok((Some(value), words)) + } + } + } + + impl SPIRVParse for Vec { + fn spirv_parse<'a>( + mut words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let mut retval = Vec::new(); + while !words.is_empty() { + let result = T::spirv_parse(words, parse_state)?; + words = result.1; + retval.push(result.0); + } + Ok((retval, words)) + } + } + + impl SPIRVParse for (A, B) { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let (a, words) = A::spirv_parse(words, parse_state)?; + let (b, words) = B::spirv_parse(words, parse_state)?; + Ok(((a, b), words)) + } + } + + struct ByteIterator<'a> { + current_word: [u8; 4], + bytes_left_in_current_word: usize, + words: &'a [u32], + } + + impl<'a> ByteIterator<'a> { + fn new(words: &'a [u32]) -> Self { + Self { + current_word: [0; 4], + bytes_left_in_current_word: 0, + words, + } + } + fn take_unread_words(&mut self) -> &'a [u32] { + mem::replace(&mut self.words, &[]) + } + } + + impl<'a> Iterator for ByteIterator<'a> { + type Item = u8; + fn next(&mut self) -> Option { + if self.bytes_left_in_current_word == 0 { + let (¤t_word, words) = self.words.split_first()?; + self.words = words; + self.current_word = unsafe { mem::transmute(current_word.to_le()) }; + self.bytes_left_in_current_word = self.current_word.len(); + } + let byte = self.current_word[self.bytes_left_in_current_word]; + self.bytes_left_in_current_word -= 1; + Some(byte) + } + } + + impl SPIRVParse for String { + fn spirv_parse<'a>( + words: &'a [u32], + _parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let mut byte_count_excluding_null_terminator = None; + for (index, byte) in ByteIterator::new(words).enumerate() { + if byte == 0 { + byte_count_excluding_null_terminator = Some(index); + break; + } + } + let byte_count_excluding_null_terminator = + byte_count_excluding_null_terminator.ok_or(Error::InstructionPrematurelyEnded)?; + let mut bytes = Vec::with_capacity(byte_count_excluding_null_terminator); + let mut byte_iter = ByteIterator::new(words); + for _ in 0..byte_count_excluding_null_terminator { + let byte = byte_iter.next().unwrap(); + bytes.push(byte); + } + let _null_terminator = byte_iter.next().unwrap(); + let words = byte_iter.take_unread_words(); + for v in byte_iter { + if v != 0 { + return Err(Error::InvalidStringTermination); + } + } + assert_eq!(bytes.len(), byte_count_excluding_null_terminator); + Ok((String::from_utf8(bytes)?, words)) + } + } + + impl SPIRVParse for u32 { + fn spirv_parse<'a>( + words: &'a [u32], + _parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let (&value, words) = words + .split_first() + .ok_or(Error::InstructionPrematurelyEnded)?; + Ok((value, words)) + } + } + + impl SPIRVParse for u64 { + fn spirv_parse<'a>( + words: &'a [u32], + _parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let (&low, words) = words + .split_first() + .ok_or(Error::InstructionPrematurelyEnded)?; + let (&high, words) = words + .split_first() + .ok_or(Error::InstructionPrematurelyEnded)?; + Ok((((high as u64) << 32) | low as u64, words)) + } + } + + impl SPIRVParse for IdRef { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + let (value, words) = u32::spirv_parse(words, parse_state)?; + if value == 0 || value >= parse_state.bound { + Err(Error::IdOutOfBounds(value)) + } else { + Ok((IdRef(value), words)) + } + } + } + ) + )?; writeln!( &mut out, "{}", @@ -171,9 +335,11 @@ pub(crate) fn generate( for operand_kind in &operand_kinds { match operand_kind { ast::OperandKind::BitEnum { kind, enumerants } => { + let kind_id = new_id(kind, CamelCase); let mut enumerant_members = Vec::new(); let mut enumerant_member_names = Vec::new(); let mut enumerant_items = Vec::new(); + let mut enumerant_parse_operations = Vec::new(); for enumerant in enumerants { if enumerant.value.0 == 0 { continue; @@ -202,8 +368,16 @@ pub(crate) fn generate( enumerant_members.push(quote!{ pub #member_name: Option<#type_name> }); + let enumerant_value = enumerant.value; + enumerant_parse_operations.push(quote!{ + let #member_name = if (mask & #enumerant_value) == 0 { + mask &= !#enumerant_value; + unimplemented!() + } else { + None + }; + }) } - let kind_id = new_id(kind, CamelCase); writeln!( &mut out, "{}", @@ -212,14 +386,33 @@ pub(crate) fn generate( pub struct #kind_id { #(#enumerant_members),* } - impl #kind_id { - pub fn new() -> Self { - Self { - #(#enumerant_member_names: None,)* - } + #(#enumerant_items)* + } + )?; + let parse_body = quote!{ + let (mask, words) = words.split_first().ok_or(Error::InstructionPrematurelyEnded)?; + let mut mask = *mask; + #(#enumerant_parse_operations)* + if mask != 0 { + Err(Error::InvalidEnumValue) + } else { + Ok((Self { + #(#enumerant_member_names,)* + }, words)) + } + }; + writeln!( + &mut out, + "{}", + quote!{ + impl SPIRVParse for #kind_id { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + #parse_body } } - #(#enumerant_items)* } )?; } @@ -232,6 +425,18 @@ pub(crate) fn generate( generated_enumerants.push(quote!{#name}); continue; } + let parameters = enumerant.parameters.iter().map(|parameter| { + let name = new_id(parameter.name.as_ref().unwrap(), SnakeCase); + let kind = new_id(¶meter.kind, CamelCase); + quote!{ + #name: #kind, + } + }); + generated_enumerants.push(quote!{ + #name { + #(#parameters)* + } + }); } writeln!( &mut out, @@ -243,6 +448,20 @@ pub(crate) fn generate( } } )?; + writeln!( + &mut out, + "{}", + quote!{ + impl SPIRVParse for #kind_id { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + unimplemented!() + } + } + } + )?; } ast::OperandKind::Id { kind, doc: _ } => { let base = if *kind == ast::Kind::IdRef { @@ -260,6 +479,22 @@ pub(crate) fn generate( pub struct #kind_id(pub #base); } )?; + if *kind != ast::Kind::IdRef { + writeln!( + &mut out, + "{}", + quote!{ + impl SPIRVParse for #kind_id { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState, + ) -> Result<(Self, &'a [u32])> { + IdRef::spirv_parse(words, parse_state).map(|(value, words)| (#kind_id(value), words)) + } + } + } + )?; + } } ast::OperandKind::Literal { kind, doc: _ } => { let kind_id = new_id(kind, CamelCase); @@ -295,8 +530,49 @@ pub(crate) fn generate( { let mut instruction_enumerants = Vec::new(); let mut spec_constant_op_instruction_enumerants = Vec::new(); + let mut instruction_parse_cases = Vec::new(); for instruction in core_instructions.iter() { + let opcode = instruction.opcode; let opname = new_id(remove_initial_op(instruction.opname.as_ref()), CamelCase); + instruction_parse_cases.push(match &instruction.opname { + ast::InstructionName::OpSpecConstantOp => { + quote!{#opcode => { + let (operation, words) = OpSpecConstantOp::spirv_parse(words, parse_state)?; + if words.is_empty() { + Ok(Instruction::#opname { operation }) + } else { + Err(Error::InstructionTooLong) + } + }} + } + _ => { + let mut parse_operations = Vec::new(); + let mut operand_names = Vec::new(); + for operand in &instruction.operands { + let kind = new_id(&operand.kind, CamelCase); + let name = new_id(operand.name.as_ref().unwrap(), SnakeCase); + let kind = match operand.quantifier { + None => quote!{#kind}, + Some(ast::Quantifier::Optional) => quote!{Option::<#kind>}, + Some(ast::Quantifier::Variadic) => quote!{Vec::<#kind>}, + }; + parse_operations.push(quote!{ + let (#name, words) = #kind::spirv_parse(words, parse_state)?; + }); + operand_names.push(name); + } + quote!{#opcode => { + #(#parse_operations)* + if words.is_empty() { + Ok(Instruction::#opname { + #(#operand_names,)* + }) + } else { + Err(Error::InstructionTooLong) + } + }} + } + }); let instruction_enumerant = if instruction.opname == ast::InstructionName::OpSpecConstantOp { quote!{ @@ -337,12 +613,189 @@ pub(crate) fn generate( pub enum OpSpecConstantOp { #(#spec_constant_op_instruction_enumerants,)* } + } + )?; + writeln!( + &mut out, + "{}", + quote!{ #[derive(Clone, Debug)] pub enum Instruction { #(#instruction_enumerants,)* } } )?; + writeln!( + &mut out, + "{}", + stringify!( + #[derive(Copy, Clone, Debug)] + pub struct Header { + pub version: (u32, u32), + pub generator: u32, + pub bound: u32, + pub instruction_schema: u32, + } + + #[derive(Clone, Debug)] + pub enum Error { + MissingHeader, + InvalidHeader, + UnsupportedVersion(u32, u32), + ZeroInstructionLength, + SourcePrematurelyEnded, + UnknownOpcode(u16), + Utf8Error(Utf8Error), + InstructionPrematurelyEnded, + InvalidStringTermination, + InstructionTooLong, + InvalidEnumValue, + IdOutOfBounds(u32), + } + + impl From for Error { + fn from(v: Utf8Error) -> Self { + Error::Utf8Error(v) + } + } + + impl From for Error { + fn from(v: FromUtf8Error) -> Self { + Error::Utf8Error(v.utf8_error()) + } + } + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::MissingHeader => write!(f, "SPIR-V source is missing the file header"), + Error::InvalidHeader => write!(f, "SPIR-V source has an invalid file header"), + Error::UnsupportedVersion(major, minor) => write!( + f, + "SPIR-V source has an unsupported version: {}.{}", + major, + minor + ), + Error::ZeroInstructionLength => write!(f, "SPIR-V instruction has a length of zero"), + Error::SourcePrematurelyEnded => write!(f, "SPIR-V source prematurely ended"), + Error::UnknownOpcode(opcode) => { + write!(f, "SPIR-V instruction has an unknown opcode: {}", opcode) + } + Error::Utf8Error(error) => fmt::Display::fmt(&error, f), + Error::InstructionPrematurelyEnded => write!(f, "SPIR-V instruction prematurely ended"), + Error::InvalidStringTermination => write!(f, "SPIR-V LiteralString has an invalid termination word"), + Error::InstructionTooLong => write!(f, "SPIR-V instruction is too long"), + Error::InvalidEnumValue => write!(f, "enum has invalid value"), + Error::IdOutOfBounds(id) => write!(f, "id is out of bounds: {}", id), + } + } + } + + impl error::Error for Error {} + + type Result = result::Result; + + #[derive(Clone, Debug)] + struct ParseState { + bound: u32, + } + + #[derive(Clone, Debug)] + pub struct Parser<'a> { + words: &'a [u32], + header: Header, + parse_state: ParseState, + } + + fn parse_version(v: u32) -> Result<(u32, u32)> { + if (v & 0xFF0000FF) != 0 { + return Err(Error::InvalidHeader); + } + let major = (v >> 16) & 0xFF; + let minor = (v >> 8) & 0xFF; + Ok((major, minor)) + } + + impl<'a> Parser<'a> { + pub fn header(&self) -> &Header { + &self.header + } + pub fn start(mut words: &'a [u32]) -> Result { + let header = words.get(0..5).ok_or(Error::MissingHeader)?; + words = &words[5..]; + let header = match *header { + [MAGIC_NUMBER, version, generator, bound, instruction_schema @ 0] if bound >= 1 => { + let version = parse_version(version)?; + if version.0 != MAJOR_VERSION || version.1 > MINOR_VERSION { + return Err(Error::UnsupportedVersion(version.0, version.1)); + } + Header { + version, + generator, + bound, + instruction_schema, + } + } + _ => return Err(Error::InvalidHeader), + }; + Ok(Self { + words, + header, + parse_state: ParseState { + bound: header.bound, + }, + }) + } + fn next_helper(&mut self, length_and_opcode: u32) -> Result { + let length = (length_and_opcode >> 16) as usize; + let opcode = length_and_opcode as u16; + if length == 0 { + return Err(Error::ZeroInstructionLength); + } + let instruction_words = self.words.get(1..length).ok_or(Error::SourcePrematurelyEnded)?; + self.words = &self.words[length..]; + parse_instruction(opcode, instruction_words, &mut self.parse_state) + } + } + + impl<'a> Iterator for Parser<'a> { + type Item = Result; + fn next(&mut self) -> Option> { + let length_and_opcode = self.words.get(0)?; + Some(self.next_helper(*length_and_opcode)) + } + } + ) + )?; + writeln!( + &mut out, + "{}", + quote!{ + fn parse_instruction(opcode: u16, words: &[u32], parse_state: &mut ParseState) -> Result { + match opcode { + #(#instruction_parse_cases)* + opcode => Err(Error::UnknownOpcode(opcode)), + } + } + } + )?; + writeln!( + &mut out, + "{}", + quote!{ + impl SPIRVParse for OpSpecConstantOp { + fn spirv_parse<'a>( + words: &'a [u32], + parse_state: &mut ParseState + ) -> Result<(Self, &'a [u32])> { + let (id_result_type, words) = IdResultType::spirv_parse(words, parse_state)?; + let (id_result, words) = IdResult::spirv_parse(words, parse_state)?; + let (opcode, words) = u32::spirv_parse(words, parse_state)?; + unimplemented!() + } + } + } + )?; } let source = String::from_utf8(out).unwrap(); let source = match format_source(&options, &source) { -- 2.30.2