From e8f79c5539a5988ab26c521b6e170eba6ab10a69 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 14 Sep 2019 21:06:12 +0000 Subject: [PATCH] hdl.ast: add Value.matches(), accepting same language as Case(). Fixes #202. --- nmigen/hdl/ast.py | 58 ++++++++++++++++++++++++++++++++++--- nmigen/test/test_hdl_ast.py | 35 ++++++++++++++++++++++ 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 51e1d0d..3b2a434 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -191,9 +191,9 @@ class Value(metaclass=ABCMeta): Parameters ---------- offset : Value, in - index of first selected bit + Index of first selected bit. width : int - number of selected bits + Number of selected bits. Returns ------- @@ -211,9 +211,9 @@ class Value(metaclass=ABCMeta): Parameters ---------- offset : Value, in - index of first selected word + Index of first selected word. width : int - number of selected bits + Number of selected bits. Returns ------- @@ -222,6 +222,56 @@ class Value(metaclass=ABCMeta): """ return Part(self, offset, width, stride=width, src_loc_at=1) + def matches(self, *patterns): + """Pattern matching. + + Matches against a set of patterns, which may be integers or bit strings, recognizing + the same grammar as ``Case()``. + + Parameters + ---------- + patterns : int or str + Patterns to match against. + + Returns + ------- + Value, out + ``1`` if any pattern matches the value, ``0`` otherwise. + """ + matches = [] + for pattern in patterns: + if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern): + raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " + "bits" + .format(pattern)) + if isinstance(pattern, str) and len(pattern) != len(self): + raise SyntaxError("Match pattern '{}' must have the same width as match value " + "(which is {})" + .format(pattern, len(self))) + if not isinstance(pattern, (int, str)): + raise SyntaxError("Match pattern must be an integer or a string, not {}" + .format(pattern)) + if isinstance(pattern, int) and bits_for(pattern) > len(self): + warnings.warn("Match pattern '{:b}' is wider than match value " + "(which has width {}); comparison will never be true" + .format(pattern, len(self)), + SyntaxWarning, stacklevel=3) + continue + if isinstance(pattern, int): + matches.append(self == pattern) + elif isinstance(pattern, str): + mask = int(pattern.replace("0", "1").replace("-", "0"), 2) + pattern = int(pattern.replace("-", "0"), 2) + matches.append((self & mask) == pattern) + else: + assert False + if not matches: + return Const(0) + elif len(matches) == 1: + return matches[0] + else: + return Cat(*matches).any() + def eq(self, value): """Assignment. diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 2b85119..30d6deb 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -263,6 +263,41 @@ class OperatorTestCase(FHDLTestCase): v = Const(0b101).xor() self.assertEqual(repr(v), "(r^ (const 3'd5))") + def test_matches(self): + s = Signal(4) + self.assertRepr(s.matches(), "(const 1'd0)") + self.assertRepr(s.matches(1), """ + (== (sig s) (const 1'd1)) + """) + self.assertRepr(s.matches(0, 1), """ + (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1)))) + """) + self.assertRepr(s.matches("10--"), """ + (== (& (sig s) (const 4'd12)) (const 4'd8)) + """) + + def test_matches_width_wrong(self): + s = Signal(4) + with self.assertRaises(SyntaxError, + msg="Match pattern '--' must have the same width as match value (which is 4)"): + s.matches("--") + with self.assertWarns(SyntaxWarning, + msg="Match pattern '10110' is wider than match value (which has width 4); " + "comparison will never be true"): + s.matches(0b10110) + + def test_matches_bits_wrong(self): + s = Signal(4) + with self.assertRaises(SyntaxError, + msg="Match pattern 'abc' must consist of 0, 1, and - (don't care) bits"): + s.matches("abc") + + def test_matches_pattern_wrong(self): + s = Signal(4) + with self.assertRaises(SyntaxError, + msg="Match pattern must be an integer or a string, not 1.0"): + s.matches(1.0) + def test_hash(self): with self.assertRaises(TypeError): hash(Const(0) + Const(0)) -- 2.30.2