hdl.ast: add Value.matches(), accepting same language as Case().
authorwhitequark <cz@m-labs.hk>
Sat, 14 Sep 2019 21:06:12 +0000 (21:06 +0000)
committerwhitequark <cz@m-labs.hk>
Sat, 14 Sep 2019 21:06:12 +0000 (21:06 +0000)
Fixes #202.

nmigen/hdl/ast.py
nmigen/test/test_hdl_ast.py

index 51e1d0d28350a2f14f426f4c85b2d1cc3873997b..3b2a434ed51591897617f96dc29199f4578858d1 100644 (file)
@@ -191,9 +191,9 @@ class Value(metaclass=ABCMeta):
         Parameters
         ----------
         offset : Value, in
-            index of first selected bit
+            Index of first selected bit.
         width : int
-            number of selected bits
+            Number of selected bits.
 
         Returns
         -------
@@ -211,9 +211,9 @@ class Value(metaclass=ABCMeta):
         Parameters
         ----------
         offset : Value, in
-            index of first selected word
+            Index of first selected word.
         width : int
-            number of selected bits
+            Number of selected bits.
 
         Returns
         -------
@@ -222,6 +222,56 @@ class Value(metaclass=ABCMeta):
         """
         return Part(self, offset, width, stride=width, src_loc_at=1)
 
+    def matches(self, *patterns):
+        """Pattern matching.
+
+        Matches against a set of patterns, which may be integers or bit strings, recognizing
+        the same grammar as ``Case()``.
+
+        Parameters
+        ----------
+        patterns : int or str
+            Patterns to match against.
+
+        Returns
+        -------
+        Value, out
+            ``1`` if any pattern matches the value, ``0`` otherwise.
+        """
+        matches = []
+        for pattern in patterns:
+            if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
+                raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
+                                  "bits"
+                                  .format(pattern))
+            if isinstance(pattern, str) and len(pattern) != len(self):
+                raise SyntaxError("Match pattern '{}' must have the same width as match value "
+                                  "(which is {})"
+                                  .format(pattern, len(self)))
+            if not isinstance(pattern, (int, str)):
+                raise SyntaxError("Match pattern must be an integer or a string, not {}"
+                                  .format(pattern))
+            if isinstance(pattern, int) and bits_for(pattern) > len(self):
+                warnings.warn("Match pattern '{:b}' is wider than match value "
+                              "(which has width {}); comparison will never be true"
+                              .format(pattern, len(self)),
+                              SyntaxWarning, stacklevel=3)
+                continue
+            if isinstance(pattern, int):
+                matches.append(self == pattern)
+            elif isinstance(pattern, str):
+                mask    = int(pattern.replace("0", "1").replace("-", "0"), 2)
+                pattern = int(pattern.replace("-", "0"), 2)
+                matches.append((self & mask) == pattern)
+            else:
+                assert False
+        if not matches:
+            return Const(0)
+        elif len(matches) == 1:
+            return matches[0]
+        else:
+            return Cat(*matches).any()
+
     def eq(self, value):
         """Assignment.
 
index 2b851195d407a160c672610a150310b404c0a837..30d6deb4505e568a9033aceb52e107a579a9a262 100644 (file)
@@ -263,6 +263,41 @@ class OperatorTestCase(FHDLTestCase):
         v = Const(0b101).xor()
         self.assertEqual(repr(v), "(r^ (const 3'd5))")
 
+    def test_matches(self):
+        s = Signal(4)
+        self.assertRepr(s.matches(), "(const 1'd0)")
+        self.assertRepr(s.matches(1), """
+        (== (sig s) (const 1'd1))
+        """)
+        self.assertRepr(s.matches(0, 1), """
+        (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
+        """)
+        self.assertRepr(s.matches("10--"), """
+        (== (& (sig s) (const 4'd12)) (const 4'd8))
+        """)
+
+    def test_matches_width_wrong(self):
+        s = Signal(4)
+        with self.assertRaises(SyntaxError,
+                msg="Match pattern '--' must have the same width as match value (which is 4)"):
+            s.matches("--")
+        with self.assertWarns(SyntaxWarning,
+                msg="Match pattern '10110' is wider than match value (which has width 4); "
+                    "comparison will never be true"):
+            s.matches(0b10110)
+
+    def test_matches_bits_wrong(self):
+        s = Signal(4)
+        with self.assertRaises(SyntaxError,
+                msg="Match pattern 'abc' must consist of 0, 1, and - (don't care) bits"):
+            s.matches("abc")
+
+    def test_matches_pattern_wrong(self):
+        s = Signal(4)
+        with self.assertRaises(SyntaxError,
+                msg="Match pattern must be an integer or a string, not 1.0"):
+            s.matches(1.0)
+
     def test_hash(self):
         with self.assertRaises(TypeError):
             hash(Const(0) + Const(0))