From: whitequark Date: Fri, 28 Dec 2018 13:22:10 +0000 (+0000) Subject: hdl.rec: add basic record support. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=71883210c263c76d351e8569c81bd287a30fca71;p=nmigen.git hdl.rec: add basic record support. --- diff --git a/doc/COMPAT_SUMMARY.md b/doc/COMPAT_SUMMARY.md index 8cd07d5..5fe1682 100644 --- a/doc/COMPAT_SUMMARY.md +++ b/doc/COMPAT_SUMMARY.md @@ -162,13 +162,15 @@ Compatibility summary - (−) `timeline` ? - (−) `WaitTimer` ? - (−) `BitSlip` ? - - (−) `record` ? - - (−) `DIR_NONE`/`DIR_S_TO_M`/`DIR_M_TO_S` ? - - (−) `set_layout_parameters` ? - - (−) `layout_len` ? - - (−) `layout_get` ? - - (−) `layout_partial` ? - - (−) `Record` ? + - (−) `record` **obs** → `.hdl.rec.Record` + - (−) `DIR_NONE` id + - (−) `DIR_M_TO_S` → `DIR_FANOUT` + - (−) `DIR_S_TO_M` → `DIR_FANIN` + - (−) `set_layout_parameters` **brk** + - (−) `layout_len` **brk** + - (−) `layout_get` **brk** + - (−) `layout_partial` **brk** + - (−) `Record` id - (−) `resetsync` ? - (−) `AsyncResetSynchronizer` ? - (−) `roundrobin` ? diff --git a/examples/gpio.py b/examples/gpio.py index 06caab1..dbd2c68 100644 --- a/examples/gpio.py +++ b/examples/gpio.py @@ -10,20 +10,19 @@ class GPIO: def get_fragment(self, platform): m = Module() - m.d.comb += self.bus.dat_r.eq(self.pins[self.bus.adr]) + m.d.comb += self.bus.r_data.eq(self.pins[self.bus.addr]) with m.If(self.bus.we): - m.d.sync += self.pins[self.bus.adr].eq(self.bus.dat_w) + m.d.sync += self.pins[self.bus.addr].eq(self.bus.w_data) return m.lower(platform) if __name__ == "__main__": - # TODO: use Record - bus = SimpleNamespace( - adr =Signal(name="adr", max=8), - dat_r=Signal(name="dat_r"), - dat_w=Signal(name="dat_w"), - we =Signal(name="we"), - ) + bus = Record([ + ("addr", 3), + ("r_data", 1), + ("w_data", 1), + ("we", 1), + ]) pins = Signal(8) gpio = GPIO(Array(pins), bus) - main(gpio, ports=[pins, bus.adr, bus.dat_r, bus.dat_w, bus.we]) + main(gpio, ports=[pins, bus.addr, bus.r_data, bus.w_data, bus.we]) diff --git a/nmigen/__init__.py b/nmigen/__init__.py index a9ed14d..220f5bb 100644 --- a/nmigen/__init__.py +++ b/nmigen/__init__.py @@ -3,6 +3,7 @@ from .hdl.dsl import Module from .hdl.cd import ClockDomain from .hdl.ir import Fragment, Instance from .hdl.mem import Memory +from .hdl.rec import Record from .hdl.xfrm import ResetInserter, CEInserter from .lib.cdc import MultiReg diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 27bed92..7a5639c 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -72,7 +72,12 @@ class _State: normalize = Const.normalize -class _RHSValueCompiler(ValueVisitor): +class _ValueCompiler(ValueVisitor): + def on_Record(self, value): + return self(Cat(value.fields.values())) + + +class _RHSValueCompiler(_ValueCompiler): def __init__(self, signal_slots, sensitivity=None, mode="rhs"): self.signal_slots = signal_slots self.sensitivity = sensitivity @@ -202,7 +207,7 @@ class _RHSValueCompiler(ValueVisitor): return eval -class _LHSValueCompiler(ValueVisitor): +class _LHSValueCompiler(_ValueCompiler): def __init__(self, signal_slots, rhs_compiler): self.signal_slots = signal_slots self.rhs_compiler = rhs_compiler diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 1572d3e..9ad3033 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -305,6 +305,9 @@ class _ValueCompiler(xfrm.ValueVisitor): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: + def on_Record(self, value): + return self(Cat(value.fields.values())) + def on_Cat(self, value): return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts]))) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py new file mode 100644 index 0000000..32028af --- /dev/null +++ b/nmigen/hdl/rec.py @@ -0,0 +1,114 @@ +from enum import Enum +from collections import OrderedDict + +from .. import tracer +from ..tools import union +from .ast import * + + +__all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"] + + +Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN')) + +DIR_NONE = Direction.NONE +DIR_FANOUT = Direction.FANOUT +DIR_FANIN = Direction.FANIN + + +class Layout: + @staticmethod + def wrap(obj): + if isinstance(obj, Layout): + return obj + return Layout(obj) + + def __init__(self, fields): + self.fields = OrderedDict() + for field in fields: + if not isinstance(field, tuple) or len(field) not in (2, 3): + raise TypeError("Field {!r} has invalid layout: should be either " + "(name, shape) or (name, shape, direction)" + .format(field)) + if len(field) == 2: + name, shape = field + direction = DIR_NONE + if isinstance(shape, list): + shape = Layout.wrap(shape) + else: + name, shape, direction = field + if not isinstance(direction, Direction): + raise TypeError("Field {!r} has invalid direction: should be a Direction " + "instance like DIR_FANIN" + .format(field)) + if not isinstance(name, str): + raise TypeError("Field {!r} has invalid name: should be a string" + .format(field)) + if not isinstance(shape, (int, tuple, Layout)): + raise TypeError("Field {!r} has invalid shape: should be an int, tuple, or list " + "of fields of a nested record" + .format(field)) + if name in self.fields: + raise NameError("Field {!r} has a name that is already present in the layout" + .format(field)) + self.fields[name] = (shape, direction) + + def __getitem__(self, name): + return self.fields[name] + + def __iter__(self): + for name, (shape, dir) in self.fields.items(): + yield (name, shape, dir) + + +class Record(Value): + __slots__ = ("fields",) + + def __init__(self, layout, name=None): + if name is None: + try: + name = tracer.get_var_name() + except tracer.NameNotFound: + pass + self.name = name + self.src_loc = tracer.get_src_loc() + + def concat(a, b): + if a is None: + return b + return "{}_{}".format(a, b) + + self.layout = Layout.wrap(layout) + self.fields = OrderedDict() + for field_name, field_shape, field_dir in self.layout: + if isinstance(field_shape, Layout): + self.fields[field_name] = Record(field_shape, name=concat(name, field_name)) + else: + self.fields[field_name] = Signal(field_shape, name=concat(name, field_name)) + + def __getattr__(self, name): + return self.fields[name] + + def __getitem__(self, name): + return self.fields[name] + + def shape(self): + return sum(len(f) for f in self.fields.values()), False + + def _lhs_signals(self): + return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) + + def _rhs_signals(self): + return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet()) + + def __repr__(self): + fields = [] + for field_name, field in self.fields.items(): + if isinstance(field, Signal): + fields.append(field_name) + else: + fields.append(repr(field)) + name = self.name + if name is None: + name = "" + return "(rec {} {})".format(name, " ".join(fields)) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 2f1f937..e70f9d1 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -7,6 +7,7 @@ from .ast import * from .ast import _StatementList from .cd import * from .ir import * +from .rec import * __all__ = ["ValueVisitor", "ValueTransformer", @@ -26,6 +27,10 @@ class ValueVisitor(metaclass=ABCMeta): def on_Signal(self, value): pass # :nocov: + @abstractmethod + def on_Record(self, value): + pass # :nocov: + @abstractmethod def on_ClockSignal(self, value): pass # :nocov: @@ -66,6 +71,8 @@ class ValueVisitor(metaclass=ABCMeta): new_value = self.on_Const(value) elif type(value) is Signal: new_value = self.on_Signal(value) + elif type(value) is Record: + new_value = self.on_Record(value) elif type(value) is ClockSignal: new_value = self.on_ClockSignal(value) elif type(value) is ResetSignal: @@ -100,6 +107,9 @@ class ValueTransformer(ValueVisitor): def on_Signal(self, value): return value + def on_Record(self, value): + return value + def on_ClockSignal(self, value): return value diff --git a/nmigen/test/test_hdl_rec.py b/nmigen/test/test_hdl_rec.py new file mode 100644 index 0000000..501fda9 --- /dev/null +++ b/nmigen/test/test_hdl_rec.py @@ -0,0 +1,80 @@ +from ..hdl.ast import * +from ..hdl.rec import * +from .tools import * + + +class LayoutTestCase(FHDLTestCase): + def test_fields(self): + layout = Layout.wrap([ + ("cyc", 1), + ("data", (32, True)), + ("stb", 1, DIR_FANOUT), + ("ack", 1, DIR_FANIN), + ("info", [ + ("a", 1), + ("b", 1), + ]) + ]) + + self.assertEqual(layout["cyc"], (1, DIR_NONE)) + self.assertEqual(layout["data"], ((32, True), DIR_NONE)) + self.assertEqual(layout["stb"], (1, DIR_FANOUT)) + self.assertEqual(layout["ack"], (1, DIR_FANIN)) + sublayout = layout["info"][0] + self.assertEqual(layout["info"][1], DIR_NONE) + self.assertEqual(sublayout["a"], (1, DIR_NONE)) + self.assertEqual(sublayout["b"], (1, DIR_NONE)) + + def test_wrong_field(self): + with self.assertRaises(TypeError, + msg="Field (1,) has invalid layout: should be either (name, shape) or " + "(name, shape, direction)"): + Layout.wrap([(1,)]) + + def test_wrong_name(self): + with self.assertRaises(TypeError, + msg="Field (1, 1) has invalid name: should be a string"): + Layout.wrap([(1, 1)]) + + def test_wrong_name_duplicate(self): + with self.assertRaises(NameError, + msg="Field ('a', 2) has a name that is already present in the layout"): + Layout.wrap([("a", 1), ("a", 2)]) + + def test_wrong_direction(self): + with self.assertRaises(TypeError, + msg="Field ('a', 1, 0) has invalid direction: should be a Direction " + "instance like DIR_FANIN"): + Layout.wrap([("a", 1, 0)]) + + def test_wrong_shape(self): + with self.assertRaises(TypeError, + msg="Field ('a', 'x') has invalid shape: should be an int, tuple, or " + "list of fields of a nested record"): + Layout.wrap([("a", "x")]) + + +class RecordTestCase(FHDLTestCase): + def test_basic(self): + r = Record([ + ("stb", 1), + ("data", 32), + ("info", [ + ("a", 1), + ("b", 1), + ]) + ]) + + self.assertEqual(repr(r), "(rec r stb data (rec r_info a b))") + self.assertEqual(len(r), 35) + self.assertIsInstance(r.stb, Signal) + self.assertEqual(r.stb.name, "r_stb") + self.assertEqual(r["stb"].name, "r_stb") + + def test_unnamed(self): + r = [Record([ + ("stb", 1) + ])][0] + + self.assertEqual(repr(r), "(rec stb)") + self.assertEqual(r.stb.name, "stb") diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index 4dd7b5b..cdd83d7 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -5,6 +5,7 @@ from ..tools import flatten, union from ..hdl.ast import * from ..hdl.cd import * from ..hdl.mem import * +from ..hdl.rec import * from ..hdl.dsl import * from ..hdl.ir import * from ..back.pysim import * @@ -173,6 +174,14 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))] self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9)) + def test_record(self): + rec = Record([ + ("l", 1), + ("m", 2), + ]) + stmt = lambda y, a: [rec.eq(a), y.eq(rec)] + self.assertStatement(stmt, [C(0b101, 3)], C(0b101, 3)) + def test_repl(self): stmt = lambda y, a: y.eq(Repl(a, 3)) self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))