pub(crate) label: IdRef,
pub(crate) instructions: Vec<Instruction>,
pub(crate) switch: RefCell<Weak<SwitchNode>>,
- pub(crate) target_label: IdRef,
}
impl GenericNode for SwitchFallthroughNode {
ConditionMerge(Rc<ConditionMergeNode>),
}
+impl Node {
+ pub(crate) fn instructions(&self) -> &Vec<Instruction> {
+ match self {
+ Node::Simple(v) => v.instructions(),
+ Node::Return(v) => v.instructions(),
+ Node::Discard(v) => v.instructions(),
+ Node::Switch(v) => v.instructions(),
+ Node::SwitchFallthrough(v) => v.instructions(),
+ Node::SwitchMerge(v) => v.instructions(),
+ Node::Condition(v) => v.instructions(),
+ Node::ConditionMerge(v) => v.instructions(),
+ }
+ }
+ pub(crate) fn label(&self) -> IdRef {
+ match self {
+ Node::Simple(v) => v.label(),
+ Node::Return(v) => v.label(),
+ Node::Discard(v) => v.label(),
+ Node::Switch(v) => v.label(),
+ Node::SwitchFallthrough(v) => v.label(),
+ Node::SwitchMerge(v) => v.label(),
+ Node::Condition(v) => v.label(),
+ Node::ConditionMerge(v) => v.label(),
+ }
+ }
+}
+
impl<T: GenericNode> From<Rc<T>> for Node {
fn from(v: Rc<T>) -> Node {
GenericNode::to_node(v)
}
struct ParseStateSwitch {
- fallthrough_to_default: Option<Rc<SwitchFallthroughNode>>,
- fallthroughs: Vec<Rc<SwitchFallthroughNode>>,
+ fallthrough: Option<Rc<SwitchFallthroughNode>>,
default_label: IdRef,
- next_case: Option<IdRef>,
merges: Vec<Rc<SwitchMergeNode>>,
merge_label: IdRef,
+ fallthrough_target: Option<IdRef>,
}
struct ParseState {
fn get_switch(&mut self) -> &mut ParseStateSwitch {
self.switch.as_mut().unwrap()
}
+ fn parse_switch<T>(
+ &mut self,
+ basic_blocks: &HashMap<IdRef, BasicBlock>,
+ label_id: IdRef,
+ basic_block: &BasicBlock,
+ targets: &[(T, IdRef)],
+ default_label: IdRef,
+ merge_block: IdRef,
+ ) -> Node {
+ get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge);
+ let mut last_target = None;
+ for &(_, target) in targets {
+ if Some(target) == last_target {
+ continue;
+ }
+ last_target = Some(target);
+ if target != merge_block {
+ get_basic_block(basic_blocks, target)
+ .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal));
+ }
+ }
+ if default_label != merge_block {
+ get_basic_block(basic_blocks, default_label)
+ .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default));
+ }
+ let old_switch = self.push_switch(ParseStateSwitch {
+ default_label: default_label,
+ fallthrough: None,
+ merge_label: merge_block,
+ merges: vec![],
+ fallthrough_target: None,
+ });
+ let default_node = if default_label != merge_block {
+ Some(self.parse(basic_blocks, default_label))
+ } else {
+ None
+ };
+ let mut default_fallthrough = self.get_switch().fallthrough.take();
+ let mut default_fallthrough_target = self.get_switch().fallthrough_target.take();
+ let mut cases = Vec::with_capacity(targets.len());
+ struct Case {
+ node: Node,
+ fallthrough: Option<Rc<SwitchFallthroughNode>>,
+ fallthrough_target: Option<IdRef>,
+ }
+ let mut last_target = None;
+ for (index, &(_, target)) in targets.iter().enumerate() {
+ if Some(target) == last_target {
+ continue;
+ }
+ last_target = Some(target);
+ let node = self.parse(basic_blocks, target);
+ let fallthrough_target = self.get_switch().fallthrough_target.take();
+ if let Some(fallthrough_target) = fallthrough_target {
+ if default_label != fallthrough_target {
+ assert_eq!(
+ Some(fallthrough_target),
+ targets.get(index + 1).map(|v| v.1),
+ "invalid fallthrough branch"
+ );
+ }
+ }
+ cases.push(Case {
+ node,
+ fallthrough: self.get_switch().fallthrough.take(),
+ fallthrough_target,
+ });
+ }
+ let switch = self.pop_switch(old_switch);
+ let mut before_default_cases = None;
+ let mut output_cases = vec![];
+ let mut fallthroughs = vec![];
+ fallthroughs.extend(default_fallthrough);
+ for (
+ index,
+ Case {
+ node,
+ fallthrough,
+ fallthrough_target,
+ },
+ ) in cases.into_iter().enumerate()
+ {
+ if Some(node.label()) == default_fallthrough_target {
+ if before_default_cases.is_none() {
+ before_default_cases = Some(mem::replace(&mut output_cases, vec![]));
+ } else {
+ assert!(output_cases.is_empty(), "invalid fallthrough branch");
+ }
+ }
+ output_cases.push(node);
+ fallthroughs.extend(fallthrough);
+ if Some(default_label) == fallthrough_target {
+ assert!(before_default_cases.is_none());
+ before_default_cases = Some(mem::replace(&mut output_cases, vec![]));
+ }
+ }
+ let before_default_cases =
+ before_default_cases.unwrap_or_else(|| mem::replace(&mut output_cases, vec![]));
+ let default = if let Some(default_node) = default_node {
+ Some(SwitchDefault {
+ default_case: default_node,
+ after_default_cases: output_cases,
+ })
+ } else {
+ None
+ };
+ let next = self.parse(basic_blocks, merge_block);
+ let retval = Rc::new(SwitchNode {
+ label: label_id,
+ instructions: basic_block.get_instructions(),
+ before_default_cases,
+ default,
+ next,
+ });
+ for fallthrough in fallthroughs {
+ fallthrough.switch.replace(Rc::downgrade(&retval));
+ }
+ for merge in switch.merges {
+ merge.switch.replace(Rc::downgrade(&retval));
+ }
+ retval.into()
+ }
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::cyclomatic_complexity))]
fn parse(&mut self, basic_blocks: &HashMap<IdRef, BasicBlock>, label_id: IdRef) -> Node {
let basic_block = get_basic_block(basic_blocks, label_id);
let (terminating_instruction, instructions_without_terminator) = basic_block
BlockKind::LoopContinue => unimplemented!(),
BlockKind::SwitchCase(kind) => {
let mut switch = self.get_switch();
- let expected_target_label = match kind {
- SwitchCaseKind::Normal => {
- switch.next_case.unwrap_or(switch.default_label)
- }
- SwitchCaseKind::Default => switch.default_label,
- };
- assert_eq!(
- target_label, expected_target_label,
- "invalid branch to next switch case"
- );
- unimplemented!()
+ let retval = Rc::new(SwitchFallthroughNode {
+ label: label_id,
+ instructions: basic_block.get_instructions(),
+ switch: Default::default(),
+ });
+ assert!(switch.fallthrough_target.is_none());
+ assert!(switch.fallthrough.is_none());
+ switch.fallthrough_target = Some(target_label);
+ switch.fallthrough = Some(retval.clone());
+ retval.into()
}
BlockKind::SwitchMerge => {
assert_eq!(
}
(
&Instruction::Switch32 {
- default,
+ default: default_label,
target: ref targets,
..
},
Some(&Instruction::SelectionMerge { merge_block, .. }),
- ) => {
- unimplemented!();
- }
+ ) => self.parse_switch(
+ basic_blocks,
+ label_id,
+ basic_block,
+ targets,
+ default_label,
+ merge_block,
+ ),
(
&Instruction::Switch64 {
default: default_label,
..
},
Some(&Instruction::SelectionMerge { merge_block, .. }),
- ) => {
- get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge);
- for &(_, target) in targets {
- if target != merge_block {
- get_basic_block(basic_blocks, target)
- .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal));
- }
- }
- if default_label != merge_block {
- get_basic_block(basic_blocks, default_label)
- .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default));
- }
- let old_switch = self.push_switch(ParseStateSwitch {
- default_label: default_label,
- fallthrough_to_default: None,
- merge_label: merge_block,
- fallthroughs: vec![],
- merges: vec![],
- next_case: None,
- });
- let default = if default_label != merge_block {
- Some(self.parse(basic_blocks, default_label))
- } else {
- None
- };
- let mut default_fallthrough = None;
- for i in self.get_switch().fallthroughs.drain(..) {
- assert!(
- default_fallthrough.is_none(),
- "multiple fallthroughs from default case"
- );
- default_fallthrough = Some(i);
- }
- let mut cases = Vec::with_capacity(targets.len());
- for (index, &(_, target)) in targets.iter().enumerate() {
- self.get_switch().next_case = targets.get(index + 1).map(|v| v.1);
- cases.push(self.parse(basic_blocks, target));
- }
- let switch = self.pop_switch(old_switch);
- let (before_default_cases, default) = if let Some(default) = default {
- if let Some(fallthrough_to_default) = &switch.fallthrough_to_default {
- // FIXME: handle default_fallthrough
- unimplemented!()
- } else if let Some(default_fallthrough) = &default_fallthrough {
- unimplemented!()
- } else {
- (
- cases,
- Some(SwitchDefault {
- default_case: default,
- after_default_cases: vec![],
- }),
- )
- }
- } else {
- (cases, None)
- };
- let next = self.parse(basic_blocks, merge_block);
- let retval = Rc::new(SwitchNode {
- label: label_id,
- instructions: basic_block.get_instructions(),
- before_default_cases,
- default,
- next,
- });
- if let Some(default_fallthrough) = default_fallthrough {
- default_fallthrough.switch.replace(Rc::downgrade(&retval));
- }
- if let Some(fallthrough_to_default) = switch.fallthrough_to_default {
- fallthrough_to_default
- .switch
- .replace(Rc::downgrade(&retval));
- }
- for fallthrough in switch.fallthroughs {
- fallthrough.switch.replace(Rc::downgrade(&retval));
- }
- for merge in switch.merges {
- merge.switch.replace(Rc::downgrade(&retval));
- }
- retval.into()
- }
+ ) => self.parse_switch(
+ basic_blocks,
+ label_id,
+ basic_block,
+ targets,
+ default_label,
+ merge_block,
+ ),
(&Instruction::Switch32 { .. }, _) => unreachable!("missing merge instruction"),
(&Instruction::Switch64 { .. }, _) => unreachable!("missing merge instruction"),
(&Instruction::Kill {}, _) => Rc::new(DiscardNode {
&[
SerializedCFGElement::Switch,
SerializedCFGElement::SwitchCase,
+ SerializedCFGElement::SwitchFallthrough,
+ SerializedCFGElement::SwitchDefaultCase,
+ SerializedCFGElement::SwitchMerge,
+ SerializedCFGElement::SwitchEnd,
+ SerializedCFGElement::Return,
+ ],
+ );
+ }
+
+ #[test]
+ fn test_cfg_switch_fallthrough_default_fallthrough_break() {
+ let mut id_factory = IdFactory::new();
+ let mut instructions = Vec::new();
+
+ let label_start = id_factory.next();
+ let label_case1 = id_factory.next();
+ let label_default = id_factory.next();
+ let label_case2 = id_factory.next();
+ let label_merge = id_factory.next();
+
+ instructions.push(Instruction::NoLine);
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_start),
+ });
+ instructions.push(Instruction::SelectionMerge {
+ merge_block: label_merge,
+ selection_control: spirv_parser::SelectionControl::default(),
+ });
+ instructions.push(Instruction::Switch64 {
+ selector: id_factory.next(),
+ default: label_default,
+ target: vec![(0, label_case1), (1, label_case1), (2, label_case2)],
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_case1),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_default,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_default),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_case2,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_case2),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_merge,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_merge),
+ });
+ instructions.push(Instruction::Return);
+
+ test_cfg(
+ &instructions,
+ &[
+ SerializedCFGElement::Switch,
+ SerializedCFGElement::SwitchCase,
+ SerializedCFGElement::SwitchFallthrough,
+ SerializedCFGElement::SwitchDefaultCase,
+ SerializedCFGElement::SwitchFallthrough,
+ SerializedCFGElement::SwitchCase,
+ SerializedCFGElement::SwitchMerge,
+ SerializedCFGElement::SwitchEnd,
SerializedCFGElement::Return,
+ ],
+ );
+ }
+
+ #[test]
+ fn test_cfg_switch_break_default_fallthrough_break() {
+ let mut id_factory = IdFactory::new();
+ let mut instructions = Vec::new();
+
+ let label_start = id_factory.next();
+ let label_case1 = id_factory.next();
+ let label_default = id_factory.next();
+ let label_case2 = id_factory.next();
+ let label_merge = id_factory.next();
+
+ instructions.push(Instruction::NoLine);
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_start),
+ });
+ instructions.push(Instruction::SelectionMerge {
+ merge_block: label_merge,
+ selection_control: spirv_parser::SelectionControl::default(),
+ });
+ instructions.push(Instruction::Switch32 {
+ selector: id_factory.next(),
+ default: label_default,
+ target: vec![(0, label_case1), (1, label_case1), (2, label_case2)],
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_case1),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_merge,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_default),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_case2,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_case2),
+ });
+ instructions.push(Instruction::Branch {
+ target_label: label_merge,
+ });
+
+ instructions.push(Instruction::Label {
+ id_result: IdResult(label_merge),
+ });
+ instructions.push(Instruction::Return);
+
+ test_cfg(
+ &instructions,
+ &[
+ SerializedCFGElement::Switch,
+ SerializedCFGElement::SwitchCase,
+ SerializedCFGElement::SwitchMerge,
SerializedCFGElement::SwitchDefaultCase,
+ SerializedCFGElement::SwitchFallthrough,
+ SerializedCFGElement::SwitchCase,
SerializedCFGElement::SwitchMerge,
SerializedCFGElement::SwitchEnd,
SerializedCFGElement::Return,