From 0ab0a74ec1155857f4a4764c898cfdd8a33074f3 Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 8 Jul 2019 10:59:15 +0000 Subject: [PATCH] hdl.rec: respect modifications to signals in Record.like(). Fixes #126. --- nmigen/hdl/rec.py | 26 +++++++++++++++++++++----- nmigen/test/test_hdl_rec.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index 7828f76..d3d1b33 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -77,16 +77,32 @@ class Layout: # Unlike most Values, Record *can* be subclassed. class Record(Value): @classmethod - def like(cls, other, name=None, name_suffix=None, src_loc_at=0): + def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0): if name is not None: new_name = str(name) elif name_suffix is not None: new_name = other.name + str(name_suffix) else: new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None) - return cls(other.layout, new_name) - def __init__(self, layout, name=None, src_loc_at=0, *, fields=None): + def concat(a, b): + if a is None: + return b + return "{}__{}".format(a, b) + + fields = {} + for field_name in other.fields: + field = other[field_name] + if isinstance(field, Record): + fields[field_name] = Record.like(field, name=concat(new_name, field_name), + src_loc_at=1 + src_loc_at) + else: + fields[field_name] = Signal.like(field, name=concat(new_name, field_name), + src_loc_at=1 + src_loc_at) + + return cls(other.layout, new_name, fields=fields, src_loc_at=1) + + def __init__(self, layout, name=None, *, fields=None, src_loc_at=0): if name is None: name = tracer.get_var_name(depth=2 + src_loc_at, default=None) @@ -111,10 +127,10 @@ class Record(Value): else: if isinstance(field_shape, Layout): self.fields[field_name] = Record(field_shape, name=concat(name, field_name), - src_loc_at=src_loc_at + 1) + src_loc_at=1 + src_loc_at) else: self.fields[field_name] = Signal(field_shape, name=concat(name, field_name), - src_loc_at=src_loc_at + 1) + src_loc_at=1 + src_loc_at) def __getattr__(self, name): return self[name] diff --git a/nmigen/test/test_hdl_rec.py b/nmigen/test/test_hdl_rec.py index 7821c7d..2491587 100644 --- a/nmigen/test/test_hdl_rec.py +++ b/nmigen/test/test_hdl_rec.py @@ -151,6 +151,20 @@ class RecordTestCase(FHDLTestCase): r4 = Record.like(r1, name_suffix="foo") self.assertEqual(r4.name, "r1foo") + def test_like_modifications(self): + r1 = Record([("a", 1), ("b", [("s", 1)])]) + self.assertEqual(r1.a.name, "r1__a") + self.assertEqual(r1.b.name, "r1__b") + self.assertEqual(r1.b.s.name, "r1__b__s") + r1.a.reset = 1 + r1.b.s.reset = 1 + r2 = Record.like(r1) + self.assertEqual(r2.a.reset, 1) + self.assertEqual(r2.b.s.reset, 1) + self.assertEqual(r2.a.name, "r2__a") + self.assertEqual(r2.b.name, "r2__b") + self.assertEqual(r2.b.s.name, "r2__b__s") + def test_slice_tuple(self): r1 = Record([("a", 1), ("b", 2), ("c", 3)]) r2 = r1["a", "c"] -- 2.30.2