From 8d294f0bbb9e30f74383793737363b89af5322e6 Mon Sep 17 00:00:00 2001 From: awygle Date: Mon, 9 Nov 2020 12:20:25 -0800 Subject: [PATCH] hdl.rec: proxy operators correctly. Commit abbebf8e used __getattr__ to proxy Value methods called on Record. However, that did not proxy operators like __add__ because Python looks up the special operator methods directly on the class and does not run __getattr__ if they are missing. Instead of using __getattr__, explicitly enumerate and wrap every Value method that should be proxied. This also ensures backwards compatibility if more methods are added to Value later. Fixes #533. --- nmigen/hdl/rec.py | 38 ++++++++++----- tests/test_hdl_rec.py | 111 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 11 deletions(-) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index ddc8334..b729400 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -143,17 +143,7 @@ class Record(ValueCastable): src_loc_at=1 + src_loc_at) def __getattr__(self, name): - # must check `getattr` before `self` - we need to hit Value methods before fields - try: - value_attr = getattr(Value, name) - if callable(value_attr): - @wraps(value_attr) - def _wrapper(*args, **kwargs): - return value_attr(self, *args, **kwargs) - return _wrapper - return value_attr - except AttributeError: - return self[name] + return self[name] def __getitem__(self, item): if isinstance(item, str): @@ -257,3 +247,29 @@ class Record(ValueCastable): stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))] return stmts + +def _valueproxy(name): + value_func = getattr(Value, name) + @wraps(value_func) + def _wrapper(self, *args, **kwargs): + return value_func(Value.cast(self), *args, **kwargs) + return _wrapper + +for name in [ + "__bool__", + "__invert__", "__neg__", + "__add__", "__radd__", "__sub__", "__rsub__", + "__mul__", "__rmul__", + "__mod__", "__rmod__", "__floordiv__", "__rfloordiv__", + "__lshift__", "__rlshift__", "__rshift__", "__rrshift__", + "__and__", "__rand__", "__xor__", "__rxor__", "__or__", "__ror__", + "__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__", + "__abs__", "__len__", + "as_unsigned", "as_signed", "bool", "any", "all", "xor", "implies", + "bit_select", "word_select", "matches", + "shift_left", "shift_right", "rotate_left", "rotate_right", "eq" + ]: + setattr(Record, name, _valueproxy(name)) + +del _valueproxy +del name diff --git a/tests/test_hdl_rec.py b/tests/test_hdl_rec.py index 7e8ae53..452abdd 100644 --- a/tests/test_hdl_rec.py +++ b/tests/test_hdl_rec.py @@ -211,6 +211,117 @@ class RecordTestCase(FHDLTestCase): r1 = Record([("a", UnsignedEnum)]) self.assertEqual(r1.a.decoder(UnsignedEnum.FOO), "FOO/1") + def test_operators(self): + r1 = Record([("a", 1)]) + s1 = Signal() + + # __bool__ + with self.assertRaisesRegex(TypeError, + r"^Attempted to convert nMigen value to Python boolean$"): + not r1 + + # __invert__, __neg__ + self.assertEqual(repr(~r1), "(~ (cat (sig r1__a)))") + self.assertEqual(repr(-r1), "(- (cat (sig r1__a)))") + + # __add__, __radd__, __sub__, __rsub__ + self.assertEqual(repr(r1 + 1), "(+ (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 + s1), "(+ (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 + r1), "(+ (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 + r1), "(+ (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 - 1), "(- (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 - s1), "(- (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 - r1), "(- (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 - r1), "(- (sig s1) (cat (sig r1__a)))") + + # __mul__, __rmul__ + self.assertEqual(repr(r1 * 1), "(* (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 * s1), "(* (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 * r1), "(* (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 * r1), "(* (sig s1) (cat (sig r1__a)))") + + # __mod__, __rmod__, __floordiv__, __rfloordiv__ + self.assertEqual(repr(r1 % 1), "(% (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 % s1), "(% (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 % r1), "(% (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 % r1), "(% (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 // 1), "(// (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 // s1), "(// (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 // r1), "(// (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 // r1), "(// (sig s1) (cat (sig r1__a)))") + + # __lshift__, __rlshift__, __rshift__, __rrshift__ + self.assertEqual(repr(r1 >> 1), "(>> (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 >> s1), "(>> (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 >> r1), "(>> (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 >> r1), "(>> (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 << 1), "(<< (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 << s1), "(<< (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 << r1), "(<< (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 << r1), "(<< (sig s1) (cat (sig r1__a)))") + + # __and__, __rand__, __xor__, __rxor__, __or__, __ror__ + self.assertEqual(repr(r1 & 1), "(& (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 & s1), "(& (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 & r1), "(& (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 & r1), "(& (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 ^ 1), "(^ (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 ^ s1), "(^ (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 ^ r1), "(^ (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 ^ r1), "(^ (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 | 1), "(| (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 | s1), "(| (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(1 | r1), "(| (const 1'd1) (cat (sig r1__a)))") + self.assertEqual(repr(s1 | r1), "(| (sig s1) (cat (sig r1__a)))") + + # __eq__, __ne__, __lt__, __le__, __gt__, __ge__ + self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))") + self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))") + self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))") + + # __abs__, __len__ + self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))") + self.assertEqual(len(r1), 1) + + # as_unsigned, as_signed, bool, any, all, xor, implies + self.assertEqual(repr(r1.as_unsigned()), "(u (cat (sig r1__a)))") + self.assertEqual(repr(r1.as_signed()), "(s (cat (sig r1__a)))") + self.assertEqual(repr(r1.bool()), "(b (cat (sig r1__a)))") + self.assertEqual(repr(r1.any()), "(r| (cat (sig r1__a)))") + self.assertEqual(repr(r1.all()), "(r& (cat (sig r1__a)))") + self.assertEqual(repr(r1.xor()), "(r^ (cat (sig r1__a)))") + self.assertEqual(repr(r1.implies(1)), "(| (~ (cat (sig r1__a))) (const 1'd1))") + self.assertEqual(repr(r1.implies(s1)), "(| (~ (cat (sig r1__a))) (sig s1))") + + # bit_select, word_select, matches, + self.assertEqual(repr(r1.bit_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)") + self.assertEqual(repr(r1.word_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)") + self.assertEqual(repr(r1.matches("1")), + "(== (& (cat (sig r1__a)) (const 1'd1)) (const 1'd1))") + + # shift_left, shift_right, rotate_left, rotate_right, eq + self.assertEqual(repr(r1.shift_left(1)), "(cat (const 1'd0) (cat (sig r1__a)))") + self.assertEqual(repr(r1.shift_right(1)), "(slice (cat (sig r1__a)) 1:1)") + self.assertEqual(repr(r1.rotate_left(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))") + self.assertEqual(repr(r1.rotate_right(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))") + self.assertEqual(repr(r1.eq(1)), "(eq (cat (sig r1__a)) (const 1'd1))") + self.assertEqual(repr(r1.eq(s1)), "(eq (cat (sig r1__a)) (sig s1))") + class ConnectTestCase(FHDLTestCase): def setUp_flat(self): -- 2.30.2