hdl.{ast,dsl}: add Signal.enum; coerce Enum to Value; accept Enum patterns.
authorwhitequark <cz@m-labs.hk>
Mon, 16 Sep 2019 18:59:28 +0000 (18:59 +0000)
committerwhitequark <cz@m-labs.hk>
Mon, 16 Sep 2019 19:22:12 +0000 (19:22 +0000)
Fixes #207.

nmigen/hdl/ast.py
nmigen/hdl/dsl.py
nmigen/test/test_hdl_ast.py
nmigen/test/test_hdl_dsl.py

index 3b2a434ed51591897617f96dc29199f4578858d1..5e860fec8d4273888618a3081ee2547e120917a5 100644 (file)
@@ -30,6 +30,21 @@ class DUID:
         DUID.__next_uid += 1
 
 
+def _enum_shape(enum_type):
+    min_value = min(member.value for member in enum_type)
+    max_value = max(member.value for member in enum_type)
+    if not isinstance(min_value, int) or not isinstance(max_value, int):
+        raise TypeError("Only enumerations with integer values can be converted to nMigen values")
+    sign = min_value < 0 or max_value < 0
+    bits = max(bits_for(min_value, sign), bits_for(max_value, sign))
+    return (bits, sign)
+
+
+def _enum_to_bits(enum_value):
+    bits, sign = _enum_shape(type(enum_value))
+    return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0")
+
+
 class Value(metaclass=ABCMeta):
     @staticmethod
     def wrap(obj):
@@ -39,6 +54,8 @@ class Value(metaclass=ABCMeta):
             return obj
         elif isinstance(obj, (bool, int)):
             return Const(obj)
+        elif isinstance(obj, Enum):
+            return Const(obj.value, _enum_shape(type(obj)))
         else:
             raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
 
@@ -240,6 +257,10 @@ class Value(metaclass=ABCMeta):
         """
         matches = []
         for pattern in patterns:
+            if not isinstance(pattern, (int, str, Enum)):
+                raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
+                                  "not {!r}"
+                                  .format(pattern))
             if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
                 raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
                                   "bits"
@@ -248,9 +269,6 @@ class Value(metaclass=ABCMeta):
                 raise SyntaxError("Match pattern '{}' must have the same width as match value "
                                   "(which is {})"
                                   .format(pattern, len(self)))
-            if not isinstance(pattern, (int, str)):
-                raise SyntaxError("Match pattern must be an integer or a string, not {}"
-                                  .format(pattern))
             if isinstance(pattern, int) and bits_for(pattern) > len(self):
                 warnings.warn("Match pattern '{:b}' is wider than match value "
                               "(which has width {}); comparison will never be true"
@@ -259,7 +277,9 @@ class Value(metaclass=ABCMeta):
                 continue
             if isinstance(pattern, int):
                 matches.append(self == pattern)
-            elif isinstance(pattern, str):
+            elif isinstance(pattern, (str, Enum)):
+                if isinstance(pattern, Enum):
+                    pattern = _enum_to_bits(pattern)
                 mask    = int(pattern.replace("0", "1").replace("-", "0"), 2)
                 pattern = int(pattern.replace("-", "0"), 2)
                 matches.append((self & mask) == pattern)
@@ -784,6 +804,19 @@ class Signal(Value, DUID):
                      bits_for(value_range.stop - value_range.step, signed))
         return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
 
+    @classmethod
+    def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
+        """Create Signal that can represent a given enumeration.
+
+        Parameters
+        ----------
+        enum : type (inheriting from :class:`enum.Enum`)
+            Enumeration to base this Signal on.
+        """
+        if not issubclass(enum_type, Enum):
+            raise TypeError("Type {!r} is not an enumeration")
+        return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
+
     @classmethod
     def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
         """Create Signal based on another.
@@ -1230,6 +1263,8 @@ class Switch(Statement):
                     key = "{:0{}b}".format(key, len(self.test))
                 elif isinstance(key, str):
                     pass
+                elif isinstance(key, Enum):
+                    key = _enum_to_bits(key)
                 else:
                     raise TypeError("Object '{!r}' cannot be used as a switch key"
                                     .format(key))
index a82d1c69c6637464851eec0d4a73d8c227493604..024587fec4e8fcbbed62bb26808ae95c306d0ef7 100644 (file)
@@ -1,6 +1,7 @@
 from collections import OrderedDict, namedtuple
 from collections.abc import Iterable
 from contextlib import contextmanager
+from enum import Enum
 import warnings
 
 from ..tools import flatten, bits_for, deprecated
@@ -264,6 +265,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
         switch_data = self._get_ctrl("Switch")
         new_patterns = ()
         for pattern in patterns:
+            if not isinstance(pattern, (int, str, Enum)):
+                raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
+                                  "not {!r}"
+                                  .format(pattern))
             if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
                 raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
                                   .format(pattern))
@@ -271,9 +276,6 @@ class Module(_ModuleBuilderRoot, Elaboratable):
                 raise SyntaxError("Case pattern '{}' must have the same width as switch value "
                                   "(which is {})"
                                   .format(pattern, len(switch_data["test"])))
