fhdl.dsl: use less error-prone Switch/Case two-level syntax.
authorwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 07:11:06 +0000 (07:11 +0000)
committerwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 07:11:06 +0000 (07:11 +0000)
examples/pmux.py
nmigen/fhdl/dsl.py
nmigen/test/test_fhdl_dsl.py

index e69068b4e209d956b5e2a3e81883d7f5810f8c2c..cda54472edc9f79476a6f0d8eec819fe0b818b9c 100644 (file)
@@ -12,14 +12,15 @@ class ParMux:
 
     def get_fragment(self, platform):
         m = Module()
-        with m.Case(self.s, "--1"):
-            m.d.comb += self.o.eq(self.a)
-        with m.Case(self.s, "-1-"):
-            m.d.comb += self.o.eq(self.b)
-        with m.Case(self.s, "1--"):
-            m.d.comb += self.o.eq(self.c)
-        with m.Case(self.s):
-            m.d.comb += self.o.eq(0)
+        with m.Switch(self.s):
+            with m.Case("--1"):
+                m.d.comb += self.o.eq(self.a)
+            with m.Case("-1-"):
+                m.d.comb += self.o.eq(self.b)
+            with m.Case("1--"):
+                m.d.comb += self.o.eq(self.c)
+            with m.Case():
+                m.d.comb += self.o.eq(0)
         return m.lower(platform)
 
 
index 0c0be8cf6786a3c64605c5f13f9b82bc4e9bc441..9ef6303fe546d2439cbd995f53b395dccfbb2115 100644 (file)
@@ -64,30 +64,6 @@ class _ModuleBuilderRoot:
                              .format(type(self).__name__, name))
 
 
-class _ModuleBuilderCase(_ModuleBuilderRoot):
-    def __init__(self, builder, depth, test, value):
-        super().__init__(builder, depth)
-        self._test  = test
-        self._value = value
-
-    def __enter__(self):
-        if self._value is None:
-            self._value = "-" * len(self._test)
-        if isinstance(self._value, str) and len(self._test) != len(self._value):
-            raise SyntaxError("Case value {} must have the same width as test {}"
-                              .format(self._value, self._test))
-        if self._builder._stmt_switch_test != ValueKey(self._test):
-            self._builder._flush()
-            self._builder._stmt_switch_test = ValueKey(self._test)
-        self._outer_case = self._builder._statements
-        self._builder._statements = []
-        return self
-
-    def __exit__(self, *args):
-        self._builder._stmt_switch_cases[self._value] = self._builder._statements
-        self._builder._statements = self._outer_case
-
-
 class _ModuleBuilderSubmodules:
     def __init__(self, builder):
         object.__setattr__(self, "_builder", builder)
@@ -106,86 +82,147 @@ class Module(_ModuleBuilderRoot):
         _ModuleBuilderRoot.__init__(self, self, depth=0)
         self.submodules = _ModuleBuilderSubmodules(self)
 
-        self._submodules        = []
-        self._driving           = ValueDict()
-        self._statements        = Statement.wrap([])
-        self._stmt_depth        = 0
+        self._submodules   = []
+        self._driving      = ValueDict()
+        self._statements   = Statement.wrap([])
+        self._ctrl_context = None
+        self._ctrl_stack   = []
         self._stmt_if_cond      = []
         self._stmt_if_bodies    = []
         self._stmt_switch_test  = None
         self._stmt_switch_cases = OrderedDict()
 
+    def _check_context(self, construct, context):
+        if self._ctrl_context != context:
+            if self._ctrl_context is None:
+                raise SyntaxError("{} is not permitted outside of {}"
+                                  .format(construct, context))
+            else:
+                raise SyntaxError("{} is not permitted inside of {}"
+                                  .format(construct, self._ctrl_context))
+
+    def _get_ctrl(self, name):
+        if self._ctrl_stack:
+            top_name, top_data = self._ctrl_stack[-1]
+            if top_name == name:
+                return top_data
+
+    def _flush_ctrl(self):
+        while len(self._ctrl_stack) > self.domain._depth:
+            self._pop_ctrl()
+
+    def _set_ctrl(self, name, data):
+        self._flush_ctrl()
+        self._ctrl_stack.append((name, data))
+        return data
+
     @contextmanager
     def If(self, cond):
-        self._flush()
+        self._check_context("If", context=None)
+        if_data = self._set_ctrl("If", {"tests": [], "bodies": []})
         try:
