From 32446831b4a53784874702ebd737b15e8900505a Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 28 Jun 2019 04:37:08 +0000 Subject: [PATCH] hdl.{ast,dsl}, back.{pysim,rtlil}: allow multiple case values. This means that instead of: with m.Case(0b00): with m.Case(0b01): it is legal to write: with m.Case(0b00, 0b01): with no change in semantics, and slightly nicer RTLIL or Verilog output. Fixes #103. --- nmigen/back/pysim.py | 26 ++++++++++--------- nmigen/back/rtlil.py | 16 ++++++------ nmigen/compat/fhdl/structure.py | 6 ++--- nmigen/hdl/ast.py | 46 +++++++++++++++++++++------------ nmigen/hdl/dsl.py | 32 +++++++++++++---------- nmigen/test/test_hdl_dsl.py | 2 +- 6 files changed, 73 insertions(+), 55 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 42f9954..0f8e8db 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -318,20 +318,22 @@ class _StatementCompiler(StatementVisitor): def on_Switch(self, stmt): test = self.rrhs_compiler(stmt.test) cases = [] - for value, stmts in stmt.cases.items(): - if value is None: + for values, stmts in stmt.cases.items(): + if values == (): check = lambda test: True else: - if "-" in value: - mask = "".join("0" if b == "-" else "1" for b in value) - value = "".join("0" if b == "-" else b for b in value) - else: - mask = "1" * len(value) - mask = int(mask, 2) - value = int(value, 2) - def make_check(mask, value): - return lambda test: test & mask == value - check = make_check(mask, value) + check = lambda test: False + def make_check(mask, value, prev_check): + return lambda test: prev_check(test) or test & mask == value + for value in values: + if "-" in value: + mask = "".join("0" if b == "-" else "1" for b in value) + value = "".join("0" if b == "-" else b for b in value) + else: + mask = "1" * len(value) + mask = int(mask, 2) + value = int(value, 2) + check = make_check(mask, value, check) cases.append((check, self.on_statements(stmts))) def run(state): test_value = test(state) diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 7537820..1fce29e 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -188,12 +188,12 @@ class _SwitchBuilder: def __exit__(self, *args): self.rtlil._append("{}end\n", " " * self.indent) - def case(self, value=None): - if value is None: + def case(self, *values): + if values == (): self.rtlil._append("{}case\n", " " * (self.indent + 1)) else: - self.rtlil._append("{}case {}'{}\n", " " * (self.indent + 1), - len(value), value) + self.rtlil._append("{}case {}\n", " " * (self.indent + 1), + ", ".join("{}'{}".format(len(value), value) for value in values)) return _CaseBuilder(self.rtlil, self.indent + 2) @@ -590,10 +590,10 @@ class _StatementCompiler(xfrm.StatementVisitor): self._has_rhs = False @contextmanager - def case(self, switch, value): + def case(self, switch, values): try: old_case = self._case - with switch.case(value) as self._case: + with switch.case(*values) as self._case: yield finally: self._case = old_case @@ -645,8 +645,8 @@ class _StatementCompiler(xfrm.StatementVisitor): test_sigspec = self._test_cache[stmt] with self._case.switch(test_sigspec) as switch: - for value, stmts in stmt.cases.items(): - with self.case(switch, value): + for values, stmts in stmt.cases.items(): + with self.case(switch, values): self.on_statements(stmts) def on_statement(self, stmt): diff --git a/nmigen/compat/fhdl/structure.py b/nmigen/compat/fhdl/structure.py index 1f374d4..26b1d34 100644 --- a/nmigen/compat/fhdl/structure.py +++ b/nmigen/compat/fhdl/structure.py @@ -106,12 +106,12 @@ class Case(ast.Switch): or choice > key): key = choice elif isinstance(key, str) and key == "default": - key = None + key = () else: - key = "{:0{}b}".format(wrap(key).value, len(self.test)) + key = ("{:0{}b}".format(wrap(key).value, len(self.test)),) stmts = self.cases[key] del self.cases[key] - self.cases[None] = stmts + self.cases[()] = stmts return self diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 1bcbc5e..54393fe 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -1019,20 +1019,27 @@ class Switch(Statement): def __init__(self, test, cases): self.test = Value.wrap(test) self.cases = OrderedDict() - for key, stmts in cases.items(): - if isinstance(key, (bool, int)): - key = "{:0{}b}".format(key, len(self.test)) - elif isinstance(key, str): - pass - elif key is None: - pass - else: - raise TypeError("Object '{!r}' cannot be used as a switch key" - .format(key)) - assert key is None or len(key) == len(self.test) + for keys, stmts in cases.items(): + # Map: None -> (); key -> (key,); (key...) -> (key...) + if keys is None: + keys = () + if not isinstance(keys, tuple): + keys = (keys,) + # Map: 2 -> "0010"; "0010" -> "0010" + new_keys = () + for key in keys: + if isinstance(key, (bool, int)): + key = "{:0{}b}".format(key, len(self.test)) + elif isinstance(key, str): + pass + else: + raise TypeError("Object '{!r}' cannot be used as a switch key" + .format(key)) + assert len(key) == len(self.test) + new_keys = (*new_keys, key) if not isinstance(stmts, Iterable): stmts = [stmts] - self.cases[key] = Statement.wrap(stmts) + self.cases[new_keys] = Statement.wrap(stmts) def _lhs_signals(self): signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss), @@ -1045,11 +1052,16 @@ class Switch(Statement): return self.test._rhs_signals() | signals def __repr__(self): - cases = ["(default {})".format(" ".join(map(repr, stmts))) - if key is None else - "(case {} {})".format(key, " ".join(map(repr, stmts))) - for key, stmts in self.cases.items()] - return "(switch {!r} {})".format(self.test, " ".join(cases)) + def case_repr(keys, stmts): + stmts_repr = " ".join(map(repr, stmts)) + if keys == (): + return "(default {})".format(stmts_repr) + elif len(keys) == 1: + return "(case {} {})".format(keys[0], stmts_repr) + else: + return "(case ({}) {})".format(" ".join(keys), stmts_repr) + case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()] + return "(switch {!r} {})".format(self.test, " ".join(case_reprs)) @final diff --git a/nmigen/hdl/dsl.py b/nmigen/hdl/dsl.py index 851d721..a3beeef 100644 --- a/nmigen/hdl/dsl.py +++ b/nmigen/hdl/dsl.py @@ -214,27 +214,31 @@ class Module(_ModuleBuilderRoot, Elaboratable): self._pop_ctrl() @contextmanager - def Case(self, value=None): + def Case(self, *values): self._check_context("Case", context="Switch") switch_data = self._get_ctrl("Switch") - if value is None: - value = "-" * len(switch_data["test"]) - if isinstance(value, str) and len(value) != len(switch_data["test"]): - raise SyntaxError("Case value '{}' must have the same width as test (which is {})" - .format(value, len(switch_data["test"]))) - omit_case = False - if isinstance(value, int) and bits_for(value) > len(switch_data["test"]): - warnings.warn("Case value '{:b}' is wider than test (which has width {}); " - "comparison will never be true" - .format(value, len(switch_data["test"])), SyntaxWarning, stacklevel=3) - omit_case = True + new_values = () + for value in values: + if isinstance(value, str) and len(value) != len(switch_data["test"]): + raise SyntaxError("Case value '{}' must have the same width as test (which is {})" + .format(value, len(switch_data["test"]))) + if isinstance(value, int) and bits_for(value) > len(switch_data["test"]): + warnings.warn("Case value '{:b}' is wider than test (which has width {}); " + "comparison will never be true" + .format(value, len(switch_data["test"])), + SyntaxWarning, stacklevel=3) + continue + new_values = (*new_values, value) try: _outer_case, self._statements = self._statements, [] self._ctrl_context = None yield self._flush_ctrl() - if not omit_case: - switch_data["cases"][value] = self._statements + # If none of the provided cases can possibly be true, omit this branch completely. + # This needs to be differentiated from no cases being provided in the first place, + # which means the branch will always match. + if not (values and not new_values): + switch_data["cases"][new_values] = self._statements finally: self._ctrl_context = "Switch" self._statements = _outer_case diff --git a/nmigen/test/test_hdl_dsl.py b/nmigen/test/test_hdl_dsl.py index e15b1aa..04f7e8c 100644 --- a/nmigen/test/test_hdl_dsl.py +++ b/nmigen/test/test_hdl_dsl.py @@ -297,7 +297,7 @@ class DSLTestCase(FHDLTestCase): ( (switch (sig w1) (case 0011 (eq (sig c1) (const 1'd1))) - (case ---- (eq (sig c2) (const 1'd1))) + (default (eq (sig c2) (const 1'd1))) ) ) """) -- 2.30.2