def on_Switch(self, stmt):
test = self.rrhs_compiler(stmt.test)
cases = []
- for value, stmts in stmt.cases.items():
- if value is None:
+ for values, stmts in stmt.cases.items():
+ if values == ():
check = lambda test: True
else:
- if "-" in value:
- mask = "".join("0" if b == "-" else "1" for b in value)
- value = "".join("0" if b == "-" else b for b in value)
- else:
- mask = "1" * len(value)
- mask = int(mask, 2)
- value = int(value, 2)
- def make_check(mask, value):
- return lambda test: test & mask == value
- check = make_check(mask, value)
+ check = lambda test: False
+ def make_check(mask, value, prev_check):
+ return lambda test: prev_check(test) or test & mask == value
+ for value in values:
+ if "-" in value:
+ mask = "".join("0" if b == "-" else "1" for b in value)
+ value = "".join("0" if b == "-" else b for b in value)
+ else:
+ mask = "1" * len(value)
+ mask = int(mask, 2)
+ value = int(value, 2)
+ check = make_check(mask, value, check)
cases.append((check, self.on_statements(stmts)))
def run(state):
test_value = test(state)
def __exit__(self, *args):
self.rtlil._append("{}end\n", " " * self.indent)
- def case(self, value=None):
- if value is None:
+ def case(self, *values):
+ if values == ():
self.rtlil._append("{}case\n", " " * (self.indent + 1))
else:
- self.rtlil._append("{}case {}'{}\n", " " * (self.indent + 1),
- len(value), value)
+ self.rtlil._append("{}case {}\n", " " * (self.indent + 1),
+ ", ".join("{}'{}".format(len(value), value) for value in values))
return _CaseBuilder(self.rtlil, self.indent + 2)
self._has_rhs = False
@contextmanager
- def case(self, switch, value):
+ def case(self, switch, values):
try:
old_case = self._case
- with switch.case(value) as self._case:
+ with switch.case(*values) as self._case:
yield
finally:
self._case = old_case
test_sigspec = self._test_cache[stmt]
with self._case.switch(test_sigspec) as switch:
- for value, stmts in stmt.cases.items():
- with self.case(switch, value):
+ for values, stmts in stmt.cases.items():
+ with self.case(switch, values):
self.on_statements(stmts)
def on_statement(self, stmt):
or choice > key):
key = choice
elif isinstance(key, str) and key == "default":
- key = None
+ key = ()
else:
- key = "{:0{}b}".format(wrap(key).value, len(self.test))
+ key = ("{:0{}b}".format(wrap(key).value, len(self.test)),)
stmts = self.cases[key]
del self.cases[key]
- self.cases[None] = stmts
+ self.cases[()] = stmts
return self
def __init__(self, test, cases):
self.test = Value.wrap(test)
self.cases = OrderedDict()
- for key, stmts in cases.items():
- if isinstance(key, (bool, int)):
- key = "{:0{}b}".format(key, len(self.test))
- elif isinstance(key, str):
- pass
- elif key is None:
- pass
- else:
- raise TypeError("Object '{!r}' cannot be used as a switch key"
- .format(key))
- assert key is None or len(key) == len(self.test)
+ for keys, stmts in cases.items():
+ # Map: None -> (); key -> (key,); (key...) -> (key...)
+ if keys is None:
+ keys = ()
+ if not isinstance(keys, tuple):
+ keys = (keys,)
+ # Map: 2 -> "0010"; "0010" -> "0010"
+ new_keys = ()
+ for key in keys:
+ if isinstance(key, (bool, int)):
+ key = "{:0{}b}".format(key, len(self.test))
+ elif isinstance(key, str):
+ pass
+ else:
+ raise TypeError("Object '{!r}' cannot be used as a switch key"
+ .format(key))
+ assert len(key) == len(self.test)
+ new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable):
stmts = [stmts]
- self.cases[key] = Statement.wrap(stmts)
+ self.cases[new_keys] = Statement.wrap(stmts)
def _lhs_signals(self):
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
return self.test._rhs_signals() | signals
def __repr__(self):
- cases = ["(default {})".format(" ".join(map(repr, stmts)))
- if key is None else
- "(case {} {})".format(key, " ".join(map(repr, stmts)))
- for key, stmts in self.cases.items()]
- return "(switch {!r} {})".format(self.test, " ".join(cases))
+ def case_repr(keys, stmts):
+ stmts_repr = " ".join(map(repr, stmts))
+ if keys == ():
+ return "(default {})".format(stmts_repr)
+ elif len(keys) == 1:
+ return "(case {} {})".format(keys[0], stmts_repr)
+ else:
+ return "(case ({}) {})".format(" ".join(keys), stmts_repr)
+ case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
+ return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
@final
self._pop_ctrl()
@contextmanager
- def Case(self, value=None):
+ def Case(self, *values):
self._check_context("Case", context="Switch")
switch_data = self._get_ctrl("Switch")
- if value is None:
- value = "-" * len(switch_data["test"])
- if isinstance(value, str) and len(value) != len(switch_data["test"]):
- raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
- .format(value, len(switch_data["test"])))
- omit_case = False
- if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
- warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
- "comparison will never be true"
- .format(value, len(switch_data["test"])), SyntaxWarning, stacklevel=3)
- omit_case = True
+ new_values = ()
+ for value in values:
+ if isinstance(value, str) and len(value) != len(switch_data["test"]):
+ raise SyntaxError("Case value '{}' must have the same width as test (which is {})"
+ .format(value, len(switch_data["test"])))
+ if isinstance(value, int) and bits_for(value) > len(switch_data["test"]):
+ warnings.warn("Case value '{:b}' is wider than test (which has width {}); "
+ "comparison will never be true"
+ .format(value, len(switch_data["test"])),
+ SyntaxWarning, stacklevel=3)
+ continue
+ new_values = (*new_values, value)
try:
_outer_case, self._statements = self._statements, []
self._ctrl_context = None
yield
self._flush_ctrl()
- if not omit_case:
- switch_data["cases"][value] = self._statements
+ # If none of the provided cases can possibly be true, omit this branch completely.
+ # This needs to be differentiated from no cases being provided in the first place,
+ # which means the branch will always match.
+ if not (values and not new_values):
+ switch_data["cases"][new_values] = self._statements
finally:
self._ctrl_context = "Switch"
self._statements = _outer_case
(
(switch (sig w1)
(case 0011 (eq (sig c1) (const 1'd1)))
- (case ---- (eq (sig c2) (const 1'd1)))
+ (default (eq (sig c2) (const 1'd1)))
)
)
""")