-            _outer_case = self._statements
-            self._statements = []
+            _outer_case, self._statements = self._statements, []
             self.domain._depth += 1
             yield
-            self._stmt_if_cond.append(cond)
-            self._stmt_if_bodies.append(self._statements)
+            self._flush_ctrl()
+            if_data["tests"].append(cond)
+            if_data["bodies"].append(self._statements)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
 
     @contextmanager
     def Elif(self, cond):
-        if not self._stmt_if_cond:
+        self._check_context("Elif", context=None)
+        if_data = self._get_ctrl("If")
+        if if_data is None:
             raise SyntaxError("Elif without preceding If")
         try:
-            _outer_case = self._statements
-            self._statements = []
+            _outer_case, self._statements = self._statements, []
             self.domain._depth += 1
             yield
-            self._stmt_if_cond.append(cond)
-            self._stmt_if_bodies.append(self._statements)
+            self._flush_ctrl()
+            if_data["tests"].append(cond)
+            if_data["bodies"].append(self._statements)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
 
     @contextmanager
     def Else(self):
-        if not self._stmt_if_cond:
+        self._check_context("Else", context=None)
+        if_data = self._get_ctrl("If")
+        if if_data is None:
             raise SyntaxError("Else without preceding If/Elif")
         try:
-            _outer_case = self._statements
-            self._statements = []
+            _outer_case, self._statements = self._statements, []
             self.domain._depth += 1
             yield
-            self._stmt_if_bodies.append(self._statements)
+            self._flush_ctrl()
+            if_data["bodies"].append(self._statements)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
-        self._flush()
+        self._pop_ctrl()
+
+    @contextmanager
+    def Switch(self, test):
+        self._check_context("Switch", context=None)
+        switch_data = self._set_ctrl("Switch", {"test": test, "cases": OrderedDict()})
+        try:
+            self._ctrl_context = "Switch"
+            self.domain._depth += 1
+            yield
+        finally:
+            self.domain._depth -= 1
+            self._ctrl_context = None
+        self._pop_ctrl()
 
-    def Case(self, test, value=None):
-        return _ModuleBuilderCase(self, self._stmt_depth + 1, test, value)
+    @contextmanager
+    def Case(self, value=None):
+        self._check_context("Case", context="Switch")
+        switch_data = self._get_ctrl("Switch")
+        if value is None:
+            value = "-" * len(switch_data["test"])
+        if isinstance(value, str) and len(switch_data["test"]) != len(value):
+            raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
+                              .format(value, len(switch_data["test"])))
+        try:
+            _outer_case, self._statements = self._statements, []
+            self._ctrl_context = None
+            yield
+            self._flush_ctrl()
+            switch_data["cases"][value] = self._statements
+        finally:
+            self._ctrl_context = "Switch"
+            self._statements = _outer_case
+
+    def _pop_ctrl(self):
+        name, data = self._ctrl_stack.pop()
+
+        if name == "If":
+            if_tests, if_bodies = data["tests"], data["bodies"]
 
-    def _flush(self):
-        if self._stmt_if_cond:
             tests, cases = [], OrderedDict()
-            for if_cond, if_case in zip(self._stmt_if_cond + [None], self._stmt_if_bodies):
-                if if_cond is not None:
-                    if_cond = Value.wrap(if_cond)
-                    if len(if_cond) != 1:
-                        if_cond = if_cond.bool()
-                    tests.append(if_cond)
-
-                if if_cond is not None:
-                    match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-")
+            for if_test, if_case in zip(if_tests + [None], if_bodies):
+                if if_test is not None:
+                    if_test = Value.wrap(if_test)
+                    if len(if_test) != 1:
+                        if_test = if_test.bool()
+                    tests.append(if_test)
+
+                if if_test is not None:
+                    match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
                 else:
                     match = "-" * len(tests)
                 cases[match] = if_case
+
             self._statements.append(Switch(Cat(tests), cases))
 
-        if self._stmt_switch_test:
-            self._statements.append(Switch(self._stmt_switch_test.value, self._stmt_switch_cases))
+        if name == "Switch":
+            switch_test, switch_cases = data["test"], data["cases"]
 
-        self._stmt_if_cond      = []
-        self._stmt_if_bodies    = []
-        self._stmt_switch_test  = None
-        self._stmt_switch_cases = OrderedDict()
+            self._statements.append(Switch(switch_test, switch_cases))
 
     def _add_statement(self, assigns, cd_name, depth, compat_mode=False):
         def cd_human_name(cd_name):
