hdl.rec: respect modifications to signals in Record.like().
authorwhitequark <whitequark@whitequark.org>
Mon, 8 Jul 2019 10:59:15 +0000 (10:59 +0000)
committerwhitequark <whitequark@whitequark.org>
Mon, 8 Jul 2019 10:59:15 +0000 (10:59 +0000)
Fixes #126.

nmigen/hdl/rec.py
nmigen/test/test_hdl_rec.py

index 7828f76fdb98c878caae79e5ba5fa421ec1cd243..d3d1b33ecf9137ae9fbf6897e0ec7f0be5a3a5a4 100644 (file)
@@ -77,16 +77,32 @@ class Layout:
 # Unlike most Values, Record *can* be subclassed.
 class Record(Value):
     @classmethod
-    def like(cls, other, name=None, name_suffix=None, src_loc_at=0):
+    def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0):
         if name is not None:
             new_name = str(name)
         elif name_suffix is not None:
             new_name = other.name + str(name_suffix)
         else:
             new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
-        return cls(other.layout, new_name)
 
-    def __init__(self, layout, name=None, src_loc_at=0, *, fields=None):
+        def concat(a, b):
+            if a is None:
+                return b
+            return "{}__{}".format(a, b)
+
+        fields = {}
+        for field_name in other.fields:
+            field = other[field_name]
+            if isinstance(field, Record):
+                fields[field_name] = Record.like(field, name=concat(new_name, field_name),
+                                                 src_loc_at=1 + src_loc_at)
+            else:
+                fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
+                                                 src_loc_at=1 + src_loc_at)
+
+        return cls(other.layout, new_name, fields=fields, src_loc_at=1)
+
+    def __init__(self, layout, name=None, *, fields=None, src_loc_at=0):
         if name is None:
             name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
 
@@ -111,10 +127,10 @@ class Record(Value):
             else:
                 if isinstance(field_shape, Layout):
                     self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
-                                                     src_loc_at=src_loc_at + 1)
+                                                     src_loc_at=1 + src_loc_at)
                 else:
                     self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
-                                                     src_loc_at=src_loc_at + 1)
+                                                     src_loc_at=1 + src_loc_at)
 
     def __getattr__(self, name):
         return self[name]
index 7821c7d6b167e98dc8606f812b8fe29797839c98..2491587d2924c5ef9cd6d3168d5f4e8d37ce4dc1 100644 (file)
@@ -151,6 +151,20 @@ class RecordTestCase(FHDLTestCase):
         r4 = Record.like(r1, name_suffix="foo")
         self.assertEqual(r4.name, "r1foo")
 
+    def test_like_modifications(self):
+        r1 = Record([("a", 1), ("b", [("s", 1)])])
+        self.assertEqual(r1.a.name, "r1__a")
+        self.assertEqual(r1.b.name, "r1__b")
+        self.assertEqual(r1.b.s.name, "r1__b__s")
+        r1.a.reset = 1
+        r1.b.s.reset = 1
+        r2 = Record.like(r1)
+        self.assertEqual(r2.a.reset, 1)
+        self.assertEqual(r2.b.s.reset, 1)
+        self.assertEqual(r2.a.name, "r2__a")
+        self.assertEqual(r2.b.name, "r2__b")
+        self.assertEqual(r2.b.s.name, "r2__b__s")
+
     def test_slice_tuple(self):
         r1 = Record([("a", 1), ("b", 2), ("c", 3)])
         r2 = r1["a", "c"]