Fixes #207.
DUID.__next_uid += 1
+def _enum_shape(enum_type):
+ min_value = min(member.value for member in enum_type)
+ max_value = max(member.value for member in enum_type)
+ if not isinstance(min_value, int) or not isinstance(max_value, int):
+ raise TypeError("Only enumerations with integer values can be converted to nMigen values")
+ sign = min_value < 0 or max_value < 0
+ bits = max(bits_for(min_value, sign), bits_for(max_value, sign))
+ return (bits, sign)
+
+
+def _enum_to_bits(enum_value):
+ bits, sign = _enum_shape(type(enum_value))
+ return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0")
+
+
class Value(metaclass=ABCMeta):
@staticmethod
def wrap(obj):
return obj
elif isinstance(obj, (bool, int)):
return Const(obj)
+ elif isinstance(obj, Enum):
+ return Const(obj.value, _enum_shape(type(obj)))
else:
raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
"""
matches = []
for pattern in patterns:
+ if not isinstance(pattern, (int, str, Enum)):
+ raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
+ "not {!r}"
+ .format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
"bits"
raise SyntaxError("Match pattern '{}' must have the same width as match value "
"(which is {})"
.format(pattern, len(self)))
- if not isinstance(pattern, (int, str)):
- raise SyntaxError("Match pattern must be an integer or a string, not {}"
- .format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(self):
warnings.warn("Match pattern '{:b}' is wider than match value "
"(which has width {}); comparison will never be true"
continue
if isinstance(pattern, int):
matches.append(self == pattern)
- elif isinstance(pattern, str):
+ elif isinstance(pattern, (str, Enum)):
+ if isinstance(pattern, Enum):
+ pattern = _enum_to_bits(pattern)
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
pattern = int(pattern.replace("-", "0"), 2)
matches.append((self & mask) == pattern)
bits_for(value_range.stop - value_range.step, signed))
return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs)
+ @classmethod
+ def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
+ """Create Signal that can represent a given enumeration.
+
+ Parameters
+ ----------
+ enum : type (inheriting from :class:`enum.Enum`)
+ Enumeration to base this Signal on.
+ """
+ if not issubclass(enum_type, Enum):
+ raise TypeError("Type {!r} is not an enumeration")
+ return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
+
@classmethod
def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
"""Create Signal based on another.
key = "{:0{}b}".format(key, len(self.test))
elif isinstance(key, str):
pass
+ elif isinstance(key, Enum):
+ key = _enum_to_bits(key)
else:
raise TypeError("Object '{!r}' cannot be used as a switch key"
.format(key))
from collections import OrderedDict, namedtuple
from collections.abc import Iterable
from contextlib import contextmanager
+from enum import Enum
import warnings
from ..tools import flatten, bits_for, deprecated
switch_data = self._get_ctrl("Switch")
new_patterns = ()
for pattern in patterns:
+ if not isinstance(pattern, (int, str, Enum)):
+ raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
+ "not {!r}"
+ .format(pattern))
if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits"
.format(pattern))
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
"(which is {})"
.format(pattern, len(switch_data["test"])))
- if not isinstance(pattern, (int, str)):
- raise SyntaxError("Case pattern must be an integer or a string, not {}"
- .format(pattern))
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
warnings.warn("Case pattern '{:b}' is wider than switch value "
"(which has width {}); comparison will never be true"
from .tools import *
+class UnsignedEnum(Enum):
+ FOO = 1
+ BAR = 2
+ BAZ = 3
+
+
+class SignedEnum(Enum):
+ FOO = -1
+ BAR = 0
+ BAZ = +1
+
+
+class StringEnum(Enum):
+ FOO = "a"
+ BAR = "b"
+
+
class ValueTestCase(FHDLTestCase):
def test_wrap(self):
self.assertIsInstance(Value.wrap(0), Const)
msg="Object ''str'' is not an nMigen value"):
Value.wrap("str")
+ def test_wrap_enum(self):
+ e1 = Value.wrap(UnsignedEnum.FOO)
+ self.assertIsInstance(e1, Const)
+ self.assertEqual(e1.shape(), (2, False))
+ e2 = Value.wrap(SignedEnum.FOO)
+ self.assertIsInstance(e2, Const)
+ self.assertEqual(e2.shape(), (2, True))
+
+ def test_wrap_enum_wrong(self):
+ with self.assertRaises(TypeError,
+ msg="Only enumerations with integer values can be converted to nMigen values"):
+ Value.wrap(StringEnum.FOO)
+
def test_bool(self):
with self.assertRaises(TypeError,
msg="Attempted to convert nMigen value to boolean"):
(== (& (sig s) (const 4'd12)) (const 4'd8))
""")
+ def test_matches_enum(self):
+ s = Signal.enum(SignedEnum)
+ self.assertRepr(s.matches(SignedEnum.FOO), """
+ (== (& (sig s) (const 2'd3)) (const 2'd3))
+ """)
+
def test_matches_width_wrong(self):
s = Signal(4)
with self.assertRaises(SyntaxError,
def test_matches_pattern_wrong(self):
s = Signal(4)
with self.assertRaises(SyntaxError,
- msg="Match pattern must be an integer or a string, not 1.0"):
+ msg="Match pattern must be an integer, a string, or an enumeration, not 1.0"):
s.matches(1.0)
def test_hash(self):
self.assertEqual(s.decoder(1), "RED/1")
self.assertEqual(s.decoder(3), "3")
+ def test_enum(self):
+ s1 = Signal.enum(UnsignedEnum)
+ self.assertEqual(s1.shape(), (2, False))
+ s2 = Signal.enum(SignedEnum)
+ self.assertEqual(s2.shape(), (2, True))
+ self.assertEqual(s2.decoder(SignedEnum.FOO), "FOO/-1")
+
class ClockSignalTestCase(FHDLTestCase):
def test_domain(self):
from collections import OrderedDict
+from enum import Enum
from ..hdl.ast import *
from ..hdl.cd import *
)
""")
+ def test_Switch_enum(self):
+ class Color(Enum):
+ RED = 1
+ BLUE = 2
+ m = Module()
+ se = Signal.enum(Color)
+ with m.Switch(se):
+ with m.Case(Color.RED):
+ m.d.comb += self.c1.eq(1)
+ self.assertRepr(m._statements, """
+ (
+ (switch (sig se)
+ (case 01 (eq (sig c1) (const 1'd1)))
+ )
+ )
+ """)
+
def test_Case_width_wrong(self):
m = Module()
with m.Switch(self.w1):
m = Module()
with m.Switch(self.w1):
with self.assertRaises(SyntaxError,
- msg="Case pattern must be an integer or a string, not 1.0"):
+ msg="Case pattern must be an integer, a string, or an enumeration, not 1.0"):
with m.Case(1.0):
pass