hdl.{ast,dsl}, back.{pysim,rtlil}: allow multiple case values.
authorwhitequark <cz@m-labs.hk>
Fri, 28 Jun 2019 04:37:08 +0000 (04:37 +0000)
committerwhitequark <cz@m-labs.hk>
Fri, 28 Jun 2019 04:37:08 +0000 (04:37 +0000)
This means that instead of:

    with m.Case(0b00):
        <body>
    with m.Case(0b01):
        <body>

it is legal to write:

    with m.Case(0b00, 0b01):
        <body>

with no change in semantics, and slightly nicer RTLIL or Verilog
output.

Fixes #103.

nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/compat/fhdl/structure.py
nmigen/hdl/ast.py
nmigen/hdl/dsl.py
nmigen/test/test_hdl_dsl.py

index 42f99542b29d6be764127417a46b6427bc11ca84..0f8e8db3c5171baa063a42d374b03151f8171198 100644 (file)
@@ -318,20 +318,22 @@ class _StatementCompiler(StatementVisitor):
     def on_Switch(self, stmt):
         test  = self.rrhs_compiler(stmt.test)
         cases = []
-        for value, stmts in stmt.cases.items():
-            if value is None:
+        for values, stmts in stmt.cases.items():
+            if values == ():
                 check = lambda test: True
             else:
-                if "-" in value:
-                    mask  = "".join("0" if b == "-" else "1" for b in value)
-                    value = "".join("0" if b == "-" else  b  for b in value)
-                else:
-                    mask  = "1" * len(value)
-                mask  = int(mask,  2)
-                value = int(value, 2)
-                def make_check(mask, value):
-                    return lambda test: test & mask == value
-                check = make_check(mask, value)
+                check = lambda test: False
+                def make_check(mask, value, prev_check):
+                    return lambda test: prev_check(test) or test & mask == value
+                for value in values:
+                    if "-" in value:
+                        mask  = "".join("0" if b == "-" else "1" for b in value)
+                        value = "".join("0" if b == "-" else  b  for b in value)
+                    else:
+                        mask  = "1" * len(value)
+                    mask  = int(mask,  2)
+                    value = int(value, 2)
+                    check = make_check(mask, value, check)
             cases.append((check, self.on_statements(stmts)))
         def run(state):
             test_value = test(state)
index 7537820553fd48dd75ebeef53ec04bb09cb49558..1fce29e27e7b0d315a0ea3a5ed5dca1f495ad124 100644 (file)
@@ -188,12 +188,12 @@ class _SwitchBuilder:
     def __exit__(self, *args):
         self.rtlil._append("{}end\n", "  " * self.indent)
 
-    def case(self, value=None):
-        if value is None:
+    def case(self, *values):
+        if values == ():
             self.rtlil._append("{}case\n", "  " * (self.indent + 1))
         else:
-            self.rtlil._append("{}case {}'{}\n", "  " * (self.indent + 1),
-                               len(value), value)
+            self.rtlil._append("{}case {}\n", "  " * (self.indent + 1),
+                               ", ".join("{}'{}".format(len(value), value) for value in values))
         return _CaseBuilder(self.rtlil, self.indent + 2)
 
 
@@ -590,10 +590,10 @@ class _StatementCompiler(xfrm.StatementVisitor):
         self._has_rhs     = False
 
     @contextmanager
-    def case(self, switch, value):
+    def case(self, switch, values):
         try:
             old_case = self._case
-            with switch.case(value) as self._case:
+            with switch.case(*values) as self._case:
                 yield
         finally:
             self._case = old_case
@@ -645,8 +645,8 @@ class _StatementCompiler(xfrm.StatementVisitor):
         test_sigspec = self._test_cache[stmt]
 
         with self._case.switch(test_sigspec) as switch:
-            for value, stmts in stmt.cases.items():
-                with self.case(switch, value):
+            for values, stmts in stmt.cases.items():
+                with self.case(switch, values):
                     self.on_statements(stmts)
 
     def on_statement(self, stmt):
index 1f374d4e753e1722777fc980dadb9f0cca4aa3ee..26b1d341b162dfda6ebbcab0622a3a5328e6f186 100644 (file)
@@ -106,12 +106,12 @@ class Case(ast.Switch):
                         or choice > key):
                     key = choice
         elif isinstance(key, str) and key == "default":
-            key = None
+            key = ()
         else:
-            key = "{:0{}b}".format(wrap(key).value, len(self.test))
+            key = ("{:0{}b}".format(wrap(key).value, len(self.test)),)
         stmts = self.cases[key]
         del self.cases[key]
-        self.cases[None] = stmts
+        self.cases[()] = stmts
         return self
 
 
