From 1ef0a7876b510c057b5d4a9e1ea88c4a9f81b06e Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 26 Dec 2018 12:35:27 +0000 Subject: [PATCH] hdl.ir: add an API for retrieving generated values, like FSM signal. This is useful for tests. --- nmigen/hdl/ir.py | 21 +++++++++++++++++++++ nmigen/test/test_hdl_ir.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/nmigen/hdl/ir.py b/nmigen/hdl/ir.py index b03aedb..a4889b1 100644 --- a/nmigen/hdl/ir.py +++ b/nmigen/hdl/ir.py @@ -20,6 +20,7 @@ class Fragment: self.statements = [] self.domains = OrderedDict() self.subfragments = [] + self.generated = OrderedDict() def add_ports(self, *ports, dir): assert dir in ("i", "o", "io") @@ -83,6 +84,26 @@ class Fragment: assert isinstance(subfragment, Fragment) self.subfragments.append((subfragment, name)) + def find_subfragment(self, name_or_index): + if isinstance(name_or_index, int): + if name_or_index < len(self.subfragments): + subfragment, name = self.subfragments[name_or_index] + return subfragment + raise NameError("No subfragment at index #{}".format(name_or_index)) + else: + for subfragment, name in self.subfragments: + if name == name_or_index: + return subfragment + raise NameError("No subfragment with name '{}'".format(name_or_index)) + + def find_generated(self, *path): + if len(path) > 1: + path_component, *path = path + return self.find_subfragment(path_component).find_generated(*path) + else: + item, = path + return self.generated[item] + def get_fragment(self, platform): return self diff --git a/nmigen/test/test_hdl_ir.py b/nmigen/test/test_hdl_ir.py index f72b4e9..25956d4 100644 --- a/nmigen/test/test_hdl_ir.py +++ b/nmigen/test/test_hdl_ir.py @@ -7,6 +7,37 @@ from ..hdl.mem import * from .tools import * +class FragmentGeneratedTestCase(FHDLTestCase): + def test_find_subfragment(self): + f1 = Fragment() + f2 = Fragment() + f1.add_subfragment(f2, "f2") + + self.assertEqual(f1.find_subfragment(0), f2) + self.assertEqual(f1.find_subfragment("f2"), f2) + + def test_find_subfragment_wrong(self): + f1 = Fragment() + f2 = Fragment() + f1.add_subfragment(f2, "f2") + + with self.assertRaises(NameError, + msg="No subfragment at index #1"): + f1.find_subfragment(1) + with self.assertRaises(NameError, + msg="No subfragment with name 'fx'"): + f1.find_subfragment("fx") + + def test_find_generated(self): + f1 = Fragment() + f2 = Fragment() + f2.generated["sig"] = sig = Signal() + f1.add_subfragment(f2, "f2") + + self.assertEqual(SignalKey(f1.find_generated("f2", "sig")), + SignalKey(sig)) + + class FragmentDriversTestCase(FHDLTestCase): def test_empty(self): f = Fragment() -- 2.30.2