From 4777a7b3a2c4cbf449169b23b6b138de64ec1b3f Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 16 Sep 2019 18:59:28 +0000 Subject: [PATCH] hdl.{ast,dsl}: add Signal.enum; coerce Enum to Value; accept Enum patterns. Fixes #207. --- nmigen/hdl/ast.py | 43 +++++++++++++++++++++++++++++++---- nmigen/hdl/dsl.py | 8 ++++--- nmigen/test/test_hdl_ast.py | 45 ++++++++++++++++++++++++++++++++++++- nmigen/test/test_hdl_dsl.py | 20 ++++++++++++++++- 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 3b2a434..5e860fe 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -30,6 +30,21 @@ class DUID: DUID.__next_uid += 1 +def _enum_shape(enum_type): + min_value = min(member.value for member in enum_type) + max_value = max(member.value for member in enum_type) + if not isinstance(min_value, int) or not isinstance(max_value, int): + raise TypeError("Only enumerations with integer values can be converted to nMigen values") + sign = min_value < 0 or max_value < 0 + bits = max(bits_for(min_value, sign), bits_for(max_value, sign)) + return (bits, sign) + + +def _enum_to_bits(enum_value): + bits, sign = _enum_shape(type(enum_value)) + return format(enum_value.value & ((1 << bits) - 1), "b").rjust(bits, "0") + + class Value(metaclass=ABCMeta): @staticmethod def wrap(obj): @@ -39,6 +54,8 @@ class Value(metaclass=ABCMeta): return obj elif isinstance(obj, (bool, int)): return Const(obj) + elif isinstance(obj, Enum): + return Const(obj.value, _enum_shape(type(obj))) else: raise TypeError("Object '{!r}' is not an nMigen value".format(obj)) @@ -240,6 +257,10 @@ class Value(metaclass=ABCMeta): """ matches = [] for pattern in patterns: + if not isinstance(pattern, (int, str, Enum)): + raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, " + "not {!r}" + .format(pattern)) if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern): raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) " "bits" @@ -248,9 +269,6 @@ class Value(metaclass=ABCMeta): raise SyntaxError("Match pattern '{}' must have the same width as match value " "(which is {})" .format(pattern, len(self))) - if not isinstance(pattern, (int, str)): - raise SyntaxError("Match pattern must be an integer or a string, not {}" - .format(pattern)) if isinstance(pattern, int) and bits_for(pattern) > len(self): warnings.warn("Match pattern '{:b}' is wider than match value " "(which has width {}); comparison will never be true" @@ -259,7 +277,9 @@ class Value(metaclass=ABCMeta): continue if isinstance(pattern, int): matches.append(self == pattern) - elif isinstance(pattern, str): + elif isinstance(pattern, (str, Enum)): + if isinstance(pattern, Enum): + pattern = _enum_to_bits(pattern) mask = int(pattern.replace("0", "1").replace("-", "0"), 2) pattern = int(pattern.replace("-", "0"), 2) matches.append((self & mask) == pattern) @@ -784,6 +804,19 @@ class Signal(Value, DUID): bits_for(value_range.stop - value_range.step, signed)) return cls((nbits, signed), src_loc_at=1 + src_loc_at, **kwargs) + @classmethod + def enum(cls, enum_type, *, src_loc_at=0, **kwargs): + """Create Signal that can represent a given enumeration. + + Parameters + ---------- + enum : type (inheriting from :class:`enum.Enum`) + Enumeration to base this Signal on. + """ + if not issubclass(enum_type, Enum): + raise TypeError("Type {!r} is not an enumeration") + return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs) + @classmethod def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): """Create Signal based on another. @@ -1230,6 +1263,8 @@ class Switch(Statement): key = "{:0{}b}".format(key, len(self.test)) elif isinstance(key, str): pass + elif isinstance(key, Enum): + key = _enum_to_bits(key) else: raise TypeError("Object '{!r}' cannot be used as a switch key" .format(key)) diff --git a/nmigen/hdl/dsl.py b/nmigen/hdl/dsl.py index a82d1c6..024587f 100644 --- a/nmigen/hdl/dsl.py +++ b/nmigen/hdl/dsl.py @@ -1,6 +1,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Iterable from contextlib import contextmanager +from enum import Enum import warnings from ..tools import flatten, bits_for, deprecated @@ -264,6 +265,10 @@ class Module(_ModuleBuilderRoot, Elaboratable): switch_data = self._get_ctrl("Switch") new_patterns = () for pattern in patterns: + if not isinstance(pattern, (int, str, Enum)): + raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, " + "not {!r}" + .format(pattern)) if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern): raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) bits" .format(pattern)) @@ -271,9 +276,6 @@ class Module(_ModuleBuilderRoot, Elaboratable): raise SyntaxError("Case pattern '{}' must have the same width as switch value " "(which is {})" .format(pattern, len(switch_data["test"]))) - if not isinstance(pattern, (int, str)): - raise SyntaxError("Case pattern must be an integer or a string, not {}" - .format(pattern)) if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]): warnings.warn("Case pattern '{:b}' is wider than switch value " "(which has width {}); comparison will never be true" diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 30d6deb..e459b3d 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -5,6 +5,23 @@ from ..hdl.ast import * from .tools import * +class UnsignedEnum(Enum): + FOO = 1 + BAR = 2 + BAZ = 3 + + +class SignedEnum(Enum): + FOO = -1 + BAR = 0 + BAZ = +1 + + +class StringEnum(Enum): + FOO = "a" + BAR = "b" + + class ValueTestCase(FHDLTestCase): def test_wrap(self): self.assertIsInstance(Value.wrap(0), Const) @@ -15,6 +32,19 @@ class ValueTestCase(FHDLTestCase): msg="Object ''str'' is not an nMigen value"): Value.wrap("str") + def test_wrap_enum(self): + e1 = Value.wrap(UnsignedEnum.FOO) + self.assertIsInstance(e1, Const) + self.assertEqual(e1.shape(), (2, False)) + e2 = Value.wrap(SignedEnum.FOO) + self.assertIsInstance(e2, Const) + self.assertEqual(e2.shape(), (2, True)) + + def test_wrap_enum_wrong(self): + with self.assertRaises(TypeError, + msg="Only enumerations with integer values can be converted to nMigen values"): + Value.wrap(StringEnum.FOO) + def test_bool(self): with self.assertRaises(TypeError, msg="Attempted to convert nMigen value to boolean"): @@ -276,6 +306,12 @@ class OperatorTestCase(FHDLTestCase): (== (& (sig s) (const 4'd12)) (const 4'd8)) """) + def test_matches_enum(self): + s = Signal.enum(SignedEnum) + self.assertRepr(s.matches(SignedEnum.FOO), """ + (== (& (sig s) (const 2'd3)) (const 2'd3)) + """) + def test_matches_width_wrong(self): s = Signal(4) with self.assertRaises(SyntaxError, @@ -295,7 +331,7 @@ class OperatorTestCase(FHDLTestCase): def test_matches_pattern_wrong(self): s = Signal(4) with self.assertRaises(SyntaxError, - msg="Match pattern must be an integer or a string, not 1.0"): + msg="Match pattern must be an integer, a string, or an enumeration, not 1.0"): s.matches(1.0) def test_hash(self): @@ -605,6 +641,13 @@ class SignalTestCase(FHDLTestCase): self.assertEqual(s.decoder(1), "RED/1") self.assertEqual(s.decoder(3), "3") + def test_enum(self): + s1 = Signal.enum(UnsignedEnum) + self.assertEqual(s1.shape(), (2, False)) + s2 = Signal.enum(SignedEnum) + self.assertEqual(s2.shape(), (2, True)) + self.assertEqual(s2.decoder(SignedEnum.FOO), "FOO/-1") + class ClockSignalTestCase(FHDLTestCase): def test_domain(self): diff --git a/nmigen/test/test_hdl_dsl.py b/nmigen/test/test_hdl_dsl.py index d3419a3..c3355cd 100644 --- a/nmigen/test/test_hdl_dsl.py +++ b/nmigen/test/test_hdl_dsl.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from enum import Enum from ..hdl.ast import * from ..hdl.cd import * @@ -355,6 +356,23 @@ class DSLTestCase(FHDLTestCase): ) """) + def test_Switch_enum(self): + class Color(Enum): + RED = 1 + BLUE = 2 + m = Module() + se = Signal.enum(Color) + with m.Switch(se): + with m.Case(Color.RED): + m.d.comb += self.c1.eq(1) + self.assertRepr(m._statements, """ + ( + (switch (sig se) + (case 01 (eq (sig c1) (const 1'd1))) + ) + ) + """) + def test_Case_width_wrong(self): m = Module() with m.Switch(self.w1): @@ -385,7 +403,7 @@ class DSLTestCase(FHDLTestCase): m = Module() with m.Switch(self.w1): with self.assertRaises(SyntaxError, - msg="Case pattern must be an integer or a string, not 1.0"): + msg="Case pattern must be an integer, a string, or an enumeration, not 1.0"): with m.Case(1.0): pass -- 2.30.2