From 6b843b5be6fdd786fd4caf5a80ae51abff25b984 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 2 Jul 2019 17:44:55 +0000 Subject: [PATCH] hdl.rec: implement slicing by component names. Fixes #121. --- nmigen/hdl/rec.py | 17 +++++++++++++++-- nmigen/test/test_hdl_rec.py | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index 2c6694c..67c77c2 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -56,8 +56,15 @@ class Layout: shape = (shape, False) self.fields[name] = (shape, direction) - def __getitem__(self, name): - return self.fields[name] + def __getitem__(self, item): + if isinstance(item, tuple): + return Layout([ + (name, shape, dir) + for (name, (shape, dir)) in self.fields.items() + if name in item + ]) + + return self.fields[item] def __iter__(self): for name, (shape, dir) in self.fields.items(): @@ -121,6 +128,12 @@ class Record(Value): reference = "Record '{}'".format(self.name) raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?" .format(reference, item, ", ".join(self.fields))) from None + elif isinstance(item, tuple): + return Record(self.layout[item], fields={ + field_name: field_value + for field_name, field_value in self.fields.items() + if field_name in item + }) else: return super().__getitem__(item) diff --git a/nmigen/test/test_hdl_rec.py b/nmigen/test/test_hdl_rec.py index 9c72a16..7821c7d 100644 --- a/nmigen/test/test_hdl_rec.py +++ b/nmigen/test/test_hdl_rec.py @@ -25,6 +25,18 @@ class LayoutTestCase(FHDLTestCase): self.assertEqual(sublayout["a"], ((1, False), DIR_NONE)) self.assertEqual(sublayout["b"], ((1, False), DIR_NONE)) + def test_slice_tuple(self): + layout = Layout.wrap([ + ("a", 1), + ("b", 2), + ("c", 3) + ]) + expect = Layout.wrap([ + ("a", 1), + ("c", 3) + ]) + self.assertEqual(layout["a", "c"], expect) + def test_wrong_field(self): with self.assertRaises(TypeError, msg="Field (1,) has invalid layout: should be either (name, shape) or " @@ -139,6 +151,13 @@ class RecordTestCase(FHDLTestCase): r4 = Record.like(r1, name_suffix="foo") self.assertEqual(r4.name, "r1foo") + def test_slice_tuple(self): + r1 = Record([("a", 1), ("b", 2), ("c", 3)]) + r2 = r1["a", "c"] + self.assertEqual(r2.layout, Layout([("a", 1), ("c", 3)])) + self.assertIs(r2.a, r1.a) + self.assertIs(r2.c, r1.c) + class ConnectTestCase(FHDLTestCase): def setUp_flat(self): -- 2.30.2