index 1bcbc5ef91663957a3474e2f5a9c8900caf6bfca..54393fe7c09218719912490fd4b3c426fd2db3e1 100644 (file)
@@ -1019,20 +1019,27 @@ class Switch(Statement):
     def __init__(self, test, cases):
         self.test  = Value.wrap(test)
         self.cases = OrderedDict()
-        for key, stmts in cases.items():
-            if isinstance(key, (bool, int)):
-                key = "{:0{}b}".format(key, len(self.test))
-            elif isinstance(key, str):
-                pass
-            elif key is None:
-                pass
-            else:
-                raise TypeError("Object '{!r}' cannot be used as a switch key"
-                                .format(key))
-            assert key is None or len(key) == len(self.test)
+        for keys, stmts in cases.items():
+            # Map: None -> (); key -> (key,); (key...) -> (key...)
+            if keys is None:
+                keys = ()
+            if not isinstance(keys, tuple):
+                keys = (keys,)
+            # Map: 2 -> "0010"; "0010" -> "0010"
+            new_keys = ()
+            for key in keys:
+                if isinstance(key, (bool, int)):
+                    key = "{:0{}b}".format(key, len(self.test))
+                elif isinstance(key, str):
+                    pass
+                else:
+                    raise TypeError("Object '{!r}' cannot be used as a switch key"
+                                    .format(key))
+                assert len(key) == len(self.test)
+                new_keys = (*new_keys, key)
             if not isinstance(stmts, Iterable):
                 stmts = [stmts]
-            self.cases[key] = Statement.wrap(stmts)
+            self.cases[new_keys] = Statement.wrap(stmts)
 
     def _lhs_signals(self):
         signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
@@ -1045,11 +1052,16 @@ class Switch(Statement):
         return self.test._rhs_signals() | signals
 
     def __repr__(self):
-        cases = ["(default {})".format(" ".join(map(repr, stmts)))
-                 if key is None else
-                 "(case {} {})".format(key, " ".join(map(repr, stmts)))
-                 for key, stmts in self.cases.items()]
-        return "(switch {!r} {})".format(self.test, " ".join(cases))
+        def case_repr(keys, stmts):
+            stmts_repr = " ".join(map(repr, stmts))
+            if keys == ():
+                return "(default {})".format(stmts_repr)
+            elif len(keys) == 1:
+                return "(case {} {})".format(keys[0], stmts_repr)
+            else:
+                return "(case ({}) {})".format(" ".join(keys), stmts_repr)
+        case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
+        return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
 
 
 @final
index 851d721149f9f866677b7a3ecf1a1addd89b3cce..a3beeefe9e112ea64406e3c1573b45656f0046e9 100644 (file)
@@ -214,27 +214,31 @@ class Module(_ModuleBuilderRoot, Elaboratable):
         self._pop_ctrl()
 
     @contextmanager
-    def Case(self, value=None):
+    def Case(self, *values):
         self._check_context("Case", context="Switch")
         switch_data = self._get_ctrl("Switch")
-        if value is None:
-            value = "-" * len(switch_data["test"])
-        if isinstance(value, str) and len(value) != len(switch_data["test"]):
-            raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
-                              .format(value, len(switch_data["test"])))
-        omit_case = False
-        if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
-            warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
-                          "comparison will never be true"
-                          .format(value, len(switch_data["test"])), SyntaxWarning, stacklevel=3)
-            omit_case = True
+        new_values = ()
+        for value in values:
+            if isinstance(value, str) and len(value) != len(switch_data["test"]):
+                raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
+                                  .format(value, len(switch_data["test"])))
+            if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
+                warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
+                              "comparison will never be true"
+                              .format(value, len(switch_data["test"])),
+                              SyntaxWarning, stacklevel=3)
+                continue
+            new_values = (*new_values, value)
         try:
             _outer_case, self._statements = self._statements, []
             self._ctrl_context = None
             yield
             self._flush_ctrl()
-            if not omit_case:
-                switch_data["cases"][value] = self._statements
+            # If none of the provided cases can possibly be true, omit this branch completely.
+            # This needs to be differentiated from no cases being provided in the first place,
+            # which means the branch will always match.
+            if not (values and not new_values):
+                switch_data["cases"][new_values] = self._statements
         finally:
             self._ctrl_context = "Switch"
             self._statements = _outer_case
index e15b1aacc769590edfbb62fddbcb99723005c021..04f7e8ca497f8f3521b700b0bb67bb2a503ad90f 100644 (file)
@@ -297,7 +297,7 @@ class DSLTestCase(FHDLTestCase):
         (
             (switch (sig w1)
                 (case 0011 (eq (sig c1) (const 1'd1)))
-                (case ---- (eq (sig c2) (const 1'd1)))
+                (default (eq (sig c2) (const 1'd1)))
             )
         )
         """)