implemented switch CFG parsing master
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 27 Nov 2018 06:57:58 +0000 (22:57 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 27 Nov 2018 06:57:58 +0000 (22:57 -0800)
shader-compiler/src/cfg.rs

index f5522f293970db9ee4c24cc52f7650b3cc4aa5f8..393df4a54baa9aacf15d42490d121bcaedab9be6 100644 (file)
@@ -65,7 +65,6 @@ pub(crate) struct SwitchFallthroughNode {
     pub(crate) label: IdRef,
     pub(crate) instructions: Vec<Instruction>,
     pub(crate) switch: RefCell<Weak<SwitchNode>>,
     pub(crate) label: IdRef,
     pub(crate) instructions: Vec<Instruction>,
     pub(crate) switch: RefCell<Weak<SwitchNode>>,
-    pub(crate) target_label: IdRef,
 }
 
 impl GenericNode for SwitchFallthroughNode {
 }
 
 impl GenericNode for SwitchFallthroughNode {
@@ -187,6 +186,33 @@ pub(crate) enum Node {
     ConditionMerge(Rc<ConditionMergeNode>),
 }
 
     ConditionMerge(Rc<ConditionMergeNode>),
 }
 
+impl Node {
+    pub(crate) fn instructions(&self) -> &Vec<Instruction> {
+        match self {
+            Node::Simple(v) => v.instructions(),
+            Node::Return(v) => v.instructions(),
+            Node::Discard(v) => v.instructions(),
+            Node::Switch(v) => v.instructions(),
+            Node::SwitchFallthrough(v) => v.instructions(),
+            Node::SwitchMerge(v) => v.instructions(),
+            Node::Condition(v) => v.instructions(),
+            Node::ConditionMerge(v) => v.instructions(),
+        }
+    }
+    pub(crate) fn label(&self) -> IdRef {
+        match self {
+            Node::Simple(v) => v.label(),
+            Node::Return(v) => v.label(),
+            Node::Discard(v) => v.label(),
+            Node::Switch(v) => v.label(),
+            Node::SwitchFallthrough(v) => v.label(),
+            Node::SwitchMerge(v) => v.label(),
+            Node::Condition(v) => v.label(),
+            Node::ConditionMerge(v) => v.label(),
+        }
+    }
+}
+
 impl<T: GenericNode> From<Rc<T>> for Node {
     fn from(v: Rc<T>) -> Node {
         GenericNode::to_node(v)
 impl<T: GenericNode> From<Rc<T>> for Node {
     fn from(v: Rc<T>) -> Node {
         GenericNode::to_node(v)
@@ -251,12 +277,11 @@ struct ParseStateCondition {
 }
 
 struct ParseStateSwitch {
 }
 
 struct ParseStateSwitch {
-    fallthrough_to_default: Option<Rc<SwitchFallthroughNode>>,
-    fallthroughs: Vec<Rc<SwitchFallthroughNode>>,
+    fallthrough: Option<Rc<SwitchFallthroughNode>>,
     default_label: IdRef,
     default_label: IdRef,
-    next_case: Option<IdRef>,
     merges: Vec<Rc<SwitchMergeNode>>,
     merge_label: IdRef,
     merges: Vec<Rc<SwitchMergeNode>>,
     merge_label: IdRef,
+    fallthrough_target: Option<IdRef>,
 }
 
 struct ParseState {
 }
 
 struct ParseState {
@@ -289,6 +314,129 @@ impl ParseState {
     fn get_switch(&mut self) -> &mut ParseStateSwitch {
         self.switch.as_mut().unwrap()
     }
     fn get_switch(&mut self) -> &mut ParseStateSwitch {
         self.switch.as_mut().unwrap()
     }
+    fn parse_switch<T>(
+        &mut self,
+        basic_blocks: &HashMap<IdRef, BasicBlock>,
+        label_id: IdRef,
+        basic_block: &BasicBlock,
+        targets: &[(T, IdRef)],
+        default_label: IdRef,
+        merge_block: IdRef,
+    ) -> Node {
+        get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge);
+        let mut last_target = None;
+        for &(_, target) in targets {
+            if Some(target) == last_target {
+                continue;
+            }
+            last_target = Some(target);
+            if target != merge_block {
+                get_basic_block(basic_blocks, target)
+                    .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal));
+            }
+        }
+        if default_label != merge_block {
+            get_basic_block(basic_blocks, default_label)
+                .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default));
+        }
+        let old_switch = self.push_switch(ParseStateSwitch {
+            default_label: default_label,
+            fallthrough: None,
+            merge_label: merge_block,
+            merges: vec![],
+            fallthrough_target: None,
+        });
+        let default_node = if default_label != merge_block {
+            Some(self.parse(basic_blocks, default_label))
+        } else {
+            None
+        };
+        let mut default_fallthrough = self.get_switch().fallthrough.take();
+        let mut default_fallthrough_target = self.get_switch().fallthrough_target.take();
+        let mut cases = Vec::with_capacity(targets.len());
+        struct Case {
+            node: Node,
+            fallthrough: Option<Rc<SwitchFallthroughNode>>,
+            fallthrough_target: Option<IdRef>,
+        }
+        let mut last_target = None;
+        for (index, &(_, target)) in targets.iter().enumerate() {
+            if Some(target) == last_target {
+                continue;
+            }
+            last_target = Some(target);
+            let node = self.parse(basic_blocks, target);
+            let fallthrough_target = self.get_switch().fallthrough_target.take();
+            if let Some(fallthrough_target) = fallthrough_target {
+                if default_label != fallthrough_target {
+                    assert_eq!(
+                        Some(fallthrough_target),
+                        targets.get(index + 1).map(|v| v.1),
+                        "invalid fallthrough branch"
+                    );
+                }
+            }
+            cases.push(Case {
+                node,
+                fallthrough: self.get_switch().fallthrough.take(),
+                fallthrough_target,
+            });
+        }
+        let switch = self.pop_switch(old_switch);
+        let mut before_default_cases = None;
+        let mut output_cases = vec![];
+        let mut fallthroughs = vec![];
+        fallthroughs.extend(default_fallthrough);
+        for (
+            index,
+            Case {
+                node,
+                fallthrough,
+                fallthrough_target,
+            },
+        ) in cases.into_iter().enumerate()
+        {
+            if Some(node.label()) == default_fallthrough_target {
+                if before_default_cases.is_none() {
+                    before_default_cases = Some(mem::replace(&mut output_cases, vec![]));
+                } else {
+                    assert!(output_cases.is_empty(), "invalid fallthrough branch");
+                }
+            }
+            output_cases.push(node);
+            fallthroughs.extend(fallthrough);
+            if Some(default_label) == fallthrough_target {
+                assert!(before_default_cases.is_none());
+                before_default_cases = Some(mem::replace(&mut output_cases, vec![]));
+            }
+        }
+        let before_default_cases =
+            before_default_cases.unwrap_or_else(|| mem::replace(&mut output_cases, vec![]));
+        let default = if let Some(default_node) = default_node {
+            Some(SwitchDefault {
+                default_case: default_node,
+                after_default_cases: output_cases,
+            })
+        } else {
+            None
+        };
+        let next = self.parse(basic_blocks, merge_block);
+        let retval = Rc::new(SwitchNode {
+            label: label_id,
+            instructions: basic_block.get_instructions(),
+            before_default_cases,
+            default,
+            next,
+        });
+        for fallthrough in fallthroughs {
+            fallthrough.switch.replace(Rc::downgrade(&retval));
+        }
+        for merge in switch.merges {
+            merge.switch.replace(Rc::downgrade(&retval));
+        }
+        retval.into()
+    }
+    #[cfg_attr(feature = "cargo-clippy", allow(clippy::cyclomatic_complexity))]
     fn parse(&mut self, basic_blocks: &HashMap<IdRef, BasicBlock>, label_id: IdRef) -> Node {
         let basic_block = get_basic_block(basic_blocks, label_id);
         let (terminating_instruction, instructions_without_terminator) = basic_block
     fn parse(&mut self, basic_blocks: &HashMap<IdRef, BasicBlock>, label_id: IdRef) -> Node {
         let basic_block = get_basic_block(basic_blocks, label_id);
         let (terminating_instruction, instructions_without_terminator) = basic_block
@@ -338,17 +486,16 @@ impl ParseState {
                     BlockKind::LoopContinue => unimplemented!(),
                     BlockKind::SwitchCase(kind) => {
                         let mut switch = self.get_switch();
                     BlockKind::LoopContinue => unimplemented!(),
                     BlockKind::SwitchCase(kind) => {
                         let mut switch = self.get_switch();
-                        let expected_target_label = match kind {
-                            SwitchCaseKind::Normal => {
-                                switch.next_case.unwrap_or(switch.default_label)
-                            }
-                            SwitchCaseKind::Default => switch.default_label,
-                        };
-                        assert_eq!(
-                            target_label, expected_target_label,
-                            "invalid branch to next switch case"
-                        );
-                        unimplemented!()
+                        let retval = Rc::new(SwitchFallthroughNode {
+                            label: label_id,
+                            instructions: basic_block.get_instructions(),
+                            switch: Default::default(),
+                        });
+                        assert!(switch.fallthrough_target.is_none());
+                        assert!(switch.fallthrough.is_none());
+                        switch.fallthrough_target = Some(target_label);
+                        switch.fallthrough = Some(retval.clone());
+                        retval.into()
                     }
                     BlockKind::SwitchMerge => {
                         assert_eq!(
                     }
                     BlockKind::SwitchMerge => {
                         assert_eq!(
@@ -420,14 +567,19 @@ impl ParseState {
             }
             (
                 &Instruction::Switch32 {
             }
             (
                 &Instruction::Switch32 {
-                    default,
+                    default: default_label,
                     target: ref targets,
                     ..
                 },
                 Some(&Instruction::SelectionMerge { merge_block, .. }),
                     target: ref targets,
                     ..
                 },
                 Some(&Instruction::SelectionMerge { merge_block, .. }),
-            ) => {
-                unimplemented!();
-            }
+            ) => self.parse_switch(
+                basic_blocks,
+                label_id,
+                basic_block,
+                targets,
+                default_label,
+                merge_block,
+            ),
             (
                 &Instruction::Switch64 {
                     default: default_label,
             (
                 &Instruction::Switch64 {
                     default: default_label,
@@ -435,87 +587,14 @@ impl ParseState {
                     ..
                 },
                 Some(&Instruction::SelectionMerge { merge_block, .. }),
                     ..
                 },
                 Some(&Instruction::SelectionMerge { merge_block, .. }),
-            ) => {
-                get_basic_block(basic_blocks, merge_block).set_kind(BlockKind::SwitchMerge);
-                for &(_, target) in targets {
-                    if target != merge_block {
-                        get_basic_block(basic_blocks, target)
-                            .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Normal));
-                    }
-                }
-                if default_label != merge_block {
-                    get_basic_block(basic_blocks, default_label)
-                        .set_kind(BlockKind::SwitchCase(SwitchCaseKind::Default));
-                }
-                let old_switch = self.push_switch(ParseStateSwitch {
-                    default_label: default_label,
-                    fallthrough_to_default: None,
-                    merge_label: merge_block,
-                    fallthroughs: vec![],
-                    merges: vec![],
-                    next_case: None,
-                });
-                let default = if default_label != merge_block {
-                    Some(self.parse(basic_blocks, default_label))
-                } else {
-                    None
-                };
-                let mut default_fallthrough = None;
-                for i in self.get_switch().fallthroughs.drain(..) {
-                    assert!(
-                        default_fallthrough.is_none(),
-                        "multiple fallthroughs from default case"
-                    );
-                    default_fallthrough = Some(i);
-                }
-                let mut cases = Vec::with_capacity(targets.len());
-                for (index, &(_, target)) in targets.iter().enumerate() {
-                    self.get_switch().next_case = targets.get(index + 1).map(|v| v.1);
-                    cases.push(self.parse(basic_blocks, target));
-                }
-                let switch = self.pop_switch(old_switch);
-                let (before_default_cases, default) = if let Some(default) = default {
-                    if let Some(fallthrough_to_default) = &switch.fallthrough_to_default {
-                        // FIXME: handle default_fallthrough
-                        unimplemented!()
-                    } else if let Some(default_fallthrough) = &default_fallthrough {
-                        unimplemented!()
-                    } else {
-                        (
-                            cases,
-                            Some(SwitchDefault {
-                                default_case: default,
-                                after_default_cases: vec![],
-                            }),
-                        )
-                    }
-                } else {
-                    (cases, None)
-                };
-                let next = self.parse(basic_blocks, merge_block);
-                let retval = Rc::new(SwitchNode {
-                    label: label_id,
-                    instructions: basic_block.get_instructions(),
-                    before_default_cases,
-                    default,
-                    next,
-                });
-                if let Some(default_fallthrough) = default_fallthrough {
-                    default_fallthrough.switch.replace(Rc::downgrade(&retval));
-                }
-                if let Some(fallthrough_to_default) = switch.fallthrough_to_default {
-                    fallthrough_to_default
-                        .switch
-                        .replace(Rc::downgrade(&retval));
-                }
-                for fallthrough in switch.fallthroughs {
-                    fallthrough.switch.replace(Rc::downgrade(&retval));
-                }
-                for merge in switch.merges {
-                    merge.switch.replace(Rc::downgrade(&retval));
-                }
-                retval.into()
-            }
+            ) => self.parse_switch(
+                basic_blocks,
+                label_id,
+                basic_block,
+                targets,
+                default_label,
+                merge_block,
+            ),
             (&Instruction::Switch32 { .. }, _) => unreachable!("missing merge instruction"),
             (&Instruction::Switch64 { .. }, _) => unreachable!("missing merge instruction"),
             (&Instruction::Kill {}, _) => Rc::new(DiscardNode {
             (&Instruction::Switch32 { .. }, _) => unreachable!("missing merge instruction"),
             (&Instruction::Switch64 { .. }, _) => unreachable!("missing merge instruction"),
             (&Instruction::Kill {}, _) => Rc::new(DiscardNode {
@@ -1103,8 +1182,142 @@ mod tests {
             &[
                 SerializedCFGElement::Switch,
                 SerializedCFGElement::SwitchCase,
             &[
                 SerializedCFGElement::Switch,
                 SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::SwitchFallthrough,
+                SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchMerge,
+                SerializedCFGElement::SwitchEnd,
+                SerializedCFGElement::Return,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cfg_switch_fallthrough_default_fallthrough_break() {
+        let mut id_factory = IdFactory::new();
+        let mut instructions = Vec::new();
+
+        let label_start = id_factory.next();
+        let label_case1 = id_factory.next();
+        let label_default = id_factory.next();
+        let label_case2 = id_factory.next();
+        let label_merge = id_factory.next();
+
+        instructions.push(Instruction::NoLine);
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_start),
+        });
+        instructions.push(Instruction::SelectionMerge {
+            merge_block: label_merge,
+            selection_control: spirv_parser::SelectionControl::default(),
+        });
+        instructions.push(Instruction::Switch64 {
+            selector: id_factory.next(),
+            default: label_default,
+            target: vec![(0, label_case1), (1, label_case1), (2, label_case2)],
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case1),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_default,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_default),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_case2,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case2),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_merge,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_merge),
+        });
+        instructions.push(Instruction::Return);
+
+        test_cfg(
+            &instructions,
+            &[
+                SerializedCFGElement::Switch,
+                SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::SwitchFallthrough,
+                SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchFallthrough,
+                SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::SwitchMerge,
+                SerializedCFGElement::SwitchEnd,
                 SerializedCFGElement::Return,
                 SerializedCFGElement::Return,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cfg_switch_break_default_fallthrough_break() {
+        let mut id_factory = IdFactory::new();
+        let mut instructions = Vec::new();
+
+        let label_start = id_factory.next();
+        let label_case1 = id_factory.next();
+        let label_default = id_factory.next();
+        let label_case2 = id_factory.next();
+        let label_merge = id_factory.next();
+
+        instructions.push(Instruction::NoLine);
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_start),
+        });
+        instructions.push(Instruction::SelectionMerge {
+            merge_block: label_merge,
+            selection_control: spirv_parser::SelectionControl::default(),
+        });
+        instructions.push(Instruction::Switch32 {
+            selector: id_factory.next(),
+            default: label_default,
+            target: vec![(0, label_case1), (1, label_case1), (2, label_case2)],
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case1),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_merge,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_default),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_case2,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_case2),
+        });
+        instructions.push(Instruction::Branch {
+            target_label: label_merge,
+        });
+
+        instructions.push(Instruction::Label {
+            id_result: IdResult(label_merge),
+        });
+        instructions.push(Instruction::Return);
+
+        test_cfg(
+            &instructions,
+            &[
+                SerializedCFGElement::Switch,
+                SerializedCFGElement::SwitchCase,
+                SerializedCFGElement::SwitchMerge,
                 SerializedCFGElement::SwitchDefaultCase,
                 SerializedCFGElement::SwitchDefaultCase,
+                SerializedCFGElement::SwitchFallthrough,
+                SerializedCFGElement::SwitchCase,
                 SerializedCFGElement::SwitchMerge,
                 SerializedCFGElement::SwitchEnd,
                 SerializedCFGElement::Return,
                 SerializedCFGElement::SwitchMerge,
                 SerializedCFGElement::SwitchEnd,
                 SerializedCFGElement::Return,