From 28a22e440094a25e8756739ec042ccbb8842eaed Mon Sep 17 00:00:00 2001 From: awygle Date: Thu, 5 Nov 2020 17:10:39 -0800 Subject: [PATCH] hdl.rec: migrate Record from UserValue to ValueCastable. Closes #528. --- nmigen/hdl/rec.py | 35 +++++++++++++++++++++++++++-------- tests/test_hdl_rec.py | 4 ++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index 5e5687c..ddc8334 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -1,6 +1,6 @@ from enum import Enum from collections import OrderedDict -from functools import reduce +from functools import reduce, wraps from .. import tracer from .._utils import union, deprecated @@ -85,8 +85,7 @@ class Layout: return "Layout([{}])".format(", ".join(field_reprs)) -# Unlike most Values, Record *can* be subclassed. -class Record(UserValue): +class Record(ValueCastable): @staticmethod def like(other, *, name=None, name_suffix=None, src_loc_at=0): if name is not None: @@ -114,8 +113,6 @@ class Record(UserValue): return Record(other.layout, name=new_name, fields=fields, src_loc_at=1) def __init__(self, layout, *, name=None, fields=None, src_loc_at=0): - super().__init__(src_loc_at=src_loc_at) - if name is None: name = tracer.get_var_name(depth=2 + src_loc_at, default=None) @@ -146,7 +143,17 @@ class Record(UserValue): src_loc_at=1 + src_loc_at) def __getattr__(self, name): - return self[name] + # must check `getattr` before `self` - we need to hit Value methods before fields + try: + value_attr = getattr(Value, name) + if callable(value_attr): + @wraps(value_attr) + def _wrapper(*args, **kwargs): + return value_attr(self, *args, **kwargs) + return _wrapper + return value_attr + except AttributeError: + return self[name] def __getitem__(self, item): if isinstance(item, str): @@ -166,11 +173,23 @@ class Record(UserValue): if field_name in item }) else: - return super().__getitem__(item) + try: + return Value.__getitem__(self, item) + except KeyError: + if self.name is None: + reference = "Unnamed record" + else: + reference = "Record '{}'".format(self.name) + raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?" + .format(reference, item, ", ".join(self.fields))) from None - def lower(self): + @ValueCastable.lowermethod + def as_value(self): return Cat(self.fields.values()) + def __len__(self): + return len(self.as_value()) + def _lhs_signals(self): return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) diff --git a/tests/test_hdl_rec.py b/tests/test_hdl_rec.py index 718fa4a..7e8ae53 100644 --- a/tests/test_hdl_rec.py +++ b/tests/test_hdl_rec.py @@ -135,8 +135,8 @@ class RecordTestCase(FHDLTestCase): ("stb", 1), ]) - self.assertEqual(repr(r[0]), "(slice (rec r data stb) 0:1)") - self.assertEqual(repr(r[0:3]), "(slice (rec r data stb) 0:3)") + self.assertEqual(repr(r[0]), "(slice (cat (sig r__data) (sig r__stb)) 0:1)") + self.assertEqual(repr(r[0:3]), "(slice (cat (sig r__data) (sig r__stb)) 0:3)") def test_wrong_field(self): r = Record([ -- 2.30.2