-            if not isinstance(pattern, (int, str)):
-                raise SyntaxError("Case pattern must be an integer or a string, not {}"
-                                  .format(pattern))
             if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
                 warnings.warn("Case pattern '{:b}' is wider than switch value "
                               "(which has width {}); comparison will never be true"
index 30d6deb4505e568a9033aceb52e107a579a9a262..e459b3dcd2a3ba2d0ed0c18267700247b9136ad7 100644 (file)
@@ -5,6 +5,23 @@ from ..hdl.ast import *
 from .tools import *
 
 
+class UnsignedEnum(Enum):
+    FOO = 1
+    BAR = 2
+    BAZ = 3
+
+
+class SignedEnum(Enum):
+    FOO = -1
+    BAR =  0
+    BAZ = +1
+
+
+class StringEnum(Enum):
+    FOO = "a"
+    BAR = "b"
+
+
 class ValueTestCase(FHDLTestCase):
     def test_wrap(self):
         self.assertIsInstance(Value.wrap(0), Const)
@@ -15,6 +32,19 @@ class ValueTestCase(FHDLTestCase):
                 msg="Object ''str'' is not an nMigen value"):
             Value.wrap("str")
 
+    def test_wrap_enum(self):
+        e1 = Value.wrap(UnsignedEnum.FOO)
+        self.assertIsInstance(e1, Const)
+        self.assertEqual(e1.shape(), (2, False))
+        e2 = Value.wrap(SignedEnum.FOO)
+        self.assertIsInstance(e2, Const)
+        self.assertEqual(e2.shape(), (2, True))
+
+    def test_wrap_enum_wrong(self):
+        with self.assertRaises(TypeError,
+                msg="Only enumerations with integer values can be converted to nMigen values"):
+            Value.wrap(StringEnum.FOO)
+
     def test_bool(self):
         with self.assertRaises(TypeError,
                 msg="Attempted to convert nMigen value to boolean"):
@@ -276,6 +306,12 @@ class OperatorTestCase(FHDLTestCase):
         (== (& (sig s) (const 4'd12)) (const 4'd8))
         """)
 
+    def test_matches_enum(self):
+        s = Signal.enum(SignedEnum)
+        self.assertRepr(s.matches(SignedEnum.FOO), """
+        (== (& (sig s) (const 2'd3)) (const 2'd3))
+        """)
+
     def test_matches_width_wrong(self):
         s = Signal(4)
         with self.assertRaises(SyntaxError,
@@ -295,7 +331,7 @@ class OperatorTestCase(FHDLTestCase):
     def test_matches_pattern_wrong(self):
         s = Signal(4)
         with self.assertRaises(SyntaxError,
-                msg="Match pattern must be an integer or a string, not 1.0"):
+                msg="Match pattern must be an integer, a string, or an enumeration, not 1.0"):
             s.matches(1.0)
 
     def test_hash(self):
@@ -605,6 +641,13 @@ class SignalTestCase(FHDLTestCase):
         self.assertEqual(s.decoder(1), "RED/1")
         self.assertEqual(s.decoder(3), "3")
 
+    def test_enum(self):
+        s1 = Signal.enum(UnsignedEnum)
+        self.assertEqual(s1.shape(), (2, False))
+        s2 = Signal.enum(SignedEnum)
+        self.assertEqual(s2.shape(), (2, True))
+        self.assertEqual(s2.decoder(SignedEnum.FOO), "FOO/-1")
+
 
 class ClockSignalTestCase(FHDLTestCase):
     def test_domain(self):
index d3419a37f08dfc49b27db092d36e657301be369c..c3355cdd9350f0d1103e4d6f3fb194aaee4a06d3 100644 (file)
@@ -1,4 +1,5 @@
 from collections import OrderedDict
+from enum import Enum
 
 from ..hdl.ast import *
 from ..hdl.cd import *
@@ -355,6 +356,23 @@ class DSLTestCase(FHDLTestCase):
         )
         """)
 
+    def test_Switch_enum(self):
+        class Color(Enum):
+            RED  = 1
+            BLUE = 2
+        m = Module()
+        se = Signal.enum(Color)
+        with m.Switch(se):
+            with m.Case(Color.RED):
+                m.d.comb += self.c1.eq(1)
+        self.assertRepr(m._statements, """
+        (
+            (switch (sig se)
+                (case 01 (eq (sig c1) (const 1'd1)))
+            )
+        )
+        """)
+
     def test_Case_width_wrong(self):
         m = Module()
         with m.Switch(self.w1):
@@ -385,7 +403,7 @@ class DSLTestCase(FHDLTestCase):
         m = Module()
         with m.Switch(self.w1):
             with self.assertRaises(SyntaxError,
-                    msg="Case pattern must be an integer or a string, not 1.0"):
+                    msg="Case pattern must be an integer, a string, or an enumeration, not 1.0"):
                 with m.Case(1.0):
                     pass