@@ -194,9 +231,8 @@ class Module(_ModuleBuilderRoot):
             else:
                 return cd_name
 
-        if depth < self._stmt_depth:
-            self._flush()
-        self._stmt_depth = depth
+        while len(self._ctrl_stack) > self.domain._depth:
+            self._pop_ctrl()
 
         for assign in Statement.wrap(assigns):
             if not compat_mode and not isinstance(assign, Assign):
@@ -222,6 +258,10 @@ class Module(_ModuleBuilderRoot):
                             "a submodule".format(submodule))
         self._submodules.append((submodule, name))
 
+    def _flush(self):
+        while self._ctrl_stack:
+            self._pop_ctrl()
+
     def lower(self, platform):
         self._flush()
 
index 55e494cf3e8e7bc3918437bc27a061e72100c9f0..ba102dce43d67d624593358c2e5419996cf89aca 100644 (file)
@@ -21,7 +21,7 @@ class DSLTestCase(unittest.TestCase):
     def assertRaises(self, exception, msg=None):
         with super().assertRaises(exception) as cm:
             yield
-        if msg:
+        if msg is not None:
             # WTF? unittest.assertRaises is completely broken.
             self.assertEqual(str(cm.exception), msg)
 
@@ -158,6 +158,68 @@ class DSLTestCase(unittest.TestCase):
         )
         """)
 
+    def test_If_If(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+        with m.If(self.s2):
+            m.d.comb += self.c2.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1))
+                (case 1 (eq (sig c1) (const 1'd1)))
+            )
+            (switch (cat (sig s2))
+                (case 1 (eq (sig c2) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_If_nested_If(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+            with m.If(self.s2):
+                m.d.comb += self.c2.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1))
+                (case 1 (eq (sig c1) (const 1'd1))
+                    (switch (cat (sig s2))
+                        (case 1 (eq (sig c2) (const 1'd1)))
+                    )
+                )
+            )
+        )
+        """)
+
+    def test_If_dangling_Else(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+            with m.If(self.s2):
+                m.d.comb += self.c2.eq(1)
+        with m.Else():
+            m.d.comb += self.c3.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1))
+                (case 1
+                    (eq (sig c1) (const 1'd1))
+                    (switch (cat (sig s2))
+                        (case 1 (eq (sig c2) (const 1'd1)))
+                    )
+                )
+                (case -
+                    (eq (sig c3) (const 1'd1))
+                )
+            )
+        )
+        """)
+
     def test_Elif_wrong(self):
         m = Module()
         with self.assertRaises(SyntaxError,
@@ -185,7 +247,64 @@ class DSLTestCase(unittest.TestCase):
         )
         """)
 
-    def test_auto_flush(self):
+    def test_Switch(self):
+        m = Module()
+        with m.Switch(self.w1):
+            with m.Case(3):
+                m.d.comb += self.c1.eq(1)
+            with m.Case("11--"):
+                m.d.comb += self.c2.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (sig w1)
+                (case 0011 (eq (sig c1) (const 1'd1)))
+                (case 11-- (eq (sig c2) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_Switch_default(self):
+        m = Module()
+        with m.Switch(self.w1):
+            with m.Case(3):
+                m.d.comb += self.c1.eq(1)
+            with m.Case():
+                m.d.comb += self.c2.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (sig w1)
+                (case 0011 (eq (sig c1) (const 1'd1)))
+                (case ---- (eq (sig c2) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_Case_width_wrong(self):
+        m = Module()
+        with m.Switch(self.w1):
+            with self.assertRaises(SyntaxError,
+                    msg="Case value '--' must have the same width as test (which is 4)"):
+                with m.Case("--"):
+                    pass
+
+    def test_Case_outside_Switch_wrong(self):
+        m = Module()
+        with self.assertRaises(SyntaxError,
+                msg="Case is not permitted outside of Switch"):
+            with m.Case():
+                pass
+
+    def test_If_inside_Switch_wrong(self):
+        m = Module()
+        with m.Switch(self.s1):
+            with self.assertRaises(SyntaxError,
+                    msg="If is not permitted inside of Switch"):
+                with m.If(self.s2):
+                    pass
+
+    def test_auto_pop_ctrl(self):
         m = Module()
         with m.If(self.w1):
             m.d.comb += self.c1.eq(1)