# Unlike most Values, Record *can* be subclassed.
class Record(Value):
@classmethod
- def like(cls, other, name=None, name_suffix=None, src_loc_at=0):
+ def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0):
if name is not None:
new_name = str(name)
elif name_suffix is not None:
new_name = other.name + str(name_suffix)
else:
new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
- return cls(other.layout, new_name)
- def __init__(self, layout, name=None, src_loc_at=0, *, fields=None):
+ def concat(a, b):
+ if a is None:
+ return b
+ return "{}__{}".format(a, b)
+
+ fields = {}
+ for field_name in other.fields:
+ field = other[field_name]
+ if isinstance(field, Record):
+ fields[field_name] = Record.like(field, name=concat(new_name, field_name),
+ src_loc_at=1 + src_loc_at)
+ else:
+ fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
+ src_loc_at=1 + src_loc_at)
+
+ return cls(other.layout, new_name, fields=fields, src_loc_at=1)
+
+ def __init__(self, layout, name=None, *, fields=None, src_loc_at=0):
if name is None:
name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
else:
if isinstance(field_shape, Layout):
self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
- src_loc_at=src_loc_at + 1)
+ src_loc_at=1 + src_loc_at)
else:
self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
- src_loc_at=src_loc_at + 1)
+ src_loc_at=1 + src_loc_at)
def __getattr__(self, name):
return self[name]
r4 = Record.like(r1, name_suffix="foo")
self.assertEqual(r4.name, "r1foo")
+ def test_like_modifications(self):
+ r1 = Record([("a", 1), ("b", [("s", 1)])])
+ self.assertEqual(r1.a.name, "r1__a")
+ self.assertEqual(r1.b.name, "r1__b")
+ self.assertEqual(r1.b.s.name, "r1__b__s")
+ r1.a.reset = 1
+ r1.b.s.reset = 1
+ r2 = Record.like(r1)
+ self.assertEqual(r2.a.reset, 1)
+ self.assertEqual(r2.b.s.reset, 1)
+ self.assertEqual(r2.a.name, "r2__a")
+ self.assertEqual(r2.b.name, "r2__b")
+ self.assertEqual(r2.b.s.name, "r2__b__s")
+
def test_slice_tuple(self):
r1 = Record([("a", 1), ("b", 2), ("c", 3)])
r2 = r1["a", "c"]