hdl.rec: proxy operators correctly.
authorawygle <awygle@gmail.com>
Mon, 9 Nov 2020 20:20:25 +0000 (12:20 -0800)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 31 Dec 2021 15:22:49 +0000 (15:22 +0000)
Commit abbebf8e used __getattr__ to proxy Value methods called on
Record. However, that did not proxy operators like __add__ because
Python looks up the special operator methods directly on the class
and does not run __getattr__ if they are missing.

Instead of using __getattr__, explicitly enumerate and wrap every
Value method that should be proxied. This also ensures backwards
compatibility if more methods are added to Value later.

Fixes #533.

nmigen/hdl/rec.py
tests/test_hdl_rec.py

index ddc833424c632a68d91fb235f3bff639e129b4f5..b72940070122db877bb389a6908db9c6a1e9fbbb 100644 (file)
@@ -143,17 +143,7 @@ class Record(ValueCastable):
                                                      src_loc_at=1 + src_loc_at)
 
     def __getattr__(self, name):
-        # must check `getattr` before `self` - we need to hit Value methods before fields
-        try:
-            value_attr = getattr(Value, name)
-            if callable(value_attr):
-                @wraps(value_attr)
-                def _wrapper(*args, **kwargs):
-                    return value_attr(self, *args, **kwargs)
-                return _wrapper
-            return value_attr
-        except AttributeError:
-            return self[name]
+        return self[name]
 
     def __getitem__(self, item):
         if isinstance(item, str):
@@ -257,3 +247,29 @@ class Record(ValueCastable):
                     stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
 
         return stmts
+
+def _valueproxy(name):
+    value_func = getattr(Value, name)
+    @wraps(value_func)
+    def _wrapper(self, *args, **kwargs):
+        return value_func(Value.cast(self), *args, **kwargs)
+    return _wrapper
+
+for name in [
+        "__bool__",
+        "__invert__", "__neg__",
+        "__add__", "__radd__", "__sub__", "__rsub__",
+        "__mul__", "__rmul__",
+        "__mod__", "__rmod__", "__floordiv__", "__rfloordiv__",
+        "__lshift__", "__rlshift__", "__rshift__", "__rrshift__",
+        "__and__", "__rand__", "__xor__", "__rxor__", "__or__", "__ror__",
+        "__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__",
+        "__abs__", "__len__",
+        "as_unsigned", "as_signed", "bool", "any", "all", "xor", "implies",
+        "bit_select", "word_select", "matches",
+        "shift_left", "shift_right", "rotate_left", "rotate_right", "eq"
+        ]:
+    setattr(Record, name, _valueproxy(name))
+
+del _valueproxy
+del name
index 7e8ae53c03a920708604d1e2a9521269506caa8e..452abdd52915fee65efbdcaa57715bccfd64ca7e 100644 (file)
@@ -211,6 +211,117 @@ class RecordTestCase(FHDLTestCase):
         r1 = Record([("a", UnsignedEnum)])
         self.assertEqual(r1.a.decoder(UnsignedEnum.FOO), "FOO/1")
 
+    def test_operators(self):
+        r1 = Record([("a", 1)])
+        s1 = Signal()
+
+        # __bool__
+        with self.assertRaisesRegex(TypeError,
+                r"^Attempted to convert nMigen value to Python boolean$"):
+            not r1
+
+        # __invert__, __neg__
+        self.assertEqual(repr(~r1), "(~ (cat (sig r1__a)))")
+        self.assertEqual(repr(-r1), "(- (cat (sig r1__a)))")
+
+        # __add__, __radd__, __sub__, __rsub__
+        self.assertEqual(repr(r1 + 1),  "(+ (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 + s1), "(+ (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 + r1),  "(+ (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 + r1), "(+ (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 - 1),  "(- (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 - s1), "(- (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 - r1),  "(- (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 - r1), "(- (sig s1) (cat (sig r1__a)))")
+
+        # __mul__, __rmul__
+        self.assertEqual(repr(r1 * 1),  "(* (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 * s1), "(* (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 * r1),  "(* (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 * r1), "(* (sig s1) (cat (sig r1__a)))")
+
+        # __mod__, __rmod__, __floordiv__, __rfloordiv__
+        self.assertEqual(repr(r1 % 1),   "(% (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 % s1),  "(% (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 % r1),   "(% (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 % r1),  "(% (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 // 1),  "(// (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 // s1), "(// (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 // r1),  "(// (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 // r1), "(// (sig s1) (cat (sig r1__a)))")
+
+        # __lshift__, __rlshift__, __rshift__, __rrshift__
+        self.assertEqual(repr(r1 >> 1),  "(>> (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 >> s1), "(>> (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 >> r1),  "(>> (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 >> r1), "(>> (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 << 1),  "(<< (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 << s1), "(<< (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 << r1),  "(<< (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 << r1), "(<< (sig s1) (cat (sig r1__a)))")
+
+        # __and__, __rand__, __xor__, __rxor__, __or__, __ror__
+        self.assertEqual(repr(r1 & 1),  "(& (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 & s1), "(& (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 & r1),  "(& (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 & r1), "(& (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 ^ 1),  "(^ (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 ^ s1), "(^ (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 ^ r1),  "(^ (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 ^ r1), "(^ (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 | 1),  "(| (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 | s1), "(| (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(1 | r1),  "(| (const 1'd1) (cat (sig r1__a)))")
+        self.assertEqual(repr(s1 | r1), "(| (sig s1) (cat (sig r1__a)))")
+
+        # __eq__, __ne__, __lt__, __le__, __gt__, __ge__
+        self.assertEqual(repr(r1 == 1),  "(== (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 != 1),  "(!= (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 < 1),   "(< (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 < s1),  "(< (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 < r1),  "(< (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 <= 1),  "(<= (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 > 1),   "(> (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 > s1),  "(> (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 > r1),  "(> (sig s1) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1 >= 1),  "(>= (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))")
+        self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))")
+
+        # __abs__, __len__
+        self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")
+        self.assertEqual(len(r1), 1)
+
+        # as_unsigned, as_signed, bool, any, all, xor, implies
+        self.assertEqual(repr(r1.as_unsigned()), "(u (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.as_signed()),   "(s (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.bool()),        "(b (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.any()),         "(r| (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.all()),         "(r& (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.xor()),         "(r^ (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.implies(1)),    "(| (~ (cat (sig r1__a))) (const 1'd1))")
+        self.assertEqual(repr(r1.implies(s1)),   "(| (~ (cat (sig r1__a))) (sig s1))")
+
+        # bit_select, word_select, matches,
+        self.assertEqual(repr(r1.bit_select(0, 1)),  "(slice (cat (sig r1__a)) 0:1)")
+        self.assertEqual(repr(r1.word_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)")
+        self.assertEqual(repr(r1.matches("1")),
+                "(== (& (cat (sig r1__a)) (const 1'd1)) (const 1'd1))")
+
+        # shift_left, shift_right, rotate_left, rotate_right, eq
+        self.assertEqual(repr(r1.shift_left(1)),  "(cat (const 1'd0) (cat (sig r1__a)))")
+        self.assertEqual(repr(r1.shift_right(1)), "(slice (cat (sig r1__a)) 1:1)")
+        self.assertEqual(repr(r1.rotate_left(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
+        self.assertEqual(repr(r1.rotate_right(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
+        self.assertEqual(repr(r1.eq(1)), "(eq (cat (sig r1__a)) (const 1'd1))")
+        self.assertEqual(repr(r1.eq(s1)), "(eq (cat (sig r1__a)) (sig s1))")
+
 
 class ConnectTestCase(FHDLTestCase):
     def setUp_flat(self):