speed up ==, hash, <, >, <=, and >= for plain_data master
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Nov 2022 06:02:12 +0000 (22:02 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Nov 2022 06:02:12 +0000 (22:02 -0800)
this makes test_register_allocate_spread finish in under half as much time for:
https://git.libre-soc.org/?p=bigint-presentation-code.git;a=commit;h=a369418056bf51137af0fc6bdcfc0799697df583

60 files changed:
.gitignore
.gitlab-ci.yml
src/nmutil/byterev.py
src/nmutil/clz.py
src/nmutil/concurrentunit.py
src/nmutil/deduped.py [new file with mode: 0644]
src/nmutil/divmod.py
src/nmutil/dynamicpipe.py
src/nmutil/extend.py
src/nmutil/formal/__init__.py [new file with mode: 0644]
src/nmutil/formal/proof_clz.py [deleted file]
src/nmutil/formal/test_byterev.py [new file with mode: 0644]
src/nmutil/formal/test_clz.py [new file with mode: 0644]
src/nmutil/formal/test_picker.py [new file with mode: 0644]
src/nmutil/formal/test_plru.py [new file with mode: 0644]
src/nmutil/formal/test_queue.py [new file with mode: 0644]
src/nmutil/formaltest.py
src/nmutil/get_test_path.py [new file with mode: 0644]
src/nmutil/grev.py [new file with mode: 0644]
src/nmutil/iocontrol.py
src/nmutil/latch.py
src/nmutil/lut.py [new file with mode: 0644]
src/nmutil/mask.py
src/nmutil/multipipe.py
src/nmutil/nmoperator.py
src/nmutil/noconflict.py
src/nmutil/p_lru.txt [deleted file]
src/nmutil/picker.py
src/nmutil/pipeline.py
src/nmutil/pipemodbase.py
src/nmutil/plain_data.py [new file with mode: 0644]
src/nmutil/plain_data.pyi [new file with mode: 0644]
src/nmutil/plru.py
src/nmutil/plru.txt [new file with mode: 0644]
src/nmutil/plru2.py [new file with mode: 0644]
src/nmutil/popcount.py [new file with mode: 0644]
src/nmutil/prefix_sum.py [new file with mode: 0644]
src/nmutil/queue.py
src/nmutil/ripple.py
src/nmutil/sim_tmp_alternative.py
src/nmutil/sim_util.py [new file with mode: 0644]
src/nmutil/singlepipe.py
src/nmutil/stageapi.py
src/nmutil/test/example_buf_pipe.py
src/nmutil/test/example_gtkwave.py
src/nmutil/test/test_buf_pipe.py
src/nmutil/test/test_clz.py
src/nmutil/test/test_deduped.py [new file with mode: 0644]
src/nmutil/test/test_grev.py [new file with mode: 0644]
src/nmutil/test/test_inout_feedback_pipe.py
src/nmutil/test/test_inout_mux_pipe.py
src/nmutil/test/test_inout_unary_mux_cancel_pipe.py
src/nmutil/test/test_lut.py [new file with mode: 0644]
src/nmutil/test/test_outmux_pipe.py
src/nmutil/test/test_plain_data.py [new file with mode: 0644]
src/nmutil/test/test_prefix_sum.py [new file with mode: 0644]
src/nmutil/test/test_prioritymux_pipe.py
src/nmutil/test/test_reservation_stations.py [new file with mode: 0644]
src/nmutil/toolchain.py [new file with mode: 0644]
src/nmutil/util.py

index 77d4f5d07fac9f86a1828aab5f9cccb47a3dd928..8248bf2236c6a86e922d826d5cef61b71ce7824e 100644 (file)
@@ -9,3 +9,5 @@ __pycache__
 .eggs
 *.egg-info
 *.gtkw
+/sim_test_out
+/formal_test_temp
\ No newline at end of file
index d6d41f247466586b86a95b4d6d5e8a88a7d59718..beb4f46cb8a3d3187c8dc4f482a413a052e6f39d 100644 (file)
@@ -1,6 +1,7 @@
 image: debian:10
 
 cache:
+    when: always
     paths:
         - ccache
 
@@ -8,22 +9,43 @@ build:
     stage: build
     before_script:
         - apt-get update
+        # one package per line to simplify sorting, git diff, etc.
         - >-
             apt-get -y install
-            build-essential git python3-dev python3-pip
-            python3-setuptools python3-wheel pkg-config tcl-dev
-            libreadline-dev bison flex libffi-dev ccache
-        - export PATH="/usr/lib/ccache:$PATH"
+            autoconf
+            bison
+            build-essential
+            ccache
+            clang
+            cmake
+            curl
+            flex
+            gawk
+            git
+            gperf
+            libboost-program-options-dev
+            libffi-dev
+            libftdi-dev
+            libgmp-dev
+            libreadline-dev
+            mercurial
+            pkg-config
+            python
+            python3
+            python3-dev
+            python3-pip
+            python3-setuptools
+            python3-wheel
+            tcl-dev
+        - export PATH="$HOME/.local/bin:/usr/lib/ccache:$PATH"
         - export CCACHE_BASEDIR="$PWD"
         - export CCACHE_DIR="$PWD/ccache"
         - export CCACHE_COMPILERCHECK=content
         - ccache --zero-stats || true
         - ccache --show-stats || true
-    after_script:
-        - export CCACHE_DIR="$PWD/ccache"
-        - ccache --show-stats
+        - python3 -m pip install --user pytest-xdist
     script:
-        - git clone --depth 1 https://github.com/YosysHQ/yosys.git yosys
+        - git clone --depth 1 -b yosys-0.17 https://github.com/YosysHQ/yosys.git yosys
         - pushd yosys
         - make config-gcc
         - make -j$(nproc)
@@ -31,11 +53,34 @@ build:
         - popd
         - yosys -V
 
-        - git clone --depth 1 https://github.com/nmigen/nmigen.git nmigen
+        - git clone https://github.com/YosysHQ/SymbiYosys.git SymbiYosys
+        - pushd SymbiYosys
+        - git checkout d10e472edf4ea9be3aa6347b264ba575fbea933a
+        - make install
+        - popd
+
+        - git clone --depth 1 -b Yices-2.6.4 https://github.com/SRI-CSL/yices2.git yices2
+        - pushd yices2
+        - autoconf
+        - ./configure
+        - make -j$(nproc)
+        - make install
+        - popd
+
+        - git clone --depth 1 -b z3-4.8.17 https://github.com/Z3Prover/z3.git z3
+        - pushd z3
+        - python scripts/mk_make.py
+        - cd build
+        - make -j$(nproc)
+        - make install
+        - popd
+
+        - git clone --depth 1 https://gitlab.com/nmigen/nmigen.git nmigen
         - pushd nmigen
+        - git rev-parse HEAD
         - python3 setup.py develop
         - popd
 
         - python3 setup.py develop
 
-        - python3 setup.py test
+        - pytest -n auto src/nmutil
index 7dad4508ac12584056d5c8837fa6d043234d21e8..a7f0a68aeacc1fd78609dcc9b60734c79e105381 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -7,11 +8,15 @@
 from nmigen import Signal, Cat
 
 # TODO: turn this into a module?
+
+
 def byte_reverse(m, name, data, length):
     """byte_reverse: unlike nmigen word_select this takes a dynamic length
 
     nmigen Signal.word_select may only take a fixed length.  we need
     bigendian byte-reverse, half-word reverse, word and dword reverse.
+
+    This only outputs the first `length` bytes, higher bytes are zeroed.
     """
     comb = m.d.comb
     data_r = Signal.like(data, name=name)
@@ -28,11 +33,10 @@ def byte_reverse(m, name, data, length):
 
     # Switch statement needed: dynamic length had better be = 1,2,4 or 8
     with m.Switch(length):
-        for j in [1,2,4,8]:
+        for j in [1, 2, 4, 8]:
             with m.Case(j):
                 rev = []
                 for i in range(j):
                     rev.append(data.word_select(j-1-i, 8))
                 comb += data_r.eq(Cat(*rev))
     return data_r
-
index 2fda8c2b578a5a335cc33bcdca5fac3083c4d483..70e0f513c45b52034c9aac5aac8d86dd55d76a38 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 from nmigen import Module, Signal, Elaboratable, Cat, Repl
 import math
 """ This module is much more efficient than PriorityEncoder
@@ -10,6 +11,15 @@ import math
 
 """
 
+
+def clz(v, width):
+    """count leading zeros."""
+    assert isinstance(width, int) and 0 <= width
+    max_v = (1 << width) - 1
+    assert isinstance(v, int) and 0 <= v <= max_v
+    return max_v.bit_length() - v.bit_length()
+
+
 class CLZ(Elaboratable):
     def __init__(self, width):
         self.width = width
@@ -86,4 +96,3 @@ class CLZ(Elaboratable):
         comb += self.lz.eq(pairs[0][0])
 
         return m
-
index 6d2ff3d56b814b4bf50b3890e58a5678d3160b5d..f317fbcd52ed1c8ee16630841bf23f2ba3681b33 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ concurrent unit from mitch alsup augmentations to 6600 scoreboard
 
     This work is funded through NLnet under Grant 2019-02-012
 """
 
 from math import log
-from nmigen import Module, Elaboratable, Signal
+from nmigen import Module, Elaboratable, Signal, Cat
 from nmigen.asserts import Assert
+from nmigen.lib.coding import PriorityEncoder
 from nmigen.cli import main, verilog
 
 from nmutil.singlepipe import PassThroughStage
 from nmutil.multipipe import CombMuxOutPipe
 from nmutil.multipipe import PriorityCombMuxInPipe
+from nmutil.iocontrol import NextControl, PrevControl
+from nmutil import nmoperator
 
 
 def num_bits(n):
@@ -93,6 +97,18 @@ class MuxOutPipe(CombMuxOutPipe):
                                 maskwid=maskwid)
 
 
+class ALUProxy:
+    """ALUProxy: create a series of ALUs that look like the ALU being
+    sandwiched in between the fan-in and fan-out.  One ALU looks like
+    it is multiple concurrent ALUs
+    """
+
+    def __init__(self, alu, p, n):
+        self.alu = alu
+        self.p = p
+        self.n = n
+
+
 class ReservationStations(Elaboratable):
     """ Reservation-Station pipeline
 
@@ -107,16 +123,25 @@ class ReservationStations(Elaboratable):
 
         Fan-in and Fan-out are combinatorial.
     """
+
     def __init__(self, num_rows, maskwid=0, feedback_width=None):
         self.num_rows = nr = num_rows
         self.feedback_width = feedback_width
         self.inpipe = InMuxPipe(nr, self.i_specfn, maskwid)   # fan-in
-        self.outpipe = MuxOutPipe(nr, self.o_specfn, maskwid) # fan-out
+        self.outpipe = MuxOutPipe(nr, self.o_specfn, maskwid)  # fan-out
 
         self.p = self.inpipe.p  # kinda annoying,
-        self.n = self.outpipe.n # use pipe in/out as this class in/out
+        self.n = self.outpipe.n  # use pipe in/out as this class in/out
         self._ports = self.inpipe.ports() + self.outpipe.ports()
 
+    def setup_pseudoalus(self):
+        """setup_pseudoalus: establishes a suite of pseudo-alus
+        that look to all pipeline-intents-and-purposes just like the original
+        """
+        self.pseudoalus = []
+        for i in range(self.num_rows):
+            self.pseudoalus.append(ALUProxy(self.alu, self.p[i], self.n[i]))
+
     def elaborate(self, platform):
         m = Module()
         m.submodules.inpipe = self.inpipe
@@ -147,3 +172,167 @@ class ReservationStations(Elaboratable):
 
     def o_specfn(self):
         return self.alu.ospec()
+
+
+class ReservationStations2(Elaboratable):
+    """ Reservation-Station pipeline.  Manages an ALU and makes it look like
+        there are multiple of them, presenting the same ready/valid API
+
+        Input:
+
+        :alu: - an ALU to be "managed" by these ReservationStations
+        :num_rows: - number of input and output Reservation Stations
+
+        Note that the ALU data (in and out specs) right the way down the
+        entire chain *must* have a "muxid" data member.  this is picked
+        up and used to route data correctly from input RS to output RS.
+
+        It is the responsibility of the USER of the ReservationStations
+        class to correctly set that muxid in each data packet to the
+        correct constant.  this could change in future.
+
+        FAILING TO SET THE MUXID IS GUARANTEED TO RESULT IN CORRUPTED DATA.
+    """
+
+    def __init__(self, alu, num_rows, alu_name=None):
+        if alu_name is None:
+            alu_name = "alu"
+        self.num_rows = nr = num_rows
+        id_wid = num_rows.bit_length()
+        self.p = []
+        self.n = []
+        self.alu = alu
+        self.alu_name = alu_name
+        # create prev and next ready/valid and add replica of ALU data specs
+        for i in range(num_rows):
+            suffix = "_%d" % i
+            p = PrevControl(name=suffix)
+            n = NextControl(name=suffix)
+            p.i_data, n.o_data = self.alu.new_specs("rs_%d" % i)
+            self.p.append(p)
+            self.n.append(n)
+
+        self.pipe = self  # for Arbiter to select the incoming prevcontrols
+
+        # set up pseudo-alus that look like a standard pipeline
+        self.pseudoalus = []
+        for i in range(self.num_rows):
+            self.pseudoalus.append(ALUProxy(self.alu, self.p[i], self.n[i]))
+
+    def __iter__(self):
+        for p in self.p:
+            yield from p
+        for n in self.n:
+            yield from n
+
+    def ports(self):
+        return list(self)
+
+    def elaborate(self, platform):
+        m = Module()
+        pe = PriorityEncoder(self.num_rows)  # input priority picker
+        m.submodules[self.alu_name] = self.alu
+        m.submodules.selector = pe
+        for i, (p, n) in enumerate(zip(self.p, self.n)):
+            m.submodules["rs_p_%d" % i] = p
+            m.submodules["rs_n_%d" % i] = n
+
+        # Priority picker for one RS
+        self.active = Signal()
+        self.m_id = Signal.like(pe.o)
+
+        # ReservationStation status information, progressively updated in FSM
+        rsvd = Signal(self.num_rows)  # indicates RS data in flight
+        sent = Signal(self.num_rows)  # sent indicates data in pipeline
+        wait = Signal(self.num_rows)  # the outputs are waiting for accept
+
+        # pick first non-reserved ReservationStation with data not already
+        # sent into the ALU
+        m.d.comb += pe.i.eq(rsvd & ~sent)
+        m.d.comb += self.active.eq(~pe.n)   # encoder active (one input valid)
+        m.d.comb += self.m_id.eq(pe.o)       # output one active input
+
+        # mux in and mux out ids.  note that all data *must* have a muxid
+        mid = self.m_id                   # input mux selector
+        o_muxid = self.alu.n.o_data.muxid  # output mux selector
+
+        # technically speaking this could be set permanently "HI".
+        # when all the ReservationStations outputs are waiting,
+        # the ALU cannot obviously accept any more data.  as the
+        # ALU is effectively "decoupled" from (managed by) the RSes,
+        # as long as there is sufficient RS allocation this should not
+        # be necessary, i.e. at no time should the ALU be given more inputs
+        # than there are outputs to accept (!) but just in case...
+        m.d.comb += self.alu.n.i_ready.eq(~wait.all())
+
+        #####
+        # input side
+        #####
+
+        # first, establish input: select one input to pass data to (p_mux)
+        for i in range(self.num_rows):
+            i_buf, o_buf = self.alu.new_specs("buf%d" % i)  # buffers
+            with m.FSM():
+                # indicate ready to accept data, and accept it if incoming
+                # BUT, if there is an opportunity to send on immediately
+                # to the ALU, take it early (combinatorial)
+                with m.State("ACCEPTING%d" % i):
+                    m.d.comb += self.p[i].o_ready.eq(1)  # ready indicator
+                    with m.If(self.p[i].i_valid):  # valid data incoming
+                        m.d.sync += rsvd[i].eq(1)  # now reserved
+                        # a unique opportunity: the ALU happens to be free
+                        with m.If(mid == i):  # picker selected us
+                            with m.If(self.alu.p.o_ready):  # ALU can accept
+                                # transfer
+                                m.d.comb += self.alu.p.i_valid.eq(1)
+                                m.d.comb += nmoperator.eq(self.alu.p.i_data,
+                                                          self.p[i].i_data)
+                                m.d.sync += sent[i].eq(1)  # now reserved
+                                m.next = "WAITOUT%d" % i  # move to "wait output"
+                        with m.Else():
+                            # nope. ALU wasn't free. try next cycle(s)
+                            m.d.sync += nmoperator.eq(i_buf, self.p[i].i_data)
+                            m.next = "ACCEPTED%d" % i  # move to "accepted"
+
+                # now try to deliver to the ALU, but only if we are "picked"
+                with m.State("ACCEPTED%d" % i):
+                    with m.If(mid == i):  # picker selected us
+                        with m.If(self.alu.p.o_ready):  # ALU can accept
+                            m.d.comb += self.alu.p.i_valid.eq(1)  # transfer
+                            m.d.comb += nmoperator.eq(self.alu.p.i_data, i_buf)
+                            m.d.sync += sent[i].eq(1)  # now reserved
+                            m.next = "WAITOUT%d" % i  # move to "wait output"
+
+                # waiting for output to appear on the ALU, take a copy
+                # BUT, again, if there is an opportunity to send on
+                # immediately, take it (combinatorial)
+                with m.State("WAITOUT%d" % i):
+                    with m.If(o_muxid == i):  # when ALU output matches our RS
+                        with m.If(self.alu.n.o_valid):  # ALU can accept
+                            # second unique opportunity: the RS is ready
+                            with m.If(self.n[i].i_ready):  # ready to receive
+                                m.d.comb += self.n[i].o_valid.eq(1)  # valid
+                                m.d.comb += nmoperator.eq(self.n[i].o_data,
+                                                          self.alu.n.o_data)
+                                m.d.sync += wait[i].eq(0)  # clear waiting
+                                m.d.sync += sent[i].eq(0)  # and sending
+                                m.d.sync += rsvd[i].eq(0)  # and reserved
+                                m.next = "ACCEPTING%d" % i  # back to "accepting"
+                            with m.Else():
+                                # nope. RS wasn't ready. try next cycles
+                                m.d.sync += wait[i].eq(1)  # now waiting
+                                m.d.sync += nmoperator.eq(o_buf,
+                                                          self.alu.n.o_data)
+                                m.next = "SENDON%d" % i  # move to "send data on"
+
+                # waiting for "valid" indicator on RS output: deliver it
+                with m.State("SENDON%d" % i):
+                    with m.If(self.n[i].i_ready):  # user is ready to receive
+                        m.d.comb += self.n[i].o_valid.eq(1)  # indicate valid
+                        m.d.comb += nmoperator.eq(self.n[i].o_data, o_buf)
+                        m.d.sync += wait[i].eq(0)  # clear waiting
+                        m.d.sync += sent[i].eq(0)  # and sending
+                        m.d.sync += rsvd[i].eq(0)  # and reserved
+                        m.next = "ACCEPTING%d" % i  # and back to "accepting"
+
+        return m
diff --git a/src/nmutil/deduped.py b/src/nmutil/deduped.py
new file mode 100644 (file)
index 0000000..d6fca1c
--- /dev/null
@@ -0,0 +1,83 @@
+import functools
+import weakref
+
+
+class _KeyBuilder:
+    def __init__(self, do_delete):
+        self.__keys = []
+        self.__refs = {}
+        self.__do_delete = do_delete
+
+    def add_ref(self, v):
+        v_id = id(v)
+        if v_id in self.__refs:
+            return
+        try:
+            v = weakref.ref(v, callback=self.__do_delete)
+        except TypeError:
+            pass
+        self.__refs[v_id] = v
+
+    def add(self, k, v):
+        self.__keys.append(id(k))
+        self.__keys.append(id(v))
+        self.add_ref(k)
+        self.add_ref(v)
+
+    def finish(self):
+        return tuple(self.__keys), tuple(self.__refs.values())
+
+
+def deduped(*, global_keys=()):
+    """decorator that causes functions to deduplicate their results based on
+    their input args and the requested globals. For each set of arguments, it
+    will always return the exact same object, by storing it internally.
+    Arguments are compared by their identity, so they don't need to be
+    hashable.
+
+    Usage:
+    ```
+    # for functions that don't depend on global variables
+    @deduped()
+    def my_fn1(a, b, *, c=1):
+        return a + b * c
+
+    my_global = 23
+
+    # for functions that depend on global variables
+    @deduped(global_keys=[lambda: my_global])
+    def my_fn2(a, b, *, c=2):
+        return a + b * c + my_global
+    ```
+    """
+    global_keys = tuple(global_keys)
+    assert all(map(callable, global_keys))
+
+    def decorator(f):
+        if isinstance(f, (staticmethod, classmethod)):
+            raise TypeError("@staticmethod or @classmethod should be applied "
+                            "to the result of @deduped, not the other way"
+                            " around")
+        assert callable(f)
+
+        map = {}
+
+        @functools.wraps(f)
+        def wrapper(*args, **kwargs):
+            key_builder = _KeyBuilder(lambda _: map.pop(key, None))
+            for arg in args:
+                key_builder.add(None, arg)
+            for k, v in kwargs.items():
+                key_builder.add(k, v)
+            for global_key in global_keys:
+                key_builder.add(None, global_key())
+            key, refs = key_builder.finish()
+            if key in map:
+                return map[key][0]
+            retval = f(*args, **kwargs)
+            # keep reference to stuff used for key to avoid ids
+            # getting reused for something else.
+            map[key] = retval, refs
+            return retval
+        return wrapper
+    return decorator
index 5a8207eea424f5330891cb5ba7dea78278d56854..7c503e6199992462824c8646f1ef62d697d6a536 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
index 0187cf29cc594429a587e1abc7a2a0609b721bfc..95f58e5fe3f275ebe74ea4af6565ec0db053eda0 100644 (file)
@@ -27,6 +27,7 @@ import threading
 # list post:
 # http://lists.libre-riscv.org/pipermail/libre-riscv-dev/2019-July/002259.html
 
+
 class Meta(ABCMeta):
     registry = {}
     recursing = threading.local()
@@ -38,11 +39,11 @@ class Meta(ABCMeta):
         if mcls.recursing.check:
             return super().__call__(*args, **kw)
         spec = args[0]
-        base = spec.pipekls # pick up the dynamic class from PipelineSpec, HERE
+        base = spec.pipekls  # pick up the dynamic class from PipelineSpec, HERE
 
         if (cls, base) not in mcls.registry:
-            print ("__call__", args, kw, cls, base,
-                   base.__bases__, cls.__bases__)
+            print("__call__", args, kw, cls, base,
+                  base.__bases__, cls.__bases__)
             mcls.registry[cls, base] = type(
                 cls.__name__,
                 (cls, base) + cls.__bases__[1:],
@@ -74,7 +75,7 @@ class Meta(ABCMeta):
 
 class DynamicPipe(metaclass=Meta):
     def __init__(self, *args):
-        print ("DynamicPipe init", super(), args)
+        print("DynamicPipe init", super(), args)
         super().__init__(self, *args)
 
 
@@ -84,7 +85,7 @@ class DynamicPipe(metaclass=Meta):
 # could hypothetically be passed through the pspec.
 class SimpleHandshakeRedir(SimpleHandshake):
     def __init__(self, mod, *args):
-        print ("redir", mod, args)
+        print("redir", mod, args)
         stage = self
         if args and args[0].stage:
             stage = args[0].stage
@@ -97,6 +98,5 @@ class MaskCancellableRedir(MaskCancellable):
         maskwid = args[0].maskwid
         if args[0].stage:
             stage = args[0].stage
-        print ("redir mask", mod, args, maskwid)
+        print("redir mask", mod, args, maskwid)
         MaskCancellable.__init__(self, stage, maskwid)
-
index 38b5e7dd27b64701aa8f09f8837a74f42975543b..7b7d1acb78cdc52422040e85479ad8c680091184 100644 (file)
@@ -1,24 +1,41 @@
+# SPDX-License-Identifier: LGPL-2-or-later
+# Copyright (C) Luke Kenneth Casson Leighton 2020,2021 <lkcl@lkcl.net>
 """
-    This work is funded through NLnet under Grant 2019-02-012
-
-    License: LGPLv3+
-
+Provides sign/unsigned extension/truncation utility functions.
 
+This work is funded through NLnet under Grant 2019-02-012
 """
 from nmigen import Repl, Cat, Const
 
 
 def exts(exts_data, width, fullwidth):
+    diff = fullwidth-width
+    if diff == 0:
+        return exts_data
     exts_data = exts_data[0:width]
+    if diff <= 0:
+        return exts_data[:fullwidth]
     topbit = exts_data[-1]
-    signbits = Repl(topbit, fullwidth-width)
+    signbits = Repl(topbit, diff)
     return Cat(exts_data, signbits)
 
 
-def extz(exts_data, width, fullwidth):
-    exts_data = exts_data[0:width]
+def extz(extz_data, width, fullwidth):
+    diff = fullwidth-width
+    if diff == 0:
+        return extz_data
+    extz_data = extz_data[0:width]
+    if diff <= 0:
+        return extz_data[:fullwidth]
     topbit = Const(0)
-    signbits = Repl(topbit, fullwidth-width)
-    return Cat(exts_data, signbits)
+    signbits = Repl(topbit, diff)
+    return Cat(extz_data, signbits)
 
 
+def ext(data, shape, newwidth):
+    """extend/truncate data to new width, preserving sign
+    """
+    width, signed = shape
+    if signed:
+        return exts(data, width, newwidth)
+    return extz(data, width, newwidth)
diff --git a/src/nmutil/formal/__init__.py b/src/nmutil/formal/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/nmutil/formal/proof_clz.py b/src/nmutil/formal/proof_clz.py
deleted file mode 100644 (file)
index aac25d3..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-from nmigen import Module, Signal, Elaboratable, Mux, Const
-from nmigen.asserts import Assert, AnyConst, Assume
-from nmutil.formaltest import FHDLTestCase
-from nmigen.cli import rtlil
-
-from nmutil.clz import CLZ
-import unittest
-
-
-# This defines a module to drive the device under test and assert
-# properties about its outputs
-class Driver(Elaboratable):
-    def __init__(self):
-        # inputs and outputs
-        pass
-
-    def elaborate(self, platform):
-        m = Module()
-        comb = m.d.comb
-        width = 10
-
-        m.submodules.dut = dut = CLZ(width)
-        sig_in = Signal.like(dut.sig_in)
-        count = Signal.like(dut.lz)
-
-
-        m.d.comb += [
-            sig_in.eq(AnyConst(width)),
-            dut.sig_in.eq(sig_in),
-            count.eq(dut.lz)]
-
-        result = Const(width)
-        for i in range(width):
-            print(result)
-            result_next = Signal.like(count, name="count_%d" % i)
-            with m.If(sig_in[i] == 1):
-                comb += result_next.eq(width-i-1)
-            with m.Else():
-                comb += result_next.eq(result)
-            result = result_next
-
-        result_sig = Signal.like(count)
-        comb += result_sig.eq(result)
-
-        comb += Assert(result_sig == count)
-        
-        # setup the inputs and outputs of the DUT as anyconst
-
-        return m
-
-class CLZTestCase(FHDLTestCase):
-    def test_proof(self):
-        module = Driver()
-        self.assertFormal(module, mode="bmc", depth=4)
-    def test_ilang(self):
-        dut = Driver()
-        vl = rtlil.convert(dut, ports=[])
-        with open("clz.il", "w") as f:
-            f.write(vl)
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/src/nmutil/formal/test_byterev.py b/src/nmutil/formal/test_byterev.py
new file mode 100644 (file)
index 0000000..6df71bd
--- /dev/null
@@ -0,0 +1,113 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+import unittest
+from nmigen.hdl.ast import AnyConst, Assert, Signal, Assume
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.byterev import byte_reverse
+from nmutil.grev import grev
+
+
+VALID_BYTE_REVERSE_LENGTHS = tuple(1 << i for i in range(4))
+LOG2_BYTE_SIZE = 3
+
+
+class TestByteReverse(FHDLTestCase):
+    def tst(self, log2_width, rev_length=None):
+        assert isinstance(log2_width, int) and log2_width >= LOG2_BYTE_SIZE
+        assert rev_length is None or rev_length in VALID_BYTE_REVERSE_LENGTHS
+        m = Module()
+        width = 1 << log2_width
+        inp = Signal(width)
+        m.d.comb += inp.eq(AnyConst(width))
+        length_sig = Signal(range(max(VALID_BYTE_REVERSE_LENGTHS) + 1))
+        m.d.comb += length_sig.eq(AnyConst(length_sig.shape()))
+
+        if rev_length is None:
+            rev_length = length_sig
+        else:
+            m.d.comb += Assume(length_sig == rev_length)
+
+        with m.Switch(length_sig):
+            for l in VALID_BYTE_REVERSE_LENGTHS:
+                with m.Case(l):
+                    m.d.comb += Assume(width >= l << LOG2_BYTE_SIZE)
+            with m.Default():
+                m.d.comb += Assume(False)
+
+        out = byte_reverse(m, name="out", data=inp, length=rev_length)
+
+        expected = Signal(width)
+        for log2_chunk_size in range(LOG2_BYTE_SIZE, log2_width + 1):
+            chunk_size = 1 << log2_chunk_size
+            chunk_byte_size = chunk_size >> LOG2_BYTE_SIZE
+            chunk_sizes = chunk_size - 8
+            with m.If(rev_length == chunk_byte_size):
+                m.d.comb += expected.eq(grev(inp, chunk_sizes, log2_width)
+                                        & ((1 << chunk_size) - 1))
+
+        m.d.comb += Assert(expected == out)
+
+        self.assertFormal(m)
+
+    def test_8_len_1(self):
+        self.tst(log2_width=3, rev_length=1)
+
+    def test_8(self):
+        self.tst(log2_width=3)
+
+    def test_16_len_1(self):
+        self.tst(log2_width=4, rev_length=1)
+
+    def test_16_len_2(self):
+        self.tst(log2_width=4, rev_length=2)
+
+    def test_16(self):
+        self.tst(log2_width=4)
+
+    def test_32_len_1(self):
+        self.tst(log2_width=5, rev_length=1)
+
+    def test_32_len_2(self):
+        self.tst(log2_width=5, rev_length=2)
+
+    def test_32_len_4(self):
+        self.tst(log2_width=5, rev_length=4)
+
+    def test_32(self):
+        self.tst(log2_width=5)
+
+    def test_64_len_1(self):
+        self.tst(log2_width=6, rev_length=1)
+
+    def test_64_len_2(self):
+        self.tst(log2_width=6, rev_length=2)
+
+    def test_64_len_4(self):
+        self.tst(log2_width=6, rev_length=4)
+
+    def test_64_len_8(self):
+        self.tst(log2_width=6, rev_length=8)
+
+    def test_64(self):
+        self.tst(log2_width=6)
+
+    def test_128_len_1(self):
+        self.tst(log2_width=7, rev_length=1)
+
+    def test_128_len_2(self):
+        self.tst(log2_width=7, rev_length=2)
+
+    def test_128_len_4(self):
+        self.tst(log2_width=7, rev_length=4)
+
+    def test_128_len_8(self):
+        self.tst(log2_width=7, rev_length=8)
+
+    def test_128(self):
+        self.tst(log2_width=7)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/nmutil/formal/test_clz.py b/src/nmutil/formal/test_clz.py
new file mode 100644 (file)
index 0000000..2bfcfe6
--- /dev/null
@@ -0,0 +1,64 @@
+from nmigen import Module, Signal, Elaboratable, Mux, Const
+from nmigen.asserts import Assert, AnyConst, Assume
+from nmutil.formaltest import FHDLTestCase
+from nmigen.cli import rtlil
+
+from nmutil.clz import CLZ
+import unittest
+
+
+# This defines a module to drive the device under test and assert
+# properties about its outputs
+class Driver(Elaboratable):
+    def __init__(self):
+        # inputs and outputs
+        pass
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        width = 10
+
+        m.submodules.dut = dut = CLZ(width)
+        sig_in = Signal.like(dut.sig_in)
+        count = Signal.like(dut.lz)
+
+        m.d.comb += [
+            sig_in.eq(AnyConst(width)),
+            dut.sig_in.eq(sig_in),
+            count.eq(dut.lz)]
+
+        result = Const(width)
+        for i in range(width):
+            print(result)
+            result_next = Signal.like(count, name="count_%d" % i)
+            with m.If(sig_in[i] == 1):
+                comb += result_next.eq(width-i-1)
+            with m.Else():
+                comb += result_next.eq(result)
+            result = result_next
+
+        result_sig = Signal.like(count)
+        comb += result_sig.eq(result)
+
+        comb += Assert(result_sig == count)
+
+        # setup the inputs and outputs of the DUT as anyconst
+
+        return m
+
+
+class CLZTestCase(FHDLTestCase):
+    def test_proof(self):
+        module = Driver()
+        self.assertFormal(module, mode="bmc", depth=4)
+
+    def test_ilang(self):
+        dut = Driver()
+        vl = rtlil.convert(dut, ports=[])
+        with open("clz.il", "w") as f:
+            f.write(vl)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/src/nmutil/formal/test_picker.py b/src/nmutil/formal/test_picker.py
new file mode 100644 (file)
index 0000000..caaf007
--- /dev/null
@@ -0,0 +1,485 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+from functools import reduce
+import operator
+import unittest
+from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Array, Shape, Mux
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.picker import (BetterMultiPriorityPicker, PriorityPicker,
+                           MultiPriorityPicker)
+from nmutil.sim_util import write_il
+
+
+class TestPriorityPicker(FHDLTestCase):
+    def tst(self, wid, msb_mode, reverse_i, reverse_o):
+        assert isinstance(wid, int)
+        assert isinstance(msb_mode, bool)
+        assert isinstance(reverse_i, bool)
+        assert isinstance(reverse_o, bool)
+        dut = PriorityPicker(wid=wid, msb_mode=msb_mode, reverse_i=reverse_i,
+                             reverse_o=reverse_o)
+        self.assertEqual(wid, dut.wid)
+        self.assertEqual(msb_mode, dut.msb_mode)
+        self.assertEqual(reverse_i, dut.reverse_i)
+        self.assertEqual(reverse_o, dut.reverse_o)
+        self.assertEqual(len(dut.i), wid)
+        self.assertEqual(len(dut.o), wid)
+        self.assertEqual(len(dut.en_o), 1)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.i.eq(AnyConst(wid))
+
+        # assert dut.o only has zero or one bit set
+        m.d.comb += Assert((dut.o & (dut.o - 1)) == 0)
+
+        m.d.comb += Assert((dut.o != 0) == dut.en_o)
+
+        unreversed_i = Signal(wid)
+        if reverse_i:
+            m.d.comb += unreversed_i.eq(dut.i[::-1])
+        else:
+            m.d.comb += unreversed_i.eq(dut.i)
+
+        unreversed_o = Signal(wid)
+        if reverse_o:
+            m.d.comb += unreversed_o.eq(dut.o[::-1])
+        else:
+            m.d.comb += unreversed_o.eq(dut.o)
+
+        expected_unreversed_o = Signal(wid)
+
+        found = Const(False, 1)
+        for i in reversed(range(wid)) if msb_mode else range(wid):
+            m.d.comb += expected_unreversed_o[i].eq(unreversed_i[i] & ~found)
+            found |= unreversed_i[i]
+
+        m.d.comb += Assert(expected_unreversed_o == unreversed_o)
+
+        self.assertFormal(m)
+
+    def test_1_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=1, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_1_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=1, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_1_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=1, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_1_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=1, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_1_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=1, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_1_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=1, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_1_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=1, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_1_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=1, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_2_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=2, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_2_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=2, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_2_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=2, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_2_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=2, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_2_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=2, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_2_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=2, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_2_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=2, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_2_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=2, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_3_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=3, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_3_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=3, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_3_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=3, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_3_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=3, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_3_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=3, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_3_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=3, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_3_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=3, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_3_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=3, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_4_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=4, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_4_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=4, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_4_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=4, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_4_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=4, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_4_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=4, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_4_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=4, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_4_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=4, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_4_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=4, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_8_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=8, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_8_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=8, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_8_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=8, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_8_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=8, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_8_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=8, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_8_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=8, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_8_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=8, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_8_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=8, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_32_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=32, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_32_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=32, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_32_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=32, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_32_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=32, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_32_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=32, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_32_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=32, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_32_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=32, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_32_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=32, msb_mode=True, reverse_i=True, reverse_o=True)
+
+    def test_64_msbm_f_revi_f_revo_f(self):
+        self.tst(wid=64, msb_mode=False, reverse_i=False, reverse_o=False)
+
+    def test_64_msbm_f_revi_f_revo_t(self):
+        self.tst(wid=64, msb_mode=False, reverse_i=False, reverse_o=True)
+
+    def test_64_msbm_f_revi_t_revo_f(self):
+        self.tst(wid=64, msb_mode=False, reverse_i=True, reverse_o=False)
+
+    def test_64_msbm_f_revi_t_revo_t(self):
+        self.tst(wid=64, msb_mode=False, reverse_i=True, reverse_o=True)
+
+    def test_64_msbm_t_revi_f_revo_f(self):
+        self.tst(wid=64, msb_mode=True, reverse_i=False, reverse_o=False)
+
+    def test_64_msbm_t_revi_f_revo_t(self):
+        self.tst(wid=64, msb_mode=True, reverse_i=False, reverse_o=True)
+
+    def test_64_msbm_t_revi_t_revo_f(self):
+        self.tst(wid=64, msb_mode=True, reverse_i=True, reverse_o=False)
+
+    def test_64_msbm_t_revi_t_revo_t(self):
+        self.tst(wid=64, msb_mode=True, reverse_i=True, reverse_o=True)
+
+
+class TestMultiPriorityPicker(FHDLTestCase):
+    def make_dut(self, width, levels, indices, multi_in):
+        dut = MultiPriorityPicker(wid=width, levels=levels, indices=indices,
+                                  multi_in=multi_in)
+        self.assertEqual(width, dut.wid)
+        self.assertEqual(levels, dut.levels)
+        self.assertEqual(indices, dut.indices)
+        self.assertEqual(multi_in, dut.multi_in)
+        return dut
+
+    def tst(self, *, width, levels, indices, multi_in):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(indices, bool)
+        assert isinstance(multi_in, bool)
+        dut = self.make_dut(width=width, levels=levels, indices=indices,
+                            multi_in=multi_in)
+        expected_ports = []
+        if multi_in:
+            self.assertIsInstance(dut.i, (Array, list))
+            self.assertEqual(len(dut.i), levels)
+            for i in dut.i:
+                self.assertIsInstance(i, Signal)
+                self.assertEqual(len(i), width)
+                expected_ports.append(i)
+        else:
+            self.assertIsInstance(dut.i, Signal)
+            self.assertEqual(len(dut.i), width)
+            expected_ports.append(dut.i)
+
+        self.assertIsInstance(dut.o, (Array, list))
+        self.assertEqual(len(dut.o), levels)
+        for o in dut.o:
+            self.assertIsInstance(o, Signal)
+            self.assertEqual(len(o), width)
+            expected_ports.append(o)
+
+        self.assertEqual(len(dut.en_o), levels)
+        expected_ports.append(dut.en_o)
+
+        if indices:
+            self.assertIsInstance(dut.idx_o, (Array, list))
+            self.assertEqual(len(dut.idx_o), levels)
+            for idx_o in dut.idx_o:
+                self.assertIsInstance(idx_o, Signal)
+                expected_ports.append(idx_o)
+        else:
+            self.assertFalse(hasattr(dut, "idx_o"))
+
+        self.assertListEqual(expected_ports, dut.ports())
+
+        write_il(self, dut, ports=dut.ports())
+
+        m = Module()
+        m.submodules.dut = dut
+        if multi_in:
+            for i in dut.i:
+                m.d.comb += i.eq(AnyConst(width))
+        else:
+            m.d.comb += dut.i.eq(AnyConst(width))
+
+        prev_set = 0
+        for o, en_o in zip(dut.o, dut.en_o):
+            # assert o only has zero or one bit set
+            m.d.comb += Assert((o & (o - 1)) == 0)
+            # assert o doesn't overlap any previous outputs
+            m.d.comb += Assert((o & prev_set) == 0)
+            prev_set |= o
+
+            m.d.comb += Assert((o != 0) == en_o)
+
+        prev_set = Const(0, width)
+        priority_pickers = [PriorityPicker(width) for _ in range(levels)]
+        for level in range(levels):
+            pp = priority_pickers[level]
+            setattr(m.submodules, f"pp_{level}", pp)
+            inp = dut.i[level] if multi_in else dut.i
+            m.d.comb += pp.i.eq(inp & ~prev_set)
+            cur_set = Signal(width, name=f"cur_set_{level}")
+            m.d.comb += cur_set.eq(prev_set | pp.o)
+            prev_set = cur_set
+            m.d.comb += Assert(pp.o == dut.o[level])
+            expected_idx = Signal(32, name=f"expected_idx_{level}")
+            number_of_prev_en_o_set = reduce(
+                operator.add, (i.en_o for i in priority_pickers[:level]), 0)
+            m.d.comb += expected_idx.eq(number_of_prev_en_o_set)
+            if indices:
+                m.d.comb += Assert(expected_idx == dut.idx_o[level])
+
+        self.assertFormal(m)
+
+    def test_4_levels_1_idxs_f_mi_f(self):
+        self.tst(width=4, levels=1, indices=False, multi_in=False)
+
+    def test_4_levels_1_idxs_f_mi_t(self):
+        self.tst(width=4, levels=1, indices=False, multi_in=True)
+
+    def test_4_levels_1_idxs_t_mi_f(self):
+        self.tst(width=4, levels=1, indices=True, multi_in=False)
+
+    def test_4_levels_1_idxs_t_mi_t(self):
+        self.tst(width=4, levels=1, indices=True, multi_in=True)
+
+    def test_4_levels_2_idxs_f_mi_f(self):
+        self.tst(width=4, levels=2, indices=False, multi_in=False)
+
+    def test_4_levels_2_idxs_f_mi_t(self):
+        self.tst(width=4, levels=2, indices=False, multi_in=True)
+
+    def test_4_levels_2_idxs_t_mi_f(self):
+        self.tst(width=4, levels=2, indices=True, multi_in=False)
+
+    def test_4_levels_2_idxs_t_mi_t(self):
+        self.tst(width=4, levels=2, indices=True, multi_in=True)
+
+    def test_4_levels_3_idxs_f_mi_f(self):
+        self.tst(width=4, levels=3, indices=False, multi_in=False)
+
+    def test_4_levels_3_idxs_f_mi_t(self):
+        self.tst(width=4, levels=3, indices=False, multi_in=True)
+
+    def test_4_levels_3_idxs_t_mi_f(self):
+        self.tst(width=4, levels=3, indices=True, multi_in=False)
+
+    def test_4_levels_3_idxs_t_mi_t(self):
+        self.tst(width=4, levels=3, indices=True, multi_in=True)
+
+    def test_4_levels_4_idxs_f_mi_f(self):
+        self.tst(width=4, levels=4, indices=False, multi_in=False)
+
+    def test_4_levels_4_idxs_f_mi_t(self):
+        self.tst(width=4, levels=4, indices=False, multi_in=True)
+
+    def test_4_levels_4_idxs_t_mi_f(self):
+        self.tst(width=4, levels=4, indices=True, multi_in=False)
+
+    def test_4_levels_4_idxs_t_mi_t(self):
+        self.tst(width=4, levels=4, indices=True, multi_in=True)
+
+    def test_8_levels_1_idxs_f_mi_f(self):
+        self.tst(width=8, levels=1, indices=False, multi_in=False)
+
+    def test_8_levels_1_idxs_f_mi_t(self):
+        self.tst(width=8, levels=1, indices=False, multi_in=True)
+
+    def test_8_levels_1_idxs_t_mi_f(self):
+        self.tst(width=8, levels=1, indices=True, multi_in=False)
+
+    def test_8_levels_1_idxs_t_mi_t(self):
+        self.tst(width=8, levels=1, indices=True, multi_in=True)
+
+    def test_8_levels_2_idxs_f_mi_f(self):
+        self.tst(width=8, levels=2, indices=False, multi_in=False)
+
+    def test_8_levels_2_idxs_f_mi_t(self):
+        self.tst(width=8, levels=2, indices=False, multi_in=True)
+
+    def test_8_levels_2_idxs_t_mi_f(self):
+        self.tst(width=8, levels=2, indices=True, multi_in=False)
+
+    def test_8_levels_2_idxs_t_mi_t(self):
+        self.tst(width=8, levels=2, indices=True, multi_in=True)
+
+    def test_8_levels_3_idxs_f_mi_f(self):
+        self.tst(width=8, levels=3, indices=False, multi_in=False)
+
+    def test_8_levels_3_idxs_f_mi_t(self):
+        self.tst(width=8, levels=3, indices=False, multi_in=True)
+
+    def test_8_levels_3_idxs_t_mi_f(self):
+        self.tst(width=8, levels=3, indices=True, multi_in=False)
+
+    def test_8_levels_3_idxs_t_mi_t(self):
+        self.tst(width=8, levels=3, indices=True, multi_in=True)
+
+    def test_8_levels_4_idxs_f_mi_f(self):
+        self.tst(width=8, levels=4, indices=False, multi_in=False)
+
+    def test_8_levels_4_idxs_f_mi_t(self):
+        self.tst(width=8, levels=4, indices=False, multi_in=True)
+
+    def test_8_levels_4_idxs_t_mi_f(self):
+        self.tst(width=8, levels=4, indices=True, multi_in=False)
+
+    def test_8_levels_4_idxs_t_mi_t(self):
+        self.tst(width=8, levels=4, indices=True, multi_in=True)
+
+    def test_8_levels_5_idxs_f_mi_f(self):
+        self.tst(width=8, levels=5, indices=False, multi_in=False)
+
+    def test_8_levels_5_idxs_f_mi_t(self):
+        self.tst(width=8, levels=5, indices=False, multi_in=True)
+
+    def test_8_levels_5_idxs_t_mi_f(self):
+        self.tst(width=8, levels=5, indices=True, multi_in=False)
+
+    def test_8_levels_5_idxs_t_mi_t(self):
+        self.tst(width=8, levels=5, indices=True, multi_in=True)
+
+    def test_8_levels_6_idxs_f_mi_f(self):
+        self.tst(width=8, levels=6, indices=False, multi_in=False)
+
+    def test_8_levels_6_idxs_f_mi_t(self):
+        self.tst(width=8, levels=6, indices=False, multi_in=True)
+
+    def test_8_levels_6_idxs_t_mi_f(self):
+        self.tst(width=8, levels=6, indices=True, multi_in=False)
+
+    def test_8_levels_6_idxs_t_mi_t(self):
+        self.tst(width=8, levels=6, indices=True, multi_in=True)
+
+    def test_8_levels_7_idxs_f_mi_f(self):
+        self.tst(width=8, levels=7, indices=False, multi_in=False)
+
+    def test_8_levels_7_idxs_f_mi_t(self):
+        self.tst(width=8, levels=7, indices=False, multi_in=True)
+
+    def test_8_levels_7_idxs_t_mi_f(self):
+        self.tst(width=8, levels=7, indices=True, multi_in=False)
+
+    def test_8_levels_7_idxs_t_mi_t(self):
+        self.tst(width=8, levels=7, indices=True, multi_in=True)
+
+    def test_8_levels_8_idxs_f_mi_f(self):
+        self.tst(width=8, levels=8, indices=False, multi_in=False)
+
+    def test_8_levels_8_idxs_f_mi_t(self):
+        self.tst(width=8, levels=8, indices=False, multi_in=True)
+
+    def test_8_levels_8_idxs_t_mi_f(self):
+        self.tst(width=8, levels=8, indices=True, multi_in=False)
+
+    def test_8_levels_8_idxs_t_mi_t(self):
+        self.tst(width=8, levels=8, indices=True, multi_in=True)
+
+    def test_16_levels_16_idxs_f_mi_f(self):
+        self.tst(width=16, levels=16, indices=False, multi_in=False)
+
+
+class TestBetterMultiPriorityPicker(TestMultiPriorityPicker):
+    def make_dut(self, width, levels, indices, multi_in):
+        if multi_in:
+            self.skipTest(
+                "multi_in are not supported by BetterMultiPriorityPicker")
+        if indices:
+            self.skipTest(
+                "indices are not supported by BetterMultiPriorityPicker")
+        dut = BetterMultiPriorityPicker(width=width, levels=levels)
+        self.assertEqual(width, dut.width)
+        self.assertEqual(levels, dut.levels)
+        return dut
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/nmutil/formal/test_plru.py b/src/nmutil/formal/test_plru.py
new file mode 100644 (file)
index 0000000..f7c3870
--- /dev/null
@@ -0,0 +1,245 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+import unittest
+from nmigen.hdl.ast import (AnySeq, Assert, Signal, Value, Array, Value)
+from nmigen.hdl.dsl import Module
+from nmigen.sim import Delay, Tick
+from nmutil.formaltest import FHDLTestCase
+from nmutil.plru2 import PLRU  # , PLRUs
+from nmutil.sim_util import write_il, do_sim
+from nmutil.plain_data import plain_data
+
+
+@plain_data()
+class PrettyPrintState:
+    __slots__ = "indent", "file", "at_line_start"
+
+    def __init__(self, indent=0, file=None, at_line_start=True):
+        self.indent = indent
+        self.file = file
+        self.at_line_start = at_line_start
+
+    def write(self, text):
+        # type: (str) -> None
+        for ch in text:
+            if ch == "\n":
+                self.at_line_start = True
+            elif self.at_line_start:
+                self.at_line_start = False
+                print("    " * self.indent, file=self.file, end='')
+            print(ch, file=self.file, end='')
+
+
+@plain_data()
+class PLRUNode:
+    __slots__ = "id", "state", "left_child", "right_child"
+
+    def __init__(self, id, left_child=None, right_child=None):
+        # type: (int, PLRUNode | None, PLRUNode | None) -> None
+        self.id = id
+        self.state = Signal(name=f"state_{id}")
+        self.left_child = left_child
+        self.right_child = right_child
+
+    @property
+    def depth(self):
+        depth = 0
+        if self.left_child is not None:
+            depth = max(depth, 1 + self.left_child.depth)
+        if self.right_child is not None:
+            depth = max(depth, 1 + self.right_child.depth)
+        return depth
+
+    def __pretty_print(self, state):
+        # type: (PrettyPrintState) -> None
+        state.write("PLRUNode(")
+        state.indent += 1
+        state.write(f"id={self.id!r},\n")
+        state.write(f"state={self.state!r},\n")
+        state.write("left_child=")
+        if self.left_child is None:
+            state.write("None")
+        else:
+            self.left_child.__pretty_print(state)
+        state.write(",\nright_child=")
+        if self.right_child is None:
+            state.write("None")
+        else:
+            self.right_child.__pretty_print(state)
+        state.indent -= 1
+        state.write("\n)")
+
+    def pretty_print(self, file=None):
+        self.__pretty_print(PrettyPrintState(file=file))
+        print(file=file)
+
+    def set_states_from_index(self, m, index, ids):
+        # type: (Module, Value, list[Signal]) -> None
+        m.d.sync += self.state.eq(~index[-1])
+        m.d.comb += ids[0].eq(self.id)
+        with m.If(index[-1]):
+            if self.right_child is not None:
+                self.right_child.set_states_from_index(m, index[:-1], ids[1:])
+        with m.Else():
+            if self.left_child is not None:
+                self.left_child.set_states_from_index(m, index[:-1], ids[1:])
+
+    def get_lru(self, m, ids):
+        # type: (Module, list[Signal]) -> Signal
+        retval = Signal(1 + self.depth, name=f"lru_{self.id}", reset=0)
+        m.d.comb += retval[-1].eq(self.state)
+        m.d.comb += ids[0].eq(self.id)
+        with m.If(self.state):
+            if self.right_child is not None:
+                right_lru = self.right_child.get_lru(m, ids[1:])
+                m.d.comb += retval[:-1].eq(right_lru)
+        with m.Else():
+            if self.left_child is not None:
+                left_lru = self.left_child.get_lru(m, ids[1:])
+                m.d.comb += retval[:-1].eq(left_lru)
+        return retval
+
+
+class TestPLRU(FHDLTestCase):
+    def tst(self, log2_num_ways, test_seq=None):
+        # type: (int, list[int | None] | None) -> None
+
+        @plain_data()
+        class MyAssert:
+            __slots__ = "test", "en"
+
+            def __init__(self, test, en):
+                # type: (Value, Signal) -> None
+                self.test = test
+                self.en = en
+
+        asserts = []  # type: list[MyAssert]
+
+        def assert_(test):
+            if test_seq is None:
+                return [Assert(test, src_loc_at=1)]
+            assert_en = Signal(name="assert_en", src_loc_at=1, reset=False)
+            asserts.append(MyAssert(test=test, en=assert_en))
+            return [assert_en.eq(True)]
+
+        dut = PLRU(log2_num_ways, debug=True)  # check debug works
+        write_il(self, dut, ports=dut.ports())
+        # debug clutters up vcd, so disable it for formal proofs
+        dut = PLRU(log2_num_ways, debug=test_seq is not None)
+        num_ways = 1 << log2_num_ways
+        self.assertEqual(dut.log2_num_ways, log2_num_ways)
+        self.assertEqual(dut.num_ways, num_ways)
+        self.assertIsInstance(dut.acc_i, Signal)
+        self.assertIsInstance(dut.acc_en_i, Signal)
+        self.assertIsInstance(dut.lru_o, Signal)
+        self.assertEqual(len(dut.acc_i), log2_num_ways)
+        self.assertEqual(len(dut.acc_en_i), 1)
+        self.assertEqual(len(dut.lru_o), log2_num_ways)
+        write_il(self, dut, ports=dut.ports())
+        m = Module()
+        nodes = [PLRUNode(i) for i in range(num_ways - 1)]
+        self.assertIsInstance(dut._tree, Array)
+        self.assertEqual(len(dut._tree), len(nodes))
+        for i in range(len(nodes)):
+            if i != 0:
+                parent = (i + 1) // 2 - 1
+                if i % 2:
+                    nodes[parent].left_child = nodes[i]
+                else:
+                    nodes[parent].right_child = nodes[i]
+            self.assertIsInstance(dut._tree[i], Signal)
+            self.assertEqual(len(dut._tree[i]), 1)
+            m.d.comb += assert_(nodes[i].state == dut._tree[i])
+
+        if test_seq is None:
+            m.d.comb += [
+                dut.acc_i.eq(AnySeq(log2_num_ways)),
+                dut.acc_en_i.eq(AnySeq(1)),
+            ]
+
+        l2nwr = range(log2_num_ways)
+        upd_ids = [Signal(log2_num_ways, name=f"upd_id_{i}") for i in l2nwr]
+        with m.If(dut.acc_en_i):
+            nodes[0].set_states_from_index(m, dut.acc_i, upd_ids)
+
+            self.assertEqual(len(dut._upd_lru_nodes), len(upd_ids))
+            for l, r in zip(dut._upd_lru_nodes, upd_ids):
+                m.d.comb += assert_(l == r)
+
+        get_ids = [Signal(log2_num_ways, name=f"get_id_{i}") for i in l2nwr]
+        lru = Signal(log2_num_ways)
+        m.d.comb += lru.eq(nodes[0].get_lru(m, get_ids))
+        m.d.comb += assert_(dut.lru_o == lru)
+        self.assertEqual(len(dut._get_lru_nodes), len(get_ids))
+        for l, r in zip(dut._get_lru_nodes, get_ids):
+            m.d.comb += assert_(l == r)
+
+        nodes[0].pretty_print()
+
+        m.submodules.dut = dut
+        if test_seq is None:
+            self.assertFormal(m, mode="prove", depth=2)
+        else:
+            traces = [dut.acc_i, dut.acc_en_i, *dut._tree]
+            for node in nodes:
+                traces.append(node.state)
+            traces += [
+                dut.lru_o, lru, *dut._get_lru_nodes, *get_ids,
+                *dut._upd_lru_nodes, *upd_ids,
+            ]
+
+            def subtest(acc_i, acc_en_i):
+                yield dut.acc_i.eq(acc_i)
+                yield dut.acc_en_i.eq(acc_en_i)
+                yield Tick()
+                yield Delay(0.7e-6)
+                for a in asserts:
+                    if (yield a.en):
+                        with self.subTest(
+                                assert_loc=':'.join(map(str, a.en.src_loc))):
+                            self.assertTrue((yield a.test))
+
+            def process():
+                for test_item in test_seq:
+                    if test_item is None:
+                        with self.subTest(test_item="None"):
+                            yield from subtest(acc_i=0, acc_en_i=0)
+                    else:
+                        with self.subTest(test_item=hex(test_item)):
+                            yield from subtest(acc_i=test_item, acc_en_i=1)
+
+            with do_sim(self, m, traces) as sim:
+                sim.add_clock(1e-6)
+                sim.add_process(process)
+                sim.run()
+
+    def test_bits_1(self):
+        self.tst(1)
+
+    def test_bits_2(self):
+        self.tst(2)
+
+    def test_bits_3(self):
+        self.tst(3)
+
+    def test_bits_4(self):
+        self.tst(4)
+
+    def test_bits_5(self):
+        self.tst(5)
+
+    def test_bits_6(self):
+        self.tst(6)
+
+    def test_bits_3_sim(self):
+        self.tst(3, [
+            0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7,
+            None,
+            0x0, 0x4, 0x2, 0x6, 0x1, 0x5, 0x3, 0x7,
+            None,
+        ])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/nmutil/formal/test_queue.py b/src/nmutil/formal/test_queue.py
new file mode 100644 (file)
index 0000000..d039071
--- /dev/null
@@ -0,0 +1,319 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+import unittest
+from nmigen.hdl.ast import (AnySeq, Assert, Signal, Assume, Const,
+                            unsigned, AnyConst)
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.queue import Queue
+from nmutil.sim_util import write_il
+
+
+class TestQueue(FHDLTestCase):
+    def tst(self, width, depth, fwft, pipe):
+        assert isinstance(width, int)
+        assert isinstance(depth, int)
+        assert isinstance(fwft, bool)
+        assert isinstance(pipe, bool)
+        dut = Queue(width=width, depth=depth, fwft=fwft, pipe=pipe)
+        self.assertEqual(width, dut.width)
+        self.assertEqual(depth, dut.depth)
+        self.assertEqual(fwft, dut.fwft)
+        self.assertEqual(pipe, dut.pipe)
+        write_il(self, dut, ports=[
+            dut.level,
+            dut.r_data, dut.r_en, dut.r_rdy,
+            dut.w_data, dut.w_en, dut.w_rdy,
+        ])
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.r_en.eq(AnySeq(1))
+        m.d.comb += dut.w_data.eq(AnySeq(width))
+        m.d.comb += dut.w_en.eq(AnySeq(1))
+
+        index_width = 16
+        max_index = Const(-1, unsigned(index_width))
+
+        check_r_data = Signal(width)
+        check_r_data_valid = Signal(reset=False)
+        r_index = Signal(index_width)
+        check_w_data = Signal(width)
+        check_w_data_valid = Signal(reset=False)
+        w_index = Signal(index_width)
+        check_index = Signal(index_width)
+        m.d.comb += check_index.eq(AnyConst(index_width))
+
+        with m.If(dut.r_en & dut.r_rdy):
+            with m.If(r_index == check_index):
+                m.d.sync += [
+                    check_r_data.eq(dut.r_data),
+                    check_r_data_valid.eq(True),
+                ]
+            m.d.sync += [
+                Assume(r_index != max_index),
+                r_index.eq(r_index + 1),
+            ]
+
+        with m.If(dut.w_en & dut.w_rdy):
+            with m.If(w_index == check_index):
+                m.d.sync += [
+                    check_w_data.eq(dut.w_data),
+                    check_w_data_valid.eq(True),
+                ]
+            m.d.sync += [
+                Assume(w_index != max_index),
+                w_index.eq(w_index + 1),
+            ]
+
+        with m.If(check_r_data_valid & check_w_data_valid):
+            m.d.comb += Assert(check_r_data == check_w_data)
+
+        # 10 is enough to fully test smaller depth queues, larger queues are
+        # assumed to be correct because the logic doesn't really depend on
+        # queue depth past the first few values.
+        self.assertFormal(m, depth=10)
+
+    def test_have_all(self):
+        def bool_char(v):
+            if v:
+                return "t"
+            return "f"
+
+        missing = []
+
+        for width in [1, 8]:
+            for depth in range(8 + 1):
+                for fwft in (False, True):
+                    for pipe in (False, True):
+                        name = (f"test_{width}_"
+                                f"depth_{depth}_"
+                                f"fwft_{bool_char(fwft)}_"
+                                f"pipe_{bool_char(pipe)}")
+                        if not callable(getattr(self, name, None)):
+                            missing.append(f"    def {name}(self):\n"
+                                           f"        self.tst("
+                                           f"width={width}, depth={depth}, "
+                                           f"fwft={fwft}, pipe={pipe})\n")
+        missing = "\n".join(missing)
+        self.assertTrue(missing == "", f"missing functions:\n\n{missing}")
+
+    def test_1_depth_0_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=0, fwft=False, pipe=False)
+
+    def test_1_depth_0_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=0, fwft=False, pipe=True)
+
+    def test_1_depth_0_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=0, fwft=True, pipe=False)
+
+    def test_1_depth_0_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=0, fwft=True, pipe=True)
+
+    def test_1_depth_1_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=1, fwft=False, pipe=False)
+
+    def test_1_depth_1_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=1, fwft=False, pipe=True)
+
+    def test_1_depth_1_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=1, fwft=True, pipe=False)
+
+    def test_1_depth_1_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=1, fwft=True, pipe=True)
+
+    def test_1_depth_2_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=2, fwft=False, pipe=False)
+
+    def test_1_depth_2_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=2, fwft=False, pipe=True)
+
+    def test_1_depth_2_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=2, fwft=True, pipe=False)
+
+    def test_1_depth_2_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=2, fwft=True, pipe=True)
+
+    def test_1_depth_3_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=3, fwft=False, pipe=False)
+
+    def test_1_depth_3_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=3, fwft=False, pipe=True)
+
+    def test_1_depth_3_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=3, fwft=True, pipe=False)
+
+    def test_1_depth_3_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=3, fwft=True, pipe=True)
+
+    def test_1_depth_4_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=4, fwft=False, pipe=False)
+
+    def test_1_depth_4_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=4, fwft=False, pipe=True)
+
+    def test_1_depth_4_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=4, fwft=True, pipe=False)
+
+    def test_1_depth_4_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=4, fwft=True, pipe=True)
+
+    def test_1_depth_5_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=5, fwft=False, pipe=False)
+
+    def test_1_depth_5_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=5, fwft=False, pipe=True)
+
+    def test_1_depth_5_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=5, fwft=True, pipe=False)
+
+    def test_1_depth_5_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=5, fwft=True, pipe=True)
+
+    def test_1_depth_6_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=6, fwft=False, pipe=False)
+
+    def test_1_depth_6_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=6, fwft=False, pipe=True)
+
+    def test_1_depth_6_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=6, fwft=True, pipe=False)
+
+    def test_1_depth_6_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=6, fwft=True, pipe=True)
+
+    def test_1_depth_7_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=7, fwft=False, pipe=False)
+
+    def test_1_depth_7_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=7, fwft=False, pipe=True)
+
+    def test_1_depth_7_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=7, fwft=True, pipe=False)
+
+    def test_1_depth_7_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=7, fwft=True, pipe=True)
+
+    def test_1_depth_8_fwft_f_pipe_f(self):
+        self.tst(width=1, depth=8, fwft=False, pipe=False)
+
+    def test_1_depth_8_fwft_f_pipe_t(self):
+        self.tst(width=1, depth=8, fwft=False, pipe=True)
+
+    def test_1_depth_8_fwft_t_pipe_f(self):
+        self.tst(width=1, depth=8, fwft=True, pipe=False)
+
+    def test_1_depth_8_fwft_t_pipe_t(self):
+        self.tst(width=1, depth=8, fwft=True, pipe=True)
+
+    def test_8_depth_0_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=0, fwft=False, pipe=False)
+
+    def test_8_depth_0_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=0, fwft=False, pipe=True)
+
+    def test_8_depth_0_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=0, fwft=True, pipe=False)
+
+    def test_8_depth_0_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=0, fwft=True, pipe=True)
+
+    def test_8_depth_1_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=1, fwft=False, pipe=False)
+
+    def test_8_depth_1_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=1, fwft=False, pipe=True)
+
+    def test_8_depth_1_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=1, fwft=True, pipe=False)
+
+    def test_8_depth_1_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=1, fwft=True, pipe=True)
+
+    def test_8_depth_2_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=2, fwft=False, pipe=False)
+
+    def test_8_depth_2_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=2, fwft=False, pipe=True)
+
+    def test_8_depth_2_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=2, fwft=True, pipe=False)
+
+    def test_8_depth_2_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=2, fwft=True, pipe=True)
+
+    def test_8_depth_3_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=3, fwft=False, pipe=False)
+
+    def test_8_depth_3_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=3, fwft=False, pipe=True)
+
+    def test_8_depth_3_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=3, fwft=True, pipe=False)
+
+    def test_8_depth_3_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=3, fwft=True, pipe=True)
+
+    def test_8_depth_4_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=4, fwft=False, pipe=False)
+
+    def test_8_depth_4_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=4, fwft=False, pipe=True)
+
+    def test_8_depth_4_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=4, fwft=True, pipe=False)
+
+    def test_8_depth_4_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=4, fwft=True, pipe=True)
+
+    def test_8_depth_5_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=5, fwft=False, pipe=False)
+
+    def test_8_depth_5_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=5, fwft=False, pipe=True)
+
+    def test_8_depth_5_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=5, fwft=True, pipe=False)
+
+    def test_8_depth_5_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=5, fwft=True, pipe=True)
+
+    def test_8_depth_6_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=6, fwft=False, pipe=False)
+
+    def test_8_depth_6_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=6, fwft=False, pipe=True)
+
+    def test_8_depth_6_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=6, fwft=True, pipe=False)
+
+    def test_8_depth_6_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=6, fwft=True, pipe=True)
+
+    def test_8_depth_7_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=7, fwft=False, pipe=False)
+
+    def test_8_depth_7_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=7, fwft=False, pipe=True)
+
+    def test_8_depth_7_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=7, fwft=True, pipe=False)
+
+    def test_8_depth_7_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=7, fwft=True, pipe=True)
+
+    def test_8_depth_8_fwft_f_pipe_f(self):
+        self.tst(width=8, depth=8, fwft=False, pipe=False)
+
+    def test_8_depth_8_fwft_f_pipe_t(self):
+        self.tst(width=8, depth=8, fwft=False, pipe=True)
+
+    def test_8_depth_8_fwft_t_pipe_f(self):
+        self.tst(width=8, depth=8, fwft=True, pipe=False)
+
+    def test_8_depth_8_fwft_t_pipe_t(self):
+        self.tst(width=8, depth=8, fwft=True, pipe=True)
+
+
+if __name__ == "__main__":
+    unittest.main()
index 4b9b174848afa022a5708262183d66ee32da8745..05a5396b8198acab40145d486439f7fe788e78c0 100644 (file)
@@ -1,16 +1,13 @@
-import os
 import re
 import shutil
 import subprocess
 import textwrap
-import traceback
 import unittest
-import warnings
-from contextlib import contextmanager
-
+from nmigen.hdl.ast import Statement
 from nmigen.hdl.ir import Fragment
 from nmigen.back import rtlil
-from nmigen._toolchain import require_tool
+from nmutil.toolchain import require_tool
+from nmutil.get_test_path import get_test_path
 
 
 __all__ = ["FHDLTestCase"]
@@ -20,6 +17,7 @@ class FHDLTestCase(unittest.TestCase):
     def assertRepr(self, obj, repr_str):
         if isinstance(obj, list):
             obj = Statement.cast(obj)
+
         def prepare_repr(repr_str):
             repr_str = re.sub(r"\s+",   " ",  repr_str)
             repr_str = re.sub(r"\( (?=\()", "(", repr_str)
@@ -27,50 +25,20 @@ class FHDLTestCase(unittest.TestCase):
             return repr_str.strip()
         self.assertEqual(prepare_repr(repr(obj)), prepare_repr(repr_str))
 
-    @contextmanager
-    def assertRaises(self, exception, msg=None):
-        with super().assertRaises(exception) as cm:
-            yield
-        if msg is not None:
-            # WTF? unittest.assertRaises is completely broken.
-            self.assertEqual(str(cm.exception), msg)
-
-    @contextmanager
-    def assertRaisesRegex(self, exception, regex=None):
-        with super().assertRaises(exception) as cm:
-            yield
-        if regex is not None:
-            # unittest.assertRaisesRegex also seems broken...
-            self.assertRegex(str(cm.exception), regex)
-
-    @contextmanager
-    def assertWarns(self, category, msg=None):
-        with warnings.catch_warnings(record=True) as warns:
-            yield
-        self.assertEqual(len(warns), 1)
-        self.assertEqual(warns[0].category, category)
-        if msg is not None:
-            self.assertEqual(str(warns[0].message), msg)
-
-    def assertFormal(self, spec, mode="bmc", depth=1, solver=""):
-        caller, *_ = traceback.extract_stack(limit=2)
-        spec_root, _ = os.path.splitext(caller.filename)
-        spec_dir = os.path.dirname(spec_root)
-        spec_name = "{}_{}".format(
-            os.path.basename(spec_root).replace("test_", "spec_"),
-            caller.name.replace("test_", "")
-        )
+    def assertFormal(self, spec, mode="bmc", depth=1, solver="",
+                     base_path="formal_test_temp", smtbmc_opts="--logic=ALL"):
+        path = get_test_path(self, base_path)
 
         # The sby -f switch seems not fully functional when sby is
         # reading from stdin.
-        if os.path.exists(os.path.join(spec_dir, spec_name)):
-            shutil.rmtree(os.path.join(spec_dir, spec_name))
+        shutil.rmtree(path, ignore_errors=True)
+        path.mkdir(parents=True)
 
         if mode == "hybrid":
             # A mix of BMC and k-induction, as per personal
             # communication with Clifford Wolf.
             script = "setattr -unset init w:* a:nmigen.sample_reg %d"
-            mode   = "bmc"
+            mode = "bmc"
         else:
             script = ""
 
@@ -81,7 +49,7 @@ class FHDLTestCase(unittest.TestCase):
         wait on
 
         [engines]
-        smtbmc {solver}
+        smtbmc {solver} -- -- {smtbmc_opts}
 
         [script]
         read_ilang top.il
@@ -95,14 +63,14 @@ class FHDLTestCase(unittest.TestCase):
             depth=depth,
             solver=solver,
             script=script,
+            smtbmc_opts=smtbmc_opts,
             rtlil=rtlil.convert(Fragment.get(spec, platform="formal"))
         )
-        with subprocess.Popen([require_tool("sby"), "-f", "-d", spec_name],
-                              cwd=spec_dir,
+        with subprocess.Popen([require_tool("sby"), "-d", "job"],
+                              cwd=path,
                               universal_newlines=True,
                               stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE) as proc:
             stdout, stderr = proc.communicate(config)
             if proc.returncode != 0:
-                self.fail("Formal verification failed:\n" + stdout)
-
+                self.fail(f"Formal verification failed:\nIn {path}\n{stdout}")
diff --git a/src/nmutil/get_test_path.py b/src/nmutil/get_test_path.py
new file mode 100644 (file)
index 0000000..f58ada8
--- /dev/null
@@ -0,0 +1,46 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+import weakref
+from pathlib import Path
+
+
+class RunCounter:
+    def __init__(self):
+        self.__run_counts = {}
+        """dict mapping self.next() keys to the next int value returned by
+        self.next()"""
+
+    def next(self, k):
+        """get a incrementing run counter for a `str` key `k`. returns an `int`."""
+        retval = self.__run_counts.get(k, 0)
+        self.__run_counts[k] = retval + 1
+        return retval
+
+    __RUN_COUNTERS = {}
+    """dict mapping object ids (int) to a tuple of a weakref.ref to that
+    object, and the corresponding RunCounter"""
+
+    @staticmethod
+    def get(obj):
+        k = id(obj)
+        t = RunCounter.__RUN_COUNTERS
+        try:
+            return t[k][1]
+        except KeyError:
+            retval = RunCounter()
+
+            def on_finalize(obj):
+                del t[k]
+            t[k] = weakref.ref(obj, on_finalize), retval
+            return retval
+
+
+def get_test_path(test_case, base_path):
+    """get the `Path` for a particular unittest.TestCase instance
+    (`test_case`). base_path is either a str or a path-like."""
+    count = RunCounter.get(test_case).next(test_case.id())
+    return Path(base_path) / test_case.id() / str(count)
diff --git a/src/nmutil/grev.py b/src/nmutil/grev.py
new file mode 100644 (file)
index 0000000..2b22fe1
--- /dev/null
@@ -0,0 +1,270 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay programmerjake@gmail.com
+# Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+r"""Generalized bit-reverse.
+
+https://bugs.libre-soc.org/show_bug.cgi?id=755
+
+A generalized bit-reverse is the following operation:
+grev(input, chunk_sizes):
+    for i in range(input.width):
+        j = i XOR chunk_sizes
+        output bit i = input bit j
+    return output
+
+This is useful because many bit/byte reverse operations can be created by
+setting `chunk_sizes` to different values. Some examples for a 64-bit
+`grev` operation:
+* `0b111111` -- reverse all bits in the 64-bit word
+* `0b111000` -- reverse bytes in the 64-bit word
+* `0b011000` -- reverse bytes in each 32-bit word independently
+* `0b110000` -- reverse order of 16-bit words
+
+This is implemented by using a series of `log2_width`
+`width`-bit wide 2:1 muxes, arranged just like a butterfly network:
+https://en.wikipedia.org/wiki/Butterfly_network
+
+To compute `out = grev(inp, 0bxyz)`, where `x`, `y`, and `z` are single bits,
+the following permutation network is used:
+
+                inp[0]  inp[1]  inp[2]  inp[3]  inp[4]  inp[5]  inp[6]  inp[7]
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b000): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  +       +       +       +       +       +       +       +
+                  |\     /|       |\     /|       |\     /|       |\     /|
+                  | \   / |       | \   / |       | \   / |       | \   / |
+                  |  \ /  |       |  \ /  |       |  \ /  |       |  \ /  |
+swap 1-bit words: |   X   |       |   X   |       |   X   |       |   X   |
+                  |  / \  |       |  / \  |       |  / \  |       |  / \  |
+                  | /   \ |       | /   \ |       | /   \ |       | /   \ |
+              z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b00z): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       | +-----|-------+       |       | +-----|-------+
+                  | +-----|-|-----+       |       | +-----|-|-----+       |
+                  | |     | |     |       |       | |     | |     |       |
+swap 2-bit words: | |     +-|-----|-----+ |       | |     +-|-----|-----+ |
+                  +-|-----|-|---+ |     | |       +-|-----|-|---+ |     | |
+                  | |     | |   | |     | |       | |     | |   | |     | |
+                  | /     | /   \ |     \ |       | /     | /   \ |     \ |
+              y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b0yz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       |       |       | +-----|-------|-------|-------+
+                  |       |       | +-----|-|-----|-------|-------+       |
+                  |       | +-----|-|-----|-|-----|-------+       |       |
+                  | +-----|-|-----|-|-----|-|-----+       |       |       |
+swap 4-bit words: | |     | |     | |     | |     |       |       |       |
+                  | |     | |     | |     +-|-----|-------|-------|-----+ |
+                  | |     | |     +-|-----|-|-----|-------|-----+ |     | |
+                  | |     +-|-----|-|-----|-|-----|-----+ |     | |     | |
+                  +-|-----|-|-----|-|-----|-|---+ |     | |     | |     | |
+                  | |     | |     | |     | |   | |     | |     | |     | |
+                  | /     | /     | /     | /   \ |     \ |     \ |     \ |
+              x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0bxyz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                out[0]  out[1]  out[2]  out[3]  out[4]  out[5]  out[6]  out[7]
+"""
+
+from nmigen.hdl.ast import Signal, Mux, Cat
+from nmigen.hdl.ast import Assert
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ir import Elaboratable
+from nmigen.back import rtlil
+import string
+
+
+def grev(inval, chunk_sizes, log2_width):
+    """Python reference implementation of generalized bit-reverse.
+    See `GRev` for documentation.
+    """
+    # mask inputs into range
+    inval &= 2 ** 2 ** log2_width - 1
+    chunk_sizes &= 2 ** log2_width - 1
+    # core algorithm:
+    retval = 0
+    for i in range(2 ** log2_width):
+        # don't use `if` so this can be used with nmigen values
+        bit = (inval & (1 << i)) != 0
+        retval |= bit << (i ^ chunk_sizes)
+    return retval
+
+
+class GRev(Elaboratable):
+    """Generalized bit-reverse.
+
+    See the module's documentation for a description of generalized
+    bit-reverse, as well as the permutation network created by this class.
+
+    Attributes:
+    log2_width: int
+        see __init__'s docs.
+    msb_first: bool
+        see __init__'s docs.
+    width: int
+        the input/output width of the grev operation. The value is
+        `2 ** self.log2_width`.
+    input: Signal with width=self.width
+        the input value of the grev operation.
+    chunk_sizes: Signal with width=self.log2_width
+        the input that describes which bits get swapped. See the module docs
+        for additional details.
+    output: Signal with width=self.width
+        the output value of the grev operation.
+    """
+
+    def __init__(self, log2_width, msb_first=False):
+        """Create a `GRev` instance.
+
+        log2_width: int
+            the base-2 logarithm of the input/output width of the grev
+            operation.
+        msb_first: bool
+            If `msb_first` is True, then the order will be the reverse of the
+            standard order -- swapping adjacent 8-bit words, then 4-bit words,
+            then 2-bit words, then 1-bit words -- using the bits of
+            `chunk_sizes` from MSB to LSB.
+            If `msb_first` is False (the default), then the order will be the
+            standard order -- swapping adjacent 1-bit words, then 2-bit words,
+            then 4-bit words, then 8-bit words -- using the bits of
+            `chunk_sizes` from LSB to MSB.
+        """
+        self.log2_width = log2_width
+        self.msb_first = msb_first
+        self.width = 1 << log2_width
+        self.input = Signal(self.width)
+        self.chunk_sizes = Signal(log2_width)
+        self.output = Signal(self.width)
+
+        # internal signals exposed for unit tests, should be ignored by
+        # external users. The signals are created in the constructor because
+        # that's where all class member variables should *always* be created.
+        # If we were to create the members in elaborate() instead, it would
+        # just make the class very confusing to use.
+        #
+        # `_intermediates[step_count]`` is the value after `step_count` steps
+        # of muxing. e.g. (for `msb_first == False`) `_intermediates[4]` is the
+        # result of 4 steps of muxing, being the value `grev(inp,0b00wxyz)`.
+        self._intermediates = [self.__inter(i) for i in range(log2_width + 1)]
+
+    def _get_cs_bit_index(self, step_index):
+        """get the index of the bit of `chunk_sizes` that this step should mux
+        based off of."""
+        assert 0 <= step_index < self.log2_width
+        if self.msb_first:
+            # reverse so we start from the MSB, producing intermediate values
+            # like, for `step_index == 4`, `0buvwx00` rather than `0b00wxyz`
+            return self.log2_width - step_index - 1
+        return step_index
+
+    def __inter(self, step_count):
+        """make a signal with a name like `grev(inp,0b000xyz)` to match the
+        diagram in the module-level docs."""
+        # make the list of bits in LSB to MSB order
+        chunk_sizes_bits = ['0'] * self.log2_width
+        # for all steps already completed
+        for step_index in range(step_count):
+            bit_num = self._get_cs_bit_index(step_index)
+            ch = string.ascii_lowercase[-1 - bit_num]  # count from z to a
+            chunk_sizes_bits[bit_num] = ch
+        # reverse cuz text is MSB first
+        chunk_sizes_val = '0b' + ''.join(reversed(chunk_sizes_bits))
+        # name works according to Verilog's rules for escaped identifiers cuz
+        # it has no spaces
+        name = f"grev(inp,{chunk_sizes_val})"
+        return Signal(self.width, name=name)
+
+    def __get_permutation(self, step_index):
+        """get the bit permutation for the current step. the returned value is
+        a list[int] where `retval[i] == j` means that this step's input bit `i`
+        goes to this step's output bit `j`."""
+        # we can extract just the latest bit for this step, since the previous
+        # step effectively has it's value's grev arg as `0b000xyz`, and this
+        # step has it's value's grev arg as `0b00wxyz`, so we only need to
+        # compute `grev(prev_step_output,0b00w000)` to get
+        # `grev(inp,0b00wxyz)`. `cur_chunk_sizes` is the `0b00w000`.
+        cur_chunk_sizes = 1 << self._get_cs_bit_index(step_index)
+        # compute bit permutation for `grev(...,0b00w000)`.
+        return [i ^ cur_chunk_sizes for i in range(self.width)]
+
+    def _sigs_and_expected(self, inp, chunk_sizes):
+        """the intermediate signals and the expected values, based off of the
+        passed-in `inp` and `chunk_sizes`."""
+        # we accumulate a mask of which chunk_sizes bits we have accounted for
+        # so far
+        chunk_sizes_mask = 0
+        for step_count, intermediate in enumerate(self._intermediates):
+            # mask out chunk_sizes to get the value
+            cur_chunk_sizes = chunk_sizes & chunk_sizes_mask
+            expected = grev(inp, cur_chunk_sizes, self.log2_width)
+            yield (intermediate, expected)
+            # if step_count is in-range for being a valid step_index
+            if step_count < self.log2_width:
+                # add current step's bit to the mask
+                chunk_sizes_mask |= 1 << self._get_cs_bit_index(step_count)
+        assert chunk_sizes_mask == 2 ** self.log2_width - 1, \
+            "should have got all the bits in chunk_sizes"
+
+    def elaborate(self, platform):
+        m = Module()
+
+        # value after zero steps is just the input
+        m.d.comb += self._intermediates[0].eq(self.input)
+
+        for step_index in range(self.log2_width):
+            step_inp = self._intermediates[step_index]
+            step_out = self._intermediates[step_index + 1]
+            # get permutation for current step
+            permutation = self.__get_permutation(step_index)
+            # figure out which `chunk_sizes` bit we want to pay attention to
+            # for this step.
+            sel = self.chunk_sizes[self._get_cs_bit_index(step_index)]
+            for in_index, out_index in enumerate(permutation):
+                # use in_index so we get the permuted bit
+                permuted_bit = step_inp[in_index]
+                # use out_index so we copy the bit straight thru
+                straight_bit = step_inp[out_index]
+                bit = Mux(sel, permuted_bit, straight_bit)
+                m.d.comb += step_out[out_index].eq(bit)
+        # value after all steps is just the output
+        m.d.comb += self.output.eq(self._intermediates[-1])
+
+        if platform != 'formal':
+            return m
+
+        # formal test comparing directly against the (simpler) version
+        m.d.comb += Assert(self.output == grev(self.input,
+                                               self.chunk_sizes,
+                                               self.log2_width))
+
+        for value, expected in self._sigs_and_expected(self.input,
+                                                       self.chunk_sizes):
+            m.d.comb += Assert(value == expected)
+        return m
+
+    def ports(self):
+        return [self.input, self.chunk_sizes, self.output]
+
+
+# useful to see what is going on:
+# python3 src/nmutil/test/test_grev.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestGrev.test_small/0.il; proc; clean -purge; show top"
+
+if __name__ == '__main__':
+    dut = GRev(3)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("grev3.il", "w") as f:
+        f.write(vl)
index 92ffee20b761e37bd5d89c27d4a1beec7815a57a..3d32a49bde120cca3b52f195c854f518afbc97ad 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ IO Control API
 
     This work is funded through NLnet under Grant 2019-02-012
@@ -42,9 +43,9 @@ class Object:
         self.fields = OrderedDict()
 
     def __setattr__(self, k, v):
-        print ("kv", k, v)
+        print("kv", k, v)
         if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
-           k in dir(Object) or "fields" not in self.__dict__):
+                k in dir(Object) or "fields" not in self.__dict__):
             return object.__setattr__(self, k, v)
         self.fields[k] = v
 
@@ -67,16 +68,16 @@ class Object:
         res = []
         for (k, o) in self.fields.items():
             i = getattr(inp, k)
-            print ("eq", o, i)
+            print("eq", o, i)
             rres = o.eq(i)
             if isinstance(rres, Sequence):
                 res += rres
             else:
                 res.append(rres)
-        print (res)
+        print(res)
         return res
 
-    def ports(self): # being called "keys" would be much better
+    def ports(self):  # being called "keys" would be much better
         return list(self)
 
 
@@ -92,16 +93,15 @@ def add_prefix_to_record_signals(prefix, record):
 
 class RecordObject(Record):
     def __init__(self, layout=None, name=None):
-        #if name is None:
+        # if name is None:
         #    name = tracer.get_var_name(depth=2, default="$ro")
         Record.__init__(self, layout=layout or [], name=name)
 
-
     def __setattr__(self, k, v):
         #print(f"RecordObject setattr({k}, {v})")
         #print (dir(Record))
         if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
-           k in dir(Record) or "fields" not in self.__dict__):
+                k in dir(Record) or "fields" not in self.__dict__):
             return object.__setattr__(self, k, v)
 
         if self.name is None:
@@ -126,7 +126,7 @@ class RecordObject(Record):
         self.layout.fields.update(newlayout)
 
     def __iter__(self):
-        for x in self.fields.values(): # remember: fields is an OrderedDict
+        for x in self.fields.values():  # remember: fields is an OrderedDict
             if hasattr(x, 'ports'):
                 yield from x.ports()
             elif isinstance(x, Record):
@@ -137,7 +137,7 @@ class RecordObject(Record):
             else:
                 yield x
 
-    def ports(self): # would be better being called "keys"
+    def ports(self):  # would be better being called "keys"
         return list(self)
 
 
@@ -147,18 +147,24 @@ class PrevControl(Elaboratable):
                    may be a multi-bit signal, where all bits are required
                    to be asserted to indicate "valid".
         * o_ready: output to next stage indicating readiness to accept data
-        * data_i : an input - MUST be added by the USER of this class
+        * i_data : an input - MUST be added by the USER of this class
     """
 
-    def __init__(self, i_width=1, stage_ctl=False, maskwid=0, offs=0):
+    def __init__(self, i_width=1, stage_ctl=False, maskwid=0, offs=0,
+                 name=None):
+        if name is None:
+            name = ""
+        n_piv = "p_i_valid"+name
+        n_por = "p_o_ready"+name
+
         self.stage_ctl = stage_ctl
         self.maskwid = maskwid
         if maskwid:
-            self.mask_i = Signal(maskwid)                # prev   >>in  self
-            self.stop_i = Signal(maskwid)                # prev   >>in  self
-        self.i_valid = Signal(i_width, name="p_i_valid") # prev   >>in  self
-        self._o_ready = Signal(name="p_o_ready")         # prev   <<out self
-        self.data_i = None # XXX MUST BE ADDED BY USER
+            self.mask_i = Signal(maskwid)              # prev   >>in  self
+            self.stop_i = Signal(maskwid)              # prev   >>in  self
+        self.i_valid = Signal(i_width, name=n_piv)     # prev   >>in  self
+        self._o_ready = Signal(name=n_por)             # prev   <<out self
+        self.i_data = None  # XXX MUST BE ADDED BY USER
         if stage_ctl:
             self.s_o_ready = Signal(name="p_s_o_rdy")    # prev   <<out self
         self.trigger = Signal(reset_less=True)
@@ -168,7 +174,7 @@ class PrevControl(Elaboratable):
         """ public-facing API: indicates (externally) that stage is ready
         """
         if self.stage_ctl:
-            return self.s_o_ready # set dynamically by stage
+            return self.s_o_ready  # set dynamically by stage
         return self._o_ready      # return this when not under dynamic control
 
     def _connect_in(self, prev, direct=False, fn=None,
@@ -185,8 +191,8 @@ class PrevControl(Elaboratable):
                 res.append(self.stop_i.eq(prev.stop_i))
         if do_data is False:
             return res
-        data_i = fn(prev.data_i) if fn is not None else prev.data_i
-        return res + [nmoperator.eq(self.data_i, data_i)]
+        i_data = fn(prev.i_data) if fn is not None else prev.i_data
+        return res + [nmoperator.eq(self.i_data, i_data)]
 
     @property
     def i_valid_test(self):
@@ -212,9 +218,9 @@ class PrevControl(Elaboratable):
         return m
 
     def eq(self, i):
-        res = [nmoperator.eq(self.data_i, i.data_i),
-                self.o_ready.eq(i.o_ready),
-                self.i_valid.eq(i.i_valid)]
+        res = [nmoperator.eq(self.i_data, i.i_data),
+               self.o_ready.eq(i.o_ready),
+               self.i_valid.eq(i.i_valid)]
         if self.maskwid:
             res.append(self.mask_i.eq(i.mask_i))
         return res
@@ -225,13 +231,13 @@ class PrevControl(Elaboratable):
         if self.maskwid:
             yield self.mask_i
             yield self.stop_i
-        if hasattr(self.data_i, "ports"):
-            yield from self.data_i.ports()
-        elif (isinstance(self.data_i, Sequence) or
-              isinstance(self.data_i, Iterable)):
-            yield from self.data_i
+        if hasattr(self.i_data, "ports"):
+            yield from self.i_data.ports()
+        elif (isinstance(self.i_data, Sequence) or
+              isinstance(self.i_data, Iterable)):
+            yield from self.i_data
         else:
-            yield self.data_i
+            yield self.i_data
 
     def ports(self):
         return list(self)
@@ -241,19 +247,25 @@ class NextControl(Elaboratable):
     """ contains the signals that go *to* the next stage (both in and out)
         * o_valid: output indicating to next stage that data is valid
         * i_ready: input from next stage indicating that it can accept data
-        * data_o : an output - MUST be added by the USER of this class
+        * o_data : an output - MUST be added by the USER of this class
     """
-    def __init__(self, stage_ctl=False, maskwid=0):
+
+    def __init__(self, stage_ctl=False, maskwid=0, name=None):
+        if name is None:
+            name = ""
+        n_nov = "n_o_valid"+name
+        n_nir = "n_i_ready"+name
+
         self.stage_ctl = stage_ctl
         self.maskwid = maskwid
         if maskwid:
-            self.mask_o = Signal(maskwid)       # self out>>  next
-            self.stop_o = Signal(maskwid)       # self out>>  next
-        self.o_valid = Signal(name="n_o_valid") # self out>>  next
-        self.i_ready = Signal(name="n_i_ready") # self <<in   next
-        self.data_o = None # XXX MUST BE ADDED BY USER
-        #if self.stage_ctl:
-        self.d_valid = Signal(reset=1) # INTERNAL (data valid)
+            self.mask_o = Signal(maskwid)  # self out>>  next
+            self.stop_o = Signal(maskwid)  # self out>>  next
+        self.o_valid = Signal(name=n_nov)  # self out>>  next
+        self.i_ready = Signal(name=n_nir)  # self <<in   next
+        self.o_data = None  # XXX MUST BE ADDED BY USER
+        # if self.stage_ctl:
+        self.d_valid = Signal(reset=1)  # INTERNAL (data valid)
         self.trigger = Signal(reset_less=True)
 
     @property
@@ -277,9 +289,9 @@ class NextControl(Elaboratable):
             if do_stop:
                 res.append(nxt.stop_i.eq(self.stop_o))
         if do_data:
-            res.append(nmoperator.eq(nxt.data_i, self.data_o))
-        print ("connect to next", self, self.maskwid, nxt.data_i,
-                                  do_data, do_stop)
+            res.append(nmoperator.eq(nxt.i_data, self.o_data))
+        print("connect to next", self, self.maskwid, nxt.i_data,
+              do_data, do_stop)
         return res
 
     def _connect_out(self, nxt, direct=False, fn=None,
@@ -296,8 +308,8 @@ class NextControl(Elaboratable):
                 res.append(nxt.stop_o.eq(self.stop_o))
         if not do_data:
             return res
-        data_o = fn(nxt.data_o) if fn is not None else nxt.data_o
-        return res + [nmoperator.eq(data_o, self.data_o)]
+        o_data = fn(nxt.o_data) if fn is not None else nxt.o_data
+        return res + [nmoperator.eq(o_data, self.o_data)]
 
     def elaborate(self, platform):
         m = Module()
@@ -310,14 +322,13 @@ class NextControl(Elaboratable):
         if self.maskwid:
             yield self.mask_o
             yield self.stop_o
-        if hasattr(self.data_o, "ports"):
-            yield from self.data_o.ports()
-        elif (isinstance(self.data_o, Sequence) or
-              isinstance(self.data_o, Iterable)):
-            yield from self.data_o
+        if hasattr(self.o_data, "ports"):
+            yield from self.o_data.ports()
+        elif (isinstance(self.o_data, Sequence) or
+              isinstance(self.o_data, Iterable)):
+            yield from self.o_data
         else:
-            yield self.data_o
+            yield self.o_data
 
     def ports(self):
         return list(self)
-
index 908c15cecf7ea77c955047934fa59704481adb9a..e2d7541d396fa7468668f4bd903346abd504a94f 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -45,8 +46,9 @@ def latchregister(m, incoming, outgoing, settrue, name=None):
     else:
         reg = Signal.like(incoming, name=name)
     m.d.comb += outgoing.eq(Mux(settrue, incoming, reg))
-    with m.If(settrue): # pass in some kind of expression/condition here
+    with m.If(settrue):  # pass in some kind of expression/condition here
         m.d.sync += reg.eq(incoming)      # latch input into register
+    return reg
 
 
 def mkname(prefix, suffix):
@@ -61,24 +63,27 @@ class SRLatch(Elaboratable):
         self.llen = llen
         s_n, r_n = mkname("s", name), mkname("r", name)
         q_n, qn_n = mkname("q", name), mkname("qn", name)
+        qint = mkname("qint", name)
         qlq_n = mkname("qlq", name)
         self.s = Signal(llen, name=s_n, reset=0)
-        self.r = Signal(llen, name=r_n, reset=(1<<llen)-1) # defaults to off
+        self.r = Signal(llen, name=r_n, reset=(1 << llen)-1)  # defaults to off
         self.q = Signal(llen, name=q_n, reset_less=True)
         self.qn = Signal(llen, name=qn_n, reset_less=True)
         self.qlq = Signal(llen, name=qlq_n, reset_less=True)
+        self.q_int = Signal(llen, name=qint, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
-        q_int = Signal(self.llen)
 
-        m.d.sync += q_int.eq((q_int & ~self.r) | self.s)
+        next_o = Signal(self.llen, reset_less=True)
+        m.d.comb += next_o.eq((self.q_int & ~self.r) | self.s)
+        m.d.sync += self.q_int.eq(next_o)
         if self.sync:
-            m.d.comb += self.q.eq(q_int)
+            m.d.comb += self.q.eq(self.q_int)
         else:
-            m.d.comb += self.q.eq((q_int & ~self.r) | self.s)
+            m.d.comb += self.q.eq(next_o)
         m.d.comb += self.qn.eq(~self.q)
-        m.d.comb += self.qlq.eq(self.q | q_int) # useful output
+        m.d.comb += self.qlq.eq(self.q | self.q_int)  # useful output
 
         return m
 
@@ -109,6 +114,7 @@ def sr_sim(dut):
     yield
     yield
 
+
 def test_sr():
     dut = SRLatch(llen=4)
     vl = rtlil.convert(dut, ports=dut.ports())
@@ -124,5 +130,6 @@ def test_sr():
 
     run_simulation(dut, sr_sim(dut), vcd_name='test_srlatch_async.vcd')
 
+
 if __name__ == '__main__':
     test_sr()
diff --git a/src/nmutil/lut.py b/src/nmutil/lut.py
new file mode 100644 (file)
index 0000000..755747a
--- /dev/null
@@ -0,0 +1,213 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay
+# Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+"""Bitwise logic operators implemented using a look-up table, like LUTs in
+FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
+
+https://bugs.libre-soc.org/show_bug.cgi?id=745
+https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
+"""
+
+from nmigen.hdl.ast import Array, Cat, Repl, Signal
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ir import Elaboratable
+from nmigen.cli import rtlil
+from nmutil.plain_data import plain_data
+
+
+class BitwiseMux(Elaboratable):
+    """Mux, but treating input/output Signals as bit vectors, rather than
+    integers. This means each bit in the output is independently multiplexed
+    based on the corresponding bit in each of the inputs.
+    """
+
+    def __init__(self, width):
+        self.sel = Signal(width)
+        self.t = Signal(width)
+        self.f = Signal(width)
+        self.output = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
+        return m
+
+
+class BitwiseLut(Elaboratable):
+    """Bitwise logic operators implemented using a look-up table, like LUTs in
+    FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
+
+    Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
+    """
+
+    def __init__(self, input_count, width):
+        """
+        input_count: int
+            the number of inputs. ternlog-style instructions have 3 inputs.
+        width: int
+            the number of bits in each input/output.
+        """
+        self.input_count = input_count
+        self.width = width
+
+        def inp(i):
+            return Signal(width, name=f"input{i}")
+        self.inputs = tuple(inp(i) for i in range(input_count))  # inputs
+        self.lut = Signal(2 ** input_count)                     # lookup input
+        self.output = Signal(width)                             # output
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        lut_array = Array(self.lut)  # create dynamic-indexable LUT array
+        out = []
+
+        for bit in range(self.width):
+            # take the bit'th bit of every input, create a LUT index from it
+            index = Signal(self.input_count, name="index%d" % bit)
+            comb += index.eq(Cat(inp[bit] for inp in self.inputs))
+            # store output bit in a list - Cat() it after (simplifies graphviz)
+            outbit = Signal(name="out%d" % bit)
+            comb += outbit.eq(lut_array[index])
+            out.append(outbit)
+
+        # finally Cat() all the output bits together
+        comb += self.output.eq(Cat(*out))
+        return m
+
+    def ports(self):
+        return list(self.inputs) + [self.lut, self.output]
+
+
+@plain_data()
+class _TreeMuxNode:
+    """Mux in tree for `TreeBitwiseLut`.
+
+    Attributes:
+    out: Signal
+    container: TreeBitwiseLut
+    parent: _TreeMuxNode | None
+    child0: _TreeMuxNode | None
+    child1: _TreeMuxNode | None
+    depth: int
+    """
+    __slots__ = "out", "container", "parent", "child0", "child1", "depth"
+
+    def __init__(self, out, container, parent, child0, child1, depth):
+        """ Arguments:
+        out: Signal
+        container: TreeBitwiseLut
+        parent: _TreeMuxNode | None
+        child0: _TreeMuxNode | None
+        child1: _TreeMuxNode | None
+        depth: int
+        """
+        self.out = out
+        self.container = container
+        self.parent = parent
+        self.child0 = child0
+        self.child1 = child1
+        self.depth = depth
+
+    @property
+    def child_index(self):
+        """index of this node, when looked up in this node's parent's children.
+        """
+        if self.parent is None:
+            return None
+        return int(self.parent.child1 is self)
+
+    def add_child(self, child_index):
+        node = _TreeMuxNode(
+            out=Signal(self.container.width),
+            container=self.container, parent=self,
+            child0=None, child1=None, depth=1 + self.depth)
+        if child_index:
+            assert self.child1 is None
+            self.child1 = node
+        else:
+            assert self.child0 is None
+            self.child0 = node
+        node.out.name = "node_out_" + node.key_str
+        return node
+
+    @property
+    def key(self):
+        retval = []
+        node = self
+        while node.parent is not None:
+            retval.append(node.child_index)
+            node = node.parent
+        retval.reverse()
+        return retval
+
+    @property
+    def key_str(self):
+        k = ['x'] * self.container.input_count
+        for i, v in enumerate(self.key):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+
+class TreeBitwiseLut(Elaboratable):
+    """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+    for API documentation. This version may produce more efficient hardware.
+    """
+
+    def __init__(self, input_count, width):
+        self.input_count = input_count
+        self.width = width
+
+        def inp(i):
+            return Signal(width, name=f"input{i}")
+        self.inputs = tuple(inp(i) for i in range(input_count))
+        self.output = Signal(width)
+        self.lut = Signal(2 ** input_count)
+        self._tree_root = _TreeMuxNode(
+            out=self.output, container=self, parent=None,
+            child0=None, child1=None, depth=0)
+        self._build_tree(self._tree_root)
+
+    def _build_tree(self, node):
+        if node.depth < self.input_count:
+            self._build_tree(node.add_child(0))
+            self._build_tree(node.add_child(1))
+
+    def _elaborate_tree(self, m, node):
+        if node.depth < self.input_count:
+            mux = BitwiseMux(self.width)
+            setattr(m.submodules, "mux_" + node.key_str, mux)
+            m.d.comb += [
+                mux.f.eq(node.child0.out),
+                mux.t.eq(node.child1.out),
+                mux.sel.eq(self.inputs[node.depth]),
+                node.out.eq(mux.output),
+            ]
+            self._elaborate_tree(m, node.child0)
+            self._elaborate_tree(m, node.child1)
+        else:
+            index = int(node.key_str, base=2)
+            m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
+
+    def elaborate(self, platform):
+        m = Module()
+        self._elaborate_tree(m, self._tree_root)
+        return m
+
+    def ports(self):
+        return [*self.inputs, self.lut, self.output]
+
+
+# useful to see what is going on:
+# python3 src/nmutil/test/test_lut.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"
+
+if __name__ == '__main__':
+    dut = BitwiseLut(2, 64)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_lut2.il", "w") as f:
+        f.write(vl)
index e98ca3b7db40e135bbf3ecf75f3bb5a2d0f8fbbe..3bd69f9b6c1a82e8338039d14821cc0e1b7c33d3 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -9,6 +10,7 @@
 from nmigen import Module, Signal, Elaboratable
 from nmigen.utils import log2_int
 
+
 def masked(m_out, m_in, mask):
     return (m_out & ~mask) | (m_in & mask)
 
@@ -16,15 +18,14 @@ def masked(m_out, m_in, mask):
 class Mask(Elaboratable):
     def __init__(self, sz):
         self.sz = sz
-        self.shift = Signal(log2_int(sz, False)+1)
+        self.shift = Signal(sz.bit_length()+1)
         self.mask = Signal(sz)
 
     def elaborate(self, platform):
         m = Module()
 
         for i in range(self.sz):
-            with m.If(self.shift > i):
+            with m.If(i < self.shift):
                 m.d.comb += self.mask[i].eq(1)
 
         return m
-
index 18315e31ec4b01821fc1e743d3b1e6ead1a1bcf8..0b91f8baf565528a10bd9297ba423674289b104b 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ Combinatorial Multi-input and Multi-output multiplexer blocks
     conforming to Pipeline API
 
@@ -32,6 +33,7 @@ from nmutil.iocontrol import NextControl, PrevControl
 class MultiInControlBase(Elaboratable):
     """ Common functions for Pipeline API
     """
+
     def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
         """ Multi-input Control class.  Conforms to same API as ControlBase...
             mostly.  has additional indices to the *multiple* input stages
@@ -40,21 +42,21 @@ class MultiInControlBase(Elaboratable):
             * n: contains ready/valid to the next stage
 
             User must also:
-            * add data_i members to PrevControl and
-            * add data_o member  to NextControl
+            * add i_data members to PrevControl and
+            * add o_data member  to NextControl
         """
         self.routemask = routemask
         # set up input and output IO ACK (prev/next ready/valid)
-        print ("multi_in", self, maskwid, p_len, routemask)
+        print("multi_in", self, maskwid, p_len, routemask)
         p = []
         for i in range(p_len):
             p.append(PrevControl(in_multi, maskwid=maskwid))
         self.p = Array(p)
         if routemask:
-            nmaskwid = maskwid # straight route mask mode
+            nmaskwid = maskwid  # straight route mask mode
         else:
-            nmaskwid = maskwid * p_len # fan-in mode
-        self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
+            nmaskwid = maskwid * p_len  # fan-in mode
+        self.n = NextControl(maskwid=nmaskwid)  # masks fan in (Cat)
 
     def connect_to_next(self, nxt, p_idx=0):
         """ helper function to connect to the next stage data/valid/ready.
@@ -80,7 +82,7 @@ class MultiInControlBase(Elaboratable):
     def set_input(self, i, idx=0):
         """ helper function to set the input data
         """
-        return eq(self.p[idx].data_i, i)
+        return eq(self.p[idx].i_data, i)
 
     def elaborate(self, platform):
         m = Module()
@@ -101,6 +103,7 @@ class MultiInControlBase(Elaboratable):
 class MultiOutControlBase(Elaboratable):
     """ Common functions for Pipeline API
     """
+
     def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
         """ Multi-output Control class.  Conforms to same API as ControlBase...
             mostly.  has additional indices to the multiple *output* stages
@@ -110,14 +113,14 @@ class MultiOutControlBase(Elaboratable):
             * n: contains ready/valid to the next stages PLURAL
 
             User must also:
-            * add data_i member to PrevControl and
-            * add data_o members to NextControl
+            * add i_data member to PrevControl and
+            * add o_data members to NextControl
         """
 
         if routemask:
-            nmaskwid = maskwid # straight route mask mode
+            nmaskwid = maskwid  # straight route mask mode
         else:
-            nmaskwid = maskwid * n_len # fan-out mode
+            nmaskwid = maskwid * n_len  # fan-out mode
 
         # set up input and output IO ACK (prev/next ready/valid)
         self.p = PrevControl(in_multi, maskwid=nmaskwid)
@@ -155,7 +158,7 @@ class MultiOutControlBase(Elaboratable):
     def set_input(self, i):
         """ helper function to set the input data
         """
-        return eq(self.p.data_i, i)
+        return eq(self.p.i_data, i)
 
     def __iter__(self):
         yield from self.p
@@ -171,23 +174,23 @@ class CombMultiOutPipeline(MultiOutControlBase):
 
         Attributes:
         -----------
-        p.data_i : stage input data (non-array).  shaped according to ispec
-        n.data_o : stage output data array.       shaped according to ospec
+        p.i_data : stage input data (non-array).  shaped according to ispec
+        n.o_data : stage output data array.       shaped according to ospec
     """
 
     def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
         MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
-                                            routemask=routemask)
+                                     routemask=routemask)
         self.stage = stage
         self.maskwid = maskwid
         self.routemask = routemask
         self.n_mux = n_mux
 
         # set up the input and output data
-        self.p.data_i = _spec(stage.ispec, 'data_i') # input type
+        self.p.i_data = _spec(stage.ispec, 'i_data')  # input type
         for i in range(n_len):
-            name = 'data_o_%d' % i
-            self.n[i].data_o = _spec(stage.ospec, name) # output type
+            name = 'o_data_%d' % i
+            self.n[i].o_data = _spec(stage.ospec, name)  # output type
 
     def process(self, i):
         if hasattr(self.stage, "process"):
@@ -197,18 +200,18 @@ class CombMultiOutPipeline(MultiOutControlBase):
     def elaborate(self, platform):
         m = MultiOutControlBase.elaborate(self, platform)
 
-        if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
+        if hasattr(self.n_mux, "elaborate"):  # TODO: identify submodule?
             m.submodules.n_mux = self.n_mux
 
         # need buffer register conforming to *input* spec
-        r_data = _spec(self.stage.ispec, 'r_data') # input type
+        r_data = _spec(self.stage.ispec, 'r_data')  # input type
         if hasattr(self.stage, "setup"):
             self.stage.setup(m, r_data)
 
         # multiplexer id taken from n_mux
         muxid = self.n_mux.m_id
-        print ("self.n_mux", self.n_mux)
-        print ("self.n_mux.m_id", self.n_mux.m_id)
+        print("self.n_mux", self.n_mux)
+        print("self.n_mux.m_id", self.n_mux.m_id)
 
         self.n_mux.m_id.name = "m_id"
 
@@ -216,7 +219,7 @@ class CombMultiOutPipeline(MultiOutControlBase):
         p_i_valid = Signal(reset_less=True)
         pv = Signal(reset_less=True)
         m.d.comb += p_i_valid.eq(self.p.i_valid_test)
-        #m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
+        # m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
         m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
 
         # all outputs to next stages first initialised to zero (invalid)
@@ -224,29 +227,31 @@ class CombMultiOutPipeline(MultiOutControlBase):
         for i in range(len(self.n)):
             m.d.comb += self.n[i].o_valid.eq(0)
         if self.routemask:
-            #with m.If(pv):
+            # with m.If(pv):
             m.d.comb += self.n[muxid].o_valid.eq(pv)
             m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
         else:
             data_valid = self.n[muxid].o_valid
-            m.d.comb += self.p.o_ready.eq(~data_valid | self.n[muxid].i_ready)
-            m.d.comb += data_valid.eq(p_i_valid | \
-                                    (~self.n[muxid].i_ready & data_valid))
-
+            m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
+            m.d.comb += data_valid.eq(p_i_valid |
+                                      (~self.n[muxid].i_ready & data_valid))
 
         # send data on
-        #with m.If(pv):
-        m.d.comb += eq(r_data, self.p.data_i)
-        m.d.comb += eq(self.n[muxid].data_o, self.process(r_data))
+        # with m.If(pv):
+        m.d.comb += eq(r_data, self.p.i_data)
+        #m.d.comb += eq(self.n[muxid].o_data, self.process(r_data))
+        for i in range(len(self.n)):
+            with m.If(muxid == i):
+                m.d.comb += eq(self.n[i].o_data, self.process(r_data))
 
         if self.maskwid:
-            if self.routemask: # straight "routing" mode - treat like data
+            if self.routemask:  # straight "routing" mode - treat like data
                 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
                 with m.If(pv):
                     m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
             else:
-                ml = [] # accumulate output masks
-                ms = [] # accumulate output stops
+                ml = []  # accumulate output masks
+                ms = []  # accumulate output stops
                 # fan-out mode.
                 # conditionally fan-out mask bits, always fan-out stop bits
                 for i in range(len(self.n)):
@@ -263,9 +268,9 @@ class CombMultiInPipeline(MultiInControlBase):
 
         Attributes:
         -----------
-        p.data_i : StageInput, shaped according to ispec
+        p.i_data : StageInput, shaped according to ispec
             The pipeline input
-        p.data_o : StageOutput, shaped according to ospec
+        p.o_data : StageOutput, shaped according to ospec
             The pipeline output
         r_data : input_shape according to ispec
             A temporary (buffered) copy of a prior (valid) input.
@@ -275,16 +280,16 @@ class CombMultiInPipeline(MultiInControlBase):
 
     def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
         MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
-                                          routemask=routemask)
+                                    routemask=routemask)
         self.stage = stage
         self.maskwid = maskwid
         self.p_mux = p_mux
 
         # set up the input and output data
         for i in range(p_len):
-            name = 'data_i_%d' % i
-            self.p[i].data_i = _spec(stage.ispec, name) # input type
-        self.n.data_o = _spec(stage.ospec, 'data_o')
+            name = 'i_data_%d' % i
+            self.p[i].i_data = _spec(stage.ispec, name)  # input type
+        self.n.o_data = _spec(stage.ospec, 'o_data')
 
     def process(self, i):
         if hasattr(self.stage, "process"):
@@ -304,16 +309,15 @@ class CombMultiInPipeline(MultiInControlBase):
         p_len = len(self.p)
         for i in range(p_len):
             name = 'r_%d' % i
-            r = _spec(self.stage.ispec, name) # input type
+            r = _spec(self.stage.ispec, name)  # input type
             r_data.append(r)
             data_valid.append(Signal(name="data_valid", reset_less=True))
             p_i_valid.append(Signal(name="p_i_valid", reset_less=True))
             n_i_readyn.append(Signal(name="n_i_readyn", reset_less=True))
             if hasattr(self.stage, "setup"):
-                print ("setup", self, self.stage, r)
+                print("setup", self, self.stage, r)
                 self.stage.setup(m, r)
-        if len(r_data) > 1:
-            r_data = Array(r_data)
+        if True:  # len(r_data) > 1: # hmm always create an Array even of len 1
             p_i_valid = Array(p_i_valid)
             n_i_readyn = Array(n_i_readyn)
             data_valid = Array(data_valid)
@@ -321,7 +325,7 @@ class CombMultiInPipeline(MultiInControlBase):
         nirn = Signal(reset_less=True)
         m.d.comb += nirn.eq(~self.n.i_ready)
         mid = self.p_mux.m_id
-        print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
+        print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
         for i in range(p_len):
             m.d.comb += data_valid[i].eq(0)
             m.d.comb += n_i_readyn[i].eq(1)
@@ -334,7 +338,8 @@ class CombMultiInPipeline(MultiInControlBase):
             m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
         else:
             m.d.comb += maskedout.eq(1)
-        m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
+        m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active &
+                                      self.p[mid].i_valid)
         m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
         m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
         anyvalid = Signal(i, reset_less=True)
@@ -343,8 +348,8 @@ class CombMultiInPipeline(MultiInControlBase):
             av.append(data_valid[i])
         anyvalid = Cat(*av)
         m.d.comb += self.n.o_valid.eq(anyvalid.bool())
-        m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
-                                    (n_i_readyn[mid] ))
+        m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
+                                       (n_i_readyn[mid]))
 
         if self.routemask:
             # XXX hack - fixes loop
@@ -361,10 +366,12 @@ class CombMultiInPipeline(MultiInControlBase):
                 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
                 with m.If(vr):
                     m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
-                    m.d.comb += eq(r_data[i], self.p[i].data_i)
+                    m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
+                    with m.If(mid == i):
+                        m.d.comb += eq(self.n.o_data, r_data[i])
         else:
-            ml = [] # accumulate output masks
-            ms = [] # accumulate output stops
+            ml = []  # accumulate output masks
+            ms = []  # accumulate output stops
             for i in range(p_len):
                 vr = Signal(reset_less=True)
                 p = self.p[i]
@@ -376,7 +383,9 @@ class CombMultiInPipeline(MultiInControlBase):
                     m.d.comb += maskedout.eq(1)
                 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
                 with m.If(vr):
-                    m.d.comb += eq(r_data[i], self.p[i].data_i)
+                    m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
+                    with m.If(mid == i):
+                        m.d.comb += eq(self.n.o_data, r_data[i])
                 if self.maskwid:
                     mlen = len(self.p[i].mask_i)
                     s = mlen*i
@@ -387,7 +396,8 @@ class CombMultiInPipeline(MultiInControlBase):
                 m.d.comb += self.n.mask_o.eq(Cat(*ml))
                 m.d.comb += self.n.stop_o.eq(Cat(*ms))
 
-        m.d.comb += eq(self.n.data_o, self.process(r_data[mid]))
+        #print ("o_data", self.n.o_data, "r_data[mid]", mid, r_data[mid])
+        #m.d.comb += eq(self.n.o_data, r_data[mid])
 
         return m
 
@@ -397,9 +407,9 @@ class NonCombMultiInPipeline(MultiInControlBase):
 
         Attributes:
         -----------
-        p.data_i : StageInput, shaped according to ispec
+        p.i_data : StageInput, shaped according to ispec
             The pipeline input
-        p.data_o : StageOutput, shaped according to ospec
+        p.o_data : StageOutput, shaped according to ospec
             The pipeline output
         r_data : input_shape according to ispec
             A temporary (buffered) copy of a prior (valid) input.
@@ -409,16 +419,16 @@ class NonCombMultiInPipeline(MultiInControlBase):
 
     def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
         MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
-                                          routemask=routemask)
+                                    routemask=routemask)
         self.stage = stage
         self.maskwid = maskwid
         self.p_mux = p_mux
 
         # set up the input and output data
         for i in range(p_len):
-            name = 'data_i_%d' % i
-            self.p[i].data_i = _spec(stage.ispec, name) # input type
-        self.n.data_o = _spec(stage.ospec, 'data_o')
+            name = 'i_data_%d' % i
+            self.p[i].i_data = _spec(stage.ispec, name)  # input type
+        self.n.o_data = _spec(stage.ospec, 'o_data')
 
     def process(self, i):
         if hasattr(self.stage, "process"):
@@ -437,12 +447,12 @@ class NonCombMultiInPipeline(MultiInControlBase):
         p_len = len(self.p)
         for i in range(p_len):
             name = 'r_%d' % i
-            r = _spec(self.stage.ispec, name) # input type
+            r = _spec(self.stage.ispec, name)  # input type
             r_data.append(r)
             r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
             p_i_valid.append(Signal(name="p_i_valid%d" % i, reset_less=True))
             if hasattr(self.stage, "setup"):
-                print ("setup", self, self.stage, r)
+                print("setup", self, self.stage, r)
                 self.stage.setup(m, r)
         if len(r_data) > 1:
             r_data = Array(r_data)
@@ -452,7 +462,7 @@ class NonCombMultiInPipeline(MultiInControlBase):
         nirn = Signal(reset_less=True)
         m.d.comb += nirn.eq(~self.n.i_ready)
         mid = self.p_mux.m_id
-        print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
+        print("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
         for i in range(p_len):
             m.d.comb += r_busy[i].eq(0)
             m.d.comb += n_i_readyn[i].eq(1)
@@ -473,8 +483,8 @@ class NonCombMultiInPipeline(MultiInControlBase):
             av.append(data_valid[i])
         anyvalid = Cat(*av)
         m.d.comb += self.n.o_valid.eq(anyvalid.bool())
-        m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
-                                    (n_i_readyn[mid] ))
+        m.d.comb += data_valid[mid].eq(p_i_valid[mid] |
+                                       (n_i_readyn[mid]))
 
         if self.routemask:
             # XXX hack - fixes loop
@@ -491,10 +501,10 @@ class NonCombMultiInPipeline(MultiInControlBase):
                 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
                 with m.If(vr):
                     m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
-                    m.d.comb += eq(r_data[i], self.p[i].data_i)
+                    m.d.comb += eq(r_data[i], self.p[i].i_data)
         else:
-            ml = [] # accumulate output masks
-            ms = [] # accumulate output stops
+            ml = []  # accumulate output masks
+            ms = []  # accumulate output stops
             for i in range(p_len):
                 vr = Signal(reset_less=True)
                 p = self.p[i]
@@ -506,7 +516,7 @@ class NonCombMultiInPipeline(MultiInControlBase):
                     m.d.comb += maskedout.eq(1)
                 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
                 with m.If(vr):
-                    m.d.comb += eq(r_data[i], self.p[i].data_i)
+                    m.d.comb += eq(r_data[i], self.p[i].i_data)
                 if self.maskwid:
                     mlen = len(self.p[i].mask_i)
                     s = mlen*i
@@ -517,42 +527,42 @@ class NonCombMultiInPipeline(MultiInControlBase):
                 m.d.comb += self.n.mask_o.eq(Cat(*ml))
                 m.d.comb += self.n.stop_o.eq(Cat(*ms))
 
-        m.d.comb += eq(self.n.data_o, self.process(r_data[mid]))
+        m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
 
         return m
 
 
 class CombMuxOutPipe(CombMultiOutPipeline):
     def __init__(self, stage, n_len, maskwid=0, muxidname=None,
-                                     routemask=False):
+                 routemask=False):
         muxidname = muxidname or "muxid"
         # HACK: stage is also the n-way multiplexer
         CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
-                                            n_mux=stage, maskwid=maskwid,
-                                            routemask=routemask)
+                                      n_mux=stage, maskwid=maskwid,
+                                      routemask=routemask)
 
         # HACK: n-mux is also the stage... so set the muxid equal to input muxid
-        muxid = getattr(self.p.data_i, muxidname)
-        print ("combmuxout", muxidname, muxid)
+        muxid = getattr(self.p.i_data, muxidname)
+        print("combmuxout", muxidname, muxid)
         stage.m_id = muxid
 
 
-
 class InputPriorityArbiter(Elaboratable):
     """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
     """
+
     def __init__(self, pipe, num_rows):
         self.pipe = pipe
         self.num_rows = num_rows
         self.mmax = int(log(self.num_rows) / log(2))
-        self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
+        self.m_id = Signal(self.mmax, reset_less=True)  # multiplex id
         self.active = Signal(reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
 
         assert len(self.pipe.p) == self.num_rows, \
-                "must declare input to be same size"
+            "must declare input to be same size"
         pe = PriorityEncoder(self.num_rows)
         m.submodules.selector = pe
 
@@ -568,7 +578,7 @@ class InputPriorityArbiter(Elaboratable):
             else:
                 m.d.comb += p_i_valid.eq(self.pipe.p[i].i_valid_test)
             in_ready.append(p_i_valid)
-        m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
+        m.d.comb += pe.i.eq(Cat(*in_ready))  # array of input "valids"
         m.d.comb += self.active.eq(~pe.n)   # encoder active (one input valid)
         m.d.comb += self.m_id.eq(pe.o)       # output one active input
 
@@ -578,7 +588,6 @@ class InputPriorityArbiter(Elaboratable):
         return [self.m_id, self.active]
 
 
-
 class PriorityCombMuxInPipe(CombMultiInPipeline):
     """ an example of how to use the combinatorial pipeline.
     """
index 8d0aafff9e32467f95c3ee7f670c178de51fbe18..5163bff740245a55e96eabaf5eb8c0f8e56e1b33 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ nmigen operator functions / utils
 
     This work is funded through NLnet under Grant 2019-02-012
@@ -41,6 +42,7 @@ class Visitor2:
         python object, enumerate them, find out the list of Signals that way,
         and assign them.
     """
+
     def iterator2(self, o, i):
         if isinstance(o, dict):
             yield from self.dict_iter2(o, i)
@@ -48,22 +50,30 @@ class Visitor2:
         if not isinstance(o, Sequence):
             o, i = [o], [i]
         for (ao, ai) in zip(o, i):
-            #print ("visit", fn, ao, ai)
+            # print ("visit", ao, ai)
+            # print ("    isinstance Record(ao)", isinstance(ao, Record))
+            # print ("    isinstance ArrayProxy(ao)",
+            #            isinstance(ao, ArrayProxy))
+            # print ("    isinstance Value(ai)",
+            #            isinstance(ai, Value))
             if isinstance(ao, Record):
                 yield from self.record_iter2(ao, ai)
             elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
                 yield from self.arrayproxy_iter2(ao, ai)
+            elif isinstance(ai, ArrayProxy) and not isinstance(ao, Value):
+                assert False, "whoops, input ArrayProxy not supported yet"
+                yield from self.arrayproxy_iter3(ao, ai)
             else:
                 yield (ao, ai)
 
     def dict_iter2(self, o, i):
         for (k, v) in o.items():
-            print ("d-iter", v, i[k])
+            print ("d-iter", v, i[k])
             yield (v, i[k])
         return res
 
     def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
-        print ("record_iter2", ao, ai, type(ao), type(ai))
+        print ("record_iter2", ao, ai, type(ao), type(ai))
         if isinstance(ai, Value):
             if isinstance(ao, Sequence):
                 ao, ai = [ao], [ai]
@@ -75,10 +85,10 @@ class Visitor2:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             yield from self.iterator2(ao.fields[field_name], val)
 
     def record_iter2(self, ao, ai):
@@ -87,16 +97,23 @@ class Visitor2:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             yield from self.iterator2(ao.fields[field_name], val)
 
     def arrayproxy_iter2(self, ao, ai):
-        #print ("arrayproxy_iter2", ai.ports(), ai, ao)
+        # print ("arrayproxy_iter2", ai.ports(), ai, ao)
         for p in ai.ports():
-            #print ("arrayproxy - p", p, p.name, ao)
+            # print ("arrayproxy - p", p, p.name, ao)
+            op = getattr(ao, p.name)
+            yield from self.iterator2(op, p)
+
+    def arrayproxy_iter3(self, ao, ai):
+        # print ("arrayproxy_iter3", ao.ports(), ai, ao)
+        for p in ao.ports():
+            # print ("arrayproxy - p", p, p.name, ao)
             op = getattr(ao, p.name)
             yield from self.iterator2(op, p)
 
@@ -105,6 +122,7 @@ class Visitor:
     """ a helper class for iterating single-argument compound data structures.
         similar to Visitor2.
     """
+
     def iterate(self, i):
         """ iterate a compound structure recursively using yield
         """
@@ -126,10 +144,10 @@ class Visitor:
                 val = ai.fields
             else:
                 val = ai
-            if hasattr(val, field_name): # check for attribute
+            if hasattr(val, field_name):  # check for attribute
                 val = getattr(val, field_name)
             else:
-                val = val[field_name] # dictionary-style specification
+                val = val[field_name]  # dictionary-style specification
             #print ("recidx", idx, field_name, field_shape, val)
             yield from self.iterate(val)
 
@@ -166,8 +184,6 @@ def cat(i):
     """ flattens a compound structure recursively using Cat
     """
     from nmigen._utils import flatten
-    #res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
-    res = list(Visitor().iterate(i)) # needed because input may be a sequence
+    # res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
+    res = list(Visitor().iterate(i))  # needed because input may be a sequence
     return Cat(*res)
-
-
index ad7eb09d0281c600719870a2bd49adfd0fcdd8b1..01bc52ad19ed2c92d65ef6c7e6634f4c3364f7ea 100644 (file)
@@ -1,28 +1,33 @@
-import inspect, types
+import inspect
+import types
 
 ############## preliminary: two utility functions #####################
 
+
 def skip_redundant(iterable, skipset=None):
-   "Redundant items are repeated items or items in the original skipset."
-   if skipset is None: skipset = set()
-   for item in iterable:
-       if item not in skipset:
-           skipset.add(item)
-           yield item
+    "Redundant items are repeated items or items in the original skipset."
+    if skipset is None:
+        skipset = set()
+    for item in iterable:
+        if item not in skipset:
+            skipset.add(item)
+            yield item
 
 
 def remove_redundant(metaclasses):
-   skipset = set([type])
-   for meta in metaclasses: # determines the metaclasses to be skipped
-       skipset.update(inspect.getmro(meta)[1:])
-   return tuple(skip_redundant(metaclasses, skipset))
+    skipset = set([type])
+    for meta in metaclasses:  # determines the metaclasses to be skipped
+        skipset.update(inspect.getmro(meta)[1:])
+    return tuple(skip_redundant(metaclasses, skipset))
 
 ##################################################################
 ## now the core of the module: two mutually recursive functions ##
 ##################################################################
 
+
 memoized_metaclasses_map = {}
 
+
 def get_noconflict_metaclass(bases, left_metas, right_metas):
     """Not intended to be used outside of this module, unless you know
     what you are doing."""
@@ -32,24 +37,25 @@ def get_noconflict_metaclass(bases, left_metas, right_metas):
 
     # return existing confict-solving meta, if any
     if needed_metas in memoized_metaclasses_map:
-      return memoized_metaclasses_map[needed_metas]
+        return memoized_metaclasses_map[needed_metas]
     # nope: compute, memoize and return needed conflict-solving meta
     elif not needed_metas:         # wee, a trivial case, happy us
         meta = type
-    elif len(needed_metas) == 1: # another trivial case
-       meta = needed_metas[0]
+    elif len(needed_metas) == 1:  # another trivial case
+        meta = needed_metas[0]
     # check for recursion, can happen i.e. for Zope ExtensionClasses
-    elif needed_metas == bases: 
+    elif needed_metas == bases:
         raise TypeError("Incompatible root metatypes", needed_metas)
-    else: # gotta work ...
+    else:  # gotta work ...
         metaname = '_' + ''.join([m.__name__ for m in needed_metas])
         meta = classmaker()(metaname, needed_metas, {})
     memoized_metaclasses_map[needed_metas] = meta
     return meta
 
+
 def classmaker(left_metas=(), right_metas=()):
     def make_class(name, bases, adict):
-        print ("make_class", name)
+        print("make_class", name)
         metaclass = get_noconflict_metaclass(bases, left_metas, right_metas)
         return metaclass(name, bases, adict)
     return make_class
diff --git a/src/nmutil/p_lru.txt b/src/nmutil/p_lru.txt
deleted file mode 100644 (file)
index 4bac768..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-pseudo-LRU
-
-two-way set associative - one bit
-
-   indicates which line of the two has been reference more recently
-
-
-four-way set associative - three bits
-
-   each bit represents one branch point in a binary decision tree; let 1
-   represent that the left side has been referenced more recently than the
-   right side, and 0 vice-versa
-
-              are all 4 lines valid?
-                   /       \
-                 yes        no, use an invalid line
-                  |
-                  |
-                  |
-             bit_0 == 0?            state | replace      ref to | next state
-              /       \             ------+--------      -------+-----------
-             y         n             00x  |  line_0      line_0 |    11_
-            /           \            01x  |  line_1      line_1 |    10_
-     bit_1 == 0?    bit_2 == 0?      1x0  |  line_2      line_2 |    0_1
-       /    \          /    \        1x1  |  line_3      line_3 |    0_0
-      y      n        y      n
-     /        \      /        \        ('x' means       ('_' means unchanged)
-   line_0  line_1  line_2  line_3      don't care)
-
-   (see Figure 3-7, p. 3-18, in Intel Embedded Pentium Processor Family Dev.
-    Manual, 1998, http://www.intel.com/design/intarch/manuals/273204.htm)
-
-
-note that there is a 6-bit encoding for true LRU for four-way set associative
-
-  bit 0: bank[1] more recently used than bank[0]
-  bit 1: bank[2] more recently used than bank[0]
-  bit 2: bank[2] more recently used than bank[1]
-  bit 3: bank[3] more recently used than bank[0]
-  bit 4: bank[3] more recently used than bank[1]
-  bit 5: bank[3] more recently used than bank[2]
-
-  this results in 24 valid bit patterns within the 64 possible bit patterns
-  (4! possible valid traces for bank references)
-
-  e.g., a trace of 0 1 2 3, where 0 is LRU and 3 is MRU, is encoded as 111111
-
-  you can implement a state machine with a 256x6 ROM (6-bit state encoding
-  appended with a 2-bit bank reference input will yield a new 6-bit state),
-  and you can implement an LRU bank indicator with a 64x2 ROM
-
index 0babc074a9eeb9663793bcc4fd1bfdb259d631bc..7aab175d8d60c031020b7b208727156123b8c09f 100644 (file)
@@ -1,4 +1,5 @@
-""" Priority Picker: optimised back-to-back PriorityEncoder and Decoder
+# SPDX-License-Identifier: LGPL-3-or-later
+""" Priority Picker: optimized back-to-back PriorityEncoder and Decoder
     and MultiPriorityPicker: cascading mutually-exclusive pickers
 
     This work is funded through NLnet under Grant 2019-02-012
 """
 
 from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux
-from nmigen.cli import verilog, rtlil
+from nmigen.utils import bits_for
+from nmigen.cli import rtlil
 import math
+from nmutil.prefix_sum import prefix_sum
 
 
 class PriorityPicker(Elaboratable):
     """ implements a priority-picker.  input: N bits, output: N bits
 
-        * lsb_mode is for a LSB-priority picker
-        * reverse_i=True is for convenient reverseal of the input bits
+        * msb_mode is for a MSB-priority picker
+        * reverse_i=True is for convenient reversal of the input bits
         * reverse_o=True is for convenient reversal of the output bits
+        * `msb_mode=True` is redundant with `reverse_i=True, reverse_o=True`
+            but is allowed for backwards compatibility.
     """
-    def __init__(self, wid, lsb_mode=False, reverse_i=False, reverse_o=False):
+
+    def __init__(self, wid, msb_mode=False, reverse_i=False, reverse_o=False):
         self.wid = wid
         # inputs
-        self.lsb_mode = lsb_mode
+        self.msb_mode = msb_mode
         self.reverse_i = reverse_i
         self.reverse_o = reverse_o
         self.i = Signal(wid, reset_less=True)
         self.o = Signal(wid, reset_less=True)
-        self.en_o = Signal(reset_less=True) # true if any output is true
+
+        self.en_o = Signal(reset_less=True)
+        "true if any output is true"
 
     def elaborate(self, platform):
         m = Module()
 
         # works by saying, "if all previous bits were zero, we get a chance"
         res = []
-        ni = Signal(self.wid, reset_less = True)
+        ni = Signal(self.wid, reset_less=True)
         i = list(self.i)
         if self.reverse_i:
             i.reverse()
+        if self.msb_mode:
+            i.reverse()
         m.d.comb += ni.eq(~Cat(*i))
         prange = list(range(0, self.wid))
-        if self.lsb_mode:
+        if self.msb_mode:
             prange.reverse()
         for n in prange:
-            t = Signal(name="t%d" % n, reset_less = True)
+            t = Signal(name="t%d" % n, reset_less=True)
             res.append(t)
             if n == 0:
                 m.d.comb += t.eq(i[n])
@@ -71,7 +81,7 @@ class PriorityPicker(Elaboratable):
         # we like Cat(*xxx).  turn lists into concatenated bits
         m.d.comb += self.o.eq(Cat(*res))
         # useful "is any output enabled" signal
-        m.d.comb += self.en_o.eq(self.o.bool()) # true if 1 input is true
+        m.d.comb += self.en_o.eq(self.o.bool())  # true if 1 input is true
 
         return m
 
@@ -95,16 +105,16 @@ class MultiPriorityPicker(Elaboratable):
 
         Also outputted (optional): an index for each picked "thing".
     """
-    def __init__(self, wid, levels, indices=False, multiin=False):
+
+    def __init__(self, wid, levels, indices=False, multi_in=False):
         self.levels = levels
         self.wid = wid
         self.indices = indices
-        self.multiin = multiin
+        self.multi_in = multi_in
 
-
-        if multiin:
+        if multi_in:
             # multiple inputs, multiple outputs.
-            i_l = [] # array of picker outputs
+            i_l = []  # array of picker outputs
             for j in range(self.levels):
                 i = Signal(self.wid, name="i_%d" % j, reset_less=True)
                 i_l.append(i)
@@ -114,7 +124,7 @@ class MultiPriorityPicker(Elaboratable):
             self.i = Signal(self.wid, reset_less=True)
 
         # create array of (single-bit) outputs (unary)
-        o_l = [] # array of picker outputs
+        o_l = []  # array of picker outputs
         for j in range(self.levels):
             o = Signal(self.wid, name="o_%d" % j, reset_less=True)
             o_l.append(o)
@@ -128,7 +138,7 @@ class MultiPriorityPicker(Elaboratable):
 
         # add an array of indices
         lidx = math.ceil(math.log2(self.levels))
-        idx_o = [] # store the array of indices
+        idx_o = []  # store the array of indices
         for j in range(self.levels):
             i = Signal(lidx, name="idxo_%d" % j, reset_less=True)
             idx_o.append(i)
@@ -146,7 +156,7 @@ class MultiPriorityPicker(Elaboratable):
         p_mask = None
         pp_l = []
         for j in range(self.levels):
-            if self.multiin:
+            if self.multi_in:
                 i = self.i[j]
             else:
                 i = self.i
@@ -160,10 +170,10 @@ class MultiPriorityPicker(Elaboratable):
                 p_mask = Const(0, self.wid)
             else:
                 mask = Signal(self.wid, name="m_%d" % j, reset_less=True)
-                comb += mask.eq(prev_pp.o | p_mask) # accumulate output bits
+                comb += mask.eq(prev_pp.o | p_mask)  # accumulate output bits
                 comb += pp.i.eq(i & ~mask)          # mask out input
                 p_mask = mask
-            i = pp.i # for input to next round
+            i = pp.i  # for input to next round
             prev_pp = pp
 
         # accumulate the enables
@@ -178,34 +188,89 @@ class MultiPriorityPicker(Elaboratable):
 
         # for each picker enabled, pass that out and set a cascading index
         lidx = math.ceil(math.log2(self.levels))
-        prev_count = None
+        prev_count = 0
         for j in range(self.levels):
             en_o = pp_l[j].en_o
-            if prev_count is None:
-                comb += self.idx_o[j].eq(0)
-            else:
-                count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
-                comb += count1.eq(prev_count + Const(1, lidx))
-                comb += self.idx_o[j].eq(Mux(en_o, count1, prev_count))
-            prev_count = self.idx_o[j]
+            count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
+            comb += count1.eq(prev_count + Const(1, lidx))
+            comb += self.idx_o[j].eq(prev_count)
+            prev_count = Mux(en_o, count1, prev_count)
 
         return m
 
     def __iter__(self):
-        if self.multiin:
+        if self.multi_in:
             yield from self.i
         else:
             yield self.i
         yield from self.o
+        yield self.en_o
         if not self.indices:
             return
-        yield self.en_o
         yield from self.idx_o
 
     def ports(self):
         return list(self)
 
 
+class BetterMultiPriorityPicker(Elaboratable):
+    """A better replacement for MultiPriorityPicker that has O(log levels)
+        latency, rather than > O(levels) latency.
+    """
+
+    def __init__(self, width, levels, *, work_efficient=False):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(work_efficient, bool)
+        self.width = width
+        self.levels = levels
+        self.work_efficient = work_efficient
+        assert self.__index_sat > self.levels - 1
+        self.i = Signal(width)
+        self.o = [Signal(width, name=f"o_{i}") for i in range(levels)]
+        self.en_o = Signal(levels)
+
+    @property
+    def __index_width(self):
+        return bits_for(self.levels)
+
+    @property
+    def __index_sat(self):
+        return (1 << self.__index_width) - 1
+
+    def elaborate(self, platform):
+        m = Module()
+
+        def sat_add(a, b):
+            sum = Signal(self.__index_width + 1)
+            m.d.comb += sum.eq(a + b)
+            retval = Signal(self.__index_width)
+            m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum))
+            return retval
+        indexes = prefix_sum((self.i[i] for i in range(self.width - 1)),
+                             sat_add, work_efficient=self.work_efficient)
+        indexes.insert(0, 0)
+        for i in range(self.width):
+            sig = Signal(self.__index_width, name=f"index_{i}")
+            m.d.comb += sig.eq(indexes[i])
+            indexes[i] = sig
+        for level in range(self.levels):
+            m.d.comb += self.en_o[level].eq(self.o[level].bool())
+            for i in range(self.width):
+                index_matches = indexes[i] == level
+                m.d.comb += self.o[level][i].eq(index_matches & self.i[i])
+
+        return m
+
+    def __iter__(self):
+        yield self.i
+        yield from self.o
+        yield self.en_o
+
+    def ports(self):
+        return list(self)
+
+
 if __name__ == '__main__':
     dut = PriorityPicker(16)
     vl = rtlil.convert(dut, ports=dut.ports())
index 4646040071fc751c806363afb11c9bb956704cb7..8d5559a88d61d77b5dd6e33b9ac0044ae645872f 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -6,27 +7,26 @@
 
 """
 
-""" Example 5: Making use of PyRTL and Introspection. """
-
 from collections.abc import Sequence
-
 from nmigen import Signal
 from nmigen.hdl.rec import Record
 from nmigen import tracer
 from nmigen.compat.fhdl.bitcontainer import value_bits_sign
 from contextlib import contextmanager
-
 from nmutil.nmoperator import eq
 from nmutil.singlepipe import StageCls, ControlBase, BufferedHandshake
 from nmutil.singlepipe import UnbufferedPipeline
 
+""" Example 5: Making use of PyRTL and Introspection.
+
+    The following example shows how pyrtl can be used to make some interesting
+    hardware structures using python introspection.  In particular, this example
+    makes a N-stage pipeline structure.  Any specific pipeline is then a derived
+    class of SimplePipeline where methods with names starting with "stage" are
+    stages, and new members with names not starting with "_" are to be registered
+    for the next stage.
+"""
 
-# The following example shows how pyrtl can be used to make some interesting
-# hardware structures using python introspection.  In particular, this example
-# makes a N-stage pipeline structure.  Any specific pipeline is then a derived
-# class of SimplePipeline where methods with names starting with "stage" are
-# stages, and new members with names not starting with "_" are to be registered
-# for the next stage.
 
 def like(value, rname, pipe, pipemode=False):
     if isinstance(value, ObjectProxy):
@@ -34,9 +34,10 @@ def like(value, rname, pipe, pipemode=False):
                                 name=rname, reset_less=True)
     else:
         return Signal(value_bits_sign(value), name=rname,
-                             reset_less=True)
+                      reset_less=True)
         return Signal.like(value, name=rname, reset_less=True)
 
+
 def get_assigns(_assigns):
     assigns = []
     for e in _assigns:
@@ -72,11 +73,11 @@ class ObjectProxy:
     @classmethod
     def like(cls, m, value, pipemode=False, name=None, src_loc_at=0, **kwargs):
         name = name or tracer.get_var_name(depth=2 + src_loc_at,
-                                            default="$like")
+                                           default="$like")
 
         src_loc_at_1 = 1 + src_loc_at
         r = ObjectProxy(m, value.name, pipemode)
-        #for a, aname in value._preg_map.items():
+        # for a, aname in value._preg_map.items():
         #    r._preg_map[aname] = like(a, aname, m, pipemode)
         for a in value.ports():
             aname = a.name
@@ -101,7 +102,7 @@ class ObjectProxy:
         return res
 
     def eq(self, i):
-        print ("ObjectProxy eq", self, i)
+        print("ObjectProxy eq", self, i)
         res = []
         for a in self.ports():
             aname = a.name
@@ -122,7 +123,7 @@ class ObjectProxy:
         try:
             v = self._preg_map[name]
             return v
-            #return like(v, name, self._m)
+            # return like(v, name, self._m)
         except KeyError:
             raise AttributeError(
                 'error, no pipeline register "%s" defined for OP %s'
@@ -170,17 +171,17 @@ class PipelineStage:
         if ispec:
             self._preg_map[self._stagename] = ispec
         if prev:
-            print ("prev", prev._stagename, prev._preg_map)
-            #if prev._stagename in prev._preg_map:
+            print("prev", prev._stagename, prev._preg_map)
+            # if prev._stagename in prev._preg_map:
             #    m = prev._preg_map[prev._stagename]
             #    self._preg_map[prev._stagename] = m
             if '__nextstage__' in prev._preg_map:
                 m = prev._preg_map['__nextstage__']
                 m = likedict(m)
                 self._preg_map[self._stagename] = m
-                #for k, v in m.items():
-                    #m[k] = like(v, k, self._m)
-                print ("make current", self._stagename, m)
+                # for k, v in m.items():
+                #m[k] = like(v, k, self._m)
+                print("make current", self._stagename, m)
         self._pipemode = pipemode
         self._eqs = {}
         self._assigns = []
@@ -188,13 +189,13 @@ class PipelineStage:
     def __getattribute__(self, name):
         if name.startswith('_'):
             return object.__getattribute__(self, name)
-        #if name in self._preg_map['__nextstage__']:
+        # if name in self._preg_map['__nextstage__']:
         #    return self._preg_map['__nextstage__'][name]
         try:
-            print ("getattr", name, object.__getattribute__(self, '_preg_map'))
+            print("getattr", name, object.__getattribute__(self, '_preg_map'))
             v = self._preg_map[self._stagename][name]
             return v
-            #return like(v, name, self._m)
+            # return like(v, name, self._m)
         except KeyError:
             raise AttributeError(
                 'error, no pipeline register "%s" defined for stage %s'
@@ -212,37 +213,39 @@ class PipelineStage:
         if next_stage not in self._preg_map:
             self._preg_map[next_stage] = {}
         self._preg_map[next_stage][name] = new_pipereg
-        print ("setattr", name, value, self._preg_map)
+        print("setattr", name, value, self._preg_map)
         if self._pipemode:
             self._eqs[name] = new_pipereg
             assign = eq(new_pipereg, value)
-            print ("pipemode: append", new_pipereg, value, assign)
+            print("pipemode: append", new_pipereg, value, assign)
             if isinstance(value, ObjectProxy):
-                print ("OP, assigns:", value._assigns)
+                print("OP, assigns:", value._assigns)
                 self._assigns += value._assigns
                 self._eqs[name]._eqs = value._eqs
             #self._m.d.comb += assign
             self._assigns += assign
         elif self._m:
-            print ("!pipemode: assign", new_pipereg, value)
+            print("!pipemode: assign", new_pipereg, value)
             assign = eq(new_pipereg, value)
             self._m.d.sync += assign
         else:
-            print ("!pipemode !m: defer assign", new_pipereg, value)
+            print("!pipemode !m: defer assign", new_pipereg, value)
             assign = eq(new_pipereg, value)
             self._eqs[name] = new_pipereg
             self._assigns += assign
             if isinstance(value, ObjectProxy):
-                print ("OP, defer assigns:", value._assigns)
+                print("OP, defer assigns:", value._assigns)
                 self._assigns += value._assigns
                 self._eqs[name]._eqs = value._eqs
 
+
 def likelist(specs):
     res = []
     for v in specs:
         res.append(like(v, v.name, None, pipemode=True))
     return res
 
+
 def likedict(specs):
     if not isinstance(specs, dict):
         return like(specs, specs.name, None, pipemode=True)
@@ -257,18 +260,19 @@ class AutoStage(StageCls):
         self.inspecs, self.outspecs = inspecs, outspecs
         self.eqs, self.assigns = eqs, assigns
         #self.o = self.ospec()
+
     def ispec(self): return likedict(self.inspecs)
     def ospec(self): return likedict(self.outspecs)
 
     def process(self, i):
-        print ("stage process", i)
+        print("stage process", i)
         return self.eqs
 
     def setup(self, m, i):
-        print ("stage setup i", i, m)
-        print ("stage setup inspecs", self.inspecs)
-        print ("stage setup outspecs", self.outspecs)
-        print ("stage setup eqs", self.eqs)
+        print("stage setup i", i, m)
+        print("stage setup inspecs", self.inspecs)
+        print("stage setup outspecs", self.outspecs)
+        print("stage setup eqs", self.eqs)
         #self.o = self.ospec()
         m.d.comb += eq(self.inspecs, i)
         #m.d.comb += eq(self.outspecs, self.eqs)
@@ -283,7 +287,7 @@ class AutoPipe(UnbufferedPipeline):
     def elaborate(self, platform):
         m = UnbufferedPipeline.elaborate(self, platform)
         m.d.comb += self.assigns
-        print ("assigns", self.assigns, m)
+        print("assigns", self.assigns, m)
         return m
 
 
@@ -297,29 +301,29 @@ class PipeManager:
     def Stage(self, name, prev=None, ispec=None):
         if ispec:
             ispec = likedict(ispec)
-        print ("start stage", name, ispec)
+        print("start stage", name, ispec)
         stage = PipelineStage(name, None, prev, self.pipemode, ispec=ispec)
         try:
-            yield stage, self.m #stage._m
+            yield stage, self.m  # stage._m
         finally:
             pass
         if self.pipemode:
             if stage._ispec:
-                print ("use ispec", stage._ispec)
+                print("use ispec", stage._ispec)
                 inspecs = stage._ispec
             else:
                 inspecs = self.get_specs(stage, name)
                 #inspecs = likedict(inspecs)
             outspecs = self.get_specs(stage, '__nextstage__', liked=True)
-            print ("stage inspecs", name, inspecs)
-            print ("stage outspecs", name, outspecs)
-            eqs = stage._eqs # get_eqs(stage._eqs)
+            print("stage inspecs", name, inspecs)
+            print("stage outspecs", name, outspecs)
+            eqs = stage._eqs  # get_eqs(stage._eqs)
             assigns = get_assigns(stage._assigns)
-            print ("stage eqs", name, eqs)
-            print ("stage assigns", name, assigns)
+            print("stage eqs", name, eqs)
+            print("stage assigns", name, assigns)
             s = AutoStage(inspecs, outspecs, eqs, assigns)
             self.stages.append(s)
-        print ("end stage", name, self.pipemode, "\n")
+        print("end stage", name, self.pipemode, "\n")
 
     def get_specs(self, stage, name, liked=False):
         return stage._preg_map[name]
@@ -328,7 +332,7 @@ class PipeManager:
             for k, v in stage._preg_map[name].items():
                 #v = like(v, k, stage._m)
                 res.append(v)
-                #if isinstance(v, ObjectProxy):
+                # if isinstance(v, ObjectProxy):
                 #    res += v.get_specs()
             return res
         return {}
@@ -338,11 +342,11 @@ class PipeManager:
         return self
 
     def __exit__(self, *args):
-        print ("exit stage", args)
+        print("exit stage", args)
         pipes = []
         cb = ControlBase()
         for s in self.stages:
-            print ("stage specs", s, s.inspecs, s.outspecs)
+            print("stage specs", s, s.inspecs, s.outspecs)
             if self.pipetype == 'buffered':
                 p = BufferedHandshake(s)
             else:
@@ -388,15 +392,14 @@ class SimplePipeline:
         next_stage = self._current_stage_num + 1
         pipereg_id = str(self._current_stage_num) + 'to' + str(next_stage)
         rname = 'pipereg_' + pipereg_id + '_' + name
-        #new_pipereg = Signal(value_bits_sign(value), name=rname,
+        # new_pipereg = Signal(value_bits_sign(value), name=rname,
         #                     reset_less=True)
         if isinstance(value, ObjectProxy):
             new_pipereg = ObjectProxy.like(self._m, value,
-                                           name=rname, reset_less = True)
+                                           name=rname, reset_less=True)
         else:
-            new_pipereg = Signal.like(value, name=rname, reset_less = True)
+            new_pipereg = Signal.like(value, name=rname, reset_less=True)
         if next_stage not in self._pipeline_register_map:
             self._pipeline_register_map[next_stage] = {}
         self._pipeline_register_map[next_stage][name] = new_pipereg
         self._m.d.sync += eq(new_pipereg, value)
-
index 1706e97c076d6abecb61251d167faa26d16722af..c07623085e50461cca08ca561430456a740bdd12 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -16,8 +17,9 @@ from nmutil.singlepipe import StageChain
 class PipeModBase(Elaboratable):
     """PipeModBase: common code between nearly every pipeline module
     """
+
     def __init__(self, pspec, modname):
-        self.modname = modname # use this to give a name to this module
+        self.modname = modname  # use this to give a name to this module
         self.pspec = pspec
         self.i = self.ispec()
         self.o = self.ospec()
@@ -39,6 +41,7 @@ class PipeModBaseChain(DynamicPipe):
     and uses pspec.pipekls to dynamically select the pipeline type
     Also conforms to the Pipeline Stage API
     """
+
     def __init__(self, pspec):
         self.pspec = pspec
         self.chain = self.get_chain()
@@ -55,10 +58,10 @@ class PipeModBaseChain(DynamicPipe):
         return self.chain[-1].ospec()
 
     def process(self, i):
-        return self.o # ... returned here (see setup comment below)
+        return self.o  # ... returned here (see setup comment below)
 
     def setup(self, m, i):
         """ links module to inputs and outputs
         """
-        StageChain(self.chain).setup(m, i) # input linked here, through chain
-        self.o = self.chain[-1].o # output is the last thing in the chain...
+        StageChain(self.chain).setup(m, i)  # input linked here, through chain
+        self.o = self.chain[-1].o  # output is the last thing in the chain...
diff --git a/src/nmutil/plain_data.py b/src/nmutil/plain_data.py
new file mode 100644 (file)
index 0000000..7bde6ba
--- /dev/null
@@ -0,0 +1,373 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+import keyword
+
+
+class FrozenPlainDataError(AttributeError):
+    pass
+
+
+class __NotSet:
+    """ helper for __repr__ for when fields aren't set """
+
+    def __repr__(self):
+        return "<not set>"
+
+
+__NOT_SET = __NotSet()
+
+
+def __ignored_classes():
+    classes = [object]  # type: list[type]
+
+    from abc import ABC
+
+    classes += [ABC]
+
+    from typing import (
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound)
+
+    classes += [
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound]
+
+    from collections.abc import (
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence)
+
+    classes += [
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence]
+
+    # rest aren't supported by python 3.7, so try to import them and skip if
+    # that errors
+
+    try:
+        # typing_extensions uses typing.Protocol if available
+        from typing_extensions import Protocol
+        classes.append(Protocol)
+    except ImportError:
+        pass
+
+    for cls in classes:
+        yield from cls.__mro__
+
+
+__IGNORED_CLASSES = frozenset(__ignored_classes())
+
+
+def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
+    if not isinstance(cls, type):
+        raise TypeError(
+            "plain_data() can only be used as a class decorator")
+    # slots is an ordered set by using dict keys.
+    # always add __dict__ and __weakref__
+    slots = {"__dict__": None, "__weakref__": None}
+    if frozen:
+        slots["__plain_data_init_done"] = None
+    fields = []
+    any_parents_have_dict = False
+    any_parents_have_weakref = False
+    for cur_cls in reversed(cls.__mro__):
+        d = getattr(cur_cls, "__dict__", {})
+        if cur_cls is not cls:
+            if "__dict__" in d:
+                any_parents_have_dict = True
+            if "__weakref__" in d:
+                any_parents_have_weakref = True
+        if cur_cls in __IGNORED_CLASSES:
+            continue
+        try:
+            cur_slots = cur_cls.__slots__
+        except AttributeError as e:
+            raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
+                            " must have __slots__ so plain_data() can "
+                            "determine what fields exist in "
+                            f"{cls.__module__}.{cls.__qualname__}") from e
+        if not isinstance(cur_slots, tuple):
+            raise TypeError("plain_data() requires __slots__ to be a "
+                            "tuple of str")
+        for field in cur_slots:
+            if not isinstance(field, str):
+                raise TypeError("plain_data() requires __slots__ to be a "
+                                "tuple of str")
+            if not field.isidentifier() or keyword.iskeyword(field):
+                raise TypeError(
+                    "plain_data() requires __slots__ entries to be valid "
+                    "Python identifiers and not keywords")
+            if field not in slots:
+                fields.append(field)
+            slots[field] = None
+
+    fields = tuple(fields)  # fields needs to be immutable
+
+    if any_parents_have_dict:
+        # work around a CPython bug that unnecessarily checks if parent
+        # classes already have the __dict__ slot.
+        del slots["__dict__"]
+
+    if any_parents_have_weakref:
+        # work around a CPython bug that unnecessarily checks if parent
+        # classes already have the __weakref__ slot.
+        del slots["__weakref__"]
+
+    # now create a new class having everything we need
+    retval_dict = dict(cls.__dict__)
+    # remove all old descriptors:
+    for name in slots.keys():
+        retval_dict.pop(name, None)
+
+    retval_dict["__plain_data_fields"] = fields
+
+    def add_method_or_error(value, replace=False):
+        name = value.__name__
+        if name in retval_dict and not replace:
+            raise TypeError(
+                f"can't generate {name} method: attribute already exists")
+        value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
+        retval_dict[name] = value
+
+    if frozen:
+        def __setattr__(self, name: str, value):
+            if getattr(self, "__plain_data_init_done", False):
+                raise FrozenPlainDataError(f"cannot assign to field {name!r}")
+            elif name not in slots and not name.startswith("_"):
+                raise AttributeError(
+                    f"cannot assign to unknown field {name!r}")
+            object.__setattr__(self, name, value)
+
+        add_method_or_error(__setattr__)
+
+        def __delattr__(self, name):
+            if getattr(self, "__plain_data_init_done", False):
+                raise FrozenPlainDataError(f"cannot delete field {name!r}")
+            object.__delattr__(self, name)
+
+        add_method_or_error(__delattr__)
+
+        old_init = cls.__init__
+
+        def __init__(self, *args, **kwargs):
+            if hasattr(self, "__plain_data_init_done"):
+                # we're already in an __init__ call (probably a
+                # superclass's __init__), don't set
+                # __plain_data_init_done too early
+                return old_init(self, *args, **kwargs)
+            object.__setattr__(self, "__plain_data_init_done", False)
+            try:
+                return old_init(self, *args, **kwargs)
+            finally:
+                object.__setattr__(self, "__plain_data_init_done", True)
+
+        add_method_or_error(__init__, replace=True)
+    else:
+        old_init = None
+
+    # set __slots__ to have everything we need in the preferred order
+    retval_dict["__slots__"] = tuple(slots.keys())
+
+    def __getstate__(self):
+        # pickling support
+        return [getattr(self, name) for name in fields]
+
+    add_method_or_error(__getstate__)
+
+    def __setstate__(self, state):
+        # pickling support
+        for name, value in zip(fields, state):
+            # bypass frozen setattr
+            object.__setattr__(self, name, value)
+
+    add_method_or_error(__setstate__)
+
+    # get source code that gets a tuple of all fields
+    def fields_tuple(var):
+        # type: (str) -> str
+        l = []
+        for name in fields:
+            l.append(f"{var}.{name}, ")
+        return "(" + "".join(l) + ")"
+
+    if eq:
+        exec(f"""
+def __eq__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} == {fields_tuple('other')}
+
+add_method_or_error(__eq__)
+""")
+
+    if unsafe_hash:
+        exec(f"""
+def __hash__(self):
+    return hash({fields_tuple('self')})
+
+add_method_or_error(__hash__)
+""")
+
+    if order:
+        exec(f"""
+def __lt__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} < {fields_tuple('other')}
+
+add_method_or_error(__lt__)
+
+def __le__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} <= {fields_tuple('other')}
+
+add_method_or_error(__le__)
+
+def __gt__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} > {fields_tuple('other')}
+
+add_method_or_error(__gt__)
+
+def __ge__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} >= {fields_tuple('other')}
+
+add_method_or_error(__ge__)
+""")
+
+    if repr_:
+        def __repr__(self):
+            parts = []
+            for name in fields:
+                parts.append(f"{name}={getattr(self, name, __NOT_SET)!r}")
+            return f"{self.__class__.__qualname__}({', '.join(parts)})"
+
+        add_method_or_error(__repr__)
+
+    # construct class
+    retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
+
+    # add __qualname__
+    retval.__qualname__ = cls.__qualname__
+
+    def fix_super_and_class(value):
+        # fixup super() and __class__
+        # derived from: https://stackoverflow.com/a/71666065/2597900
+        try:
+            closure = value.__closure__
+            if isinstance(closure, tuple):
+                if closure[0].cell_contents is cls:
+                    closure[0].cell_contents = retval
+        except (AttributeError, IndexError):
+            pass
+
+    for value in retval.__dict__.values():
+        fix_super_and_class(value)
+
+    if old_init is not None:
+        fix_super_and_class(old_init)
+
+    return retval
+
+
+def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
+               frozen=False):
+    # defaults match dataclass, with the exception of `init`
+    """ Decorator for adding equality comparison, ordered comparison,
+    `repr` support, `hash` support, and frozen type (read-only fields)
+    support to classes that are just plain data.
+
+    This is kinda like dataclasses, but uses `__slots__` instead of type
+    annotations, as well as requiring you to write your own `__init__`
+    """
+    def decorator(cls):
+        return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
+                          repr_=repr, frozen=frozen)
+    return decorator
+
+
+def fields(pd):
+    """ get the tuple of field names of the passed-in
+    `@plain_data()`-decorated class.
+
+    This is similar to `dataclasses.fields`, except this returns a
+    different type.
+
+    Returns: tuple[str, ...]
+
+    e.g.:
+    ```
+    @plain_data()
+    class MyBaseClass:
+        __slots__ = "a_field", "field2"
+        def __init__(self, a_field, field2):
+            self.a_field = a_field
+            self.field2 = field2
+
+    assert fields(MyBaseClass) == ("a_field", "field2")
+    assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
+
+    @plain_data()
+    class MyClass(MyBaseClass):
+        __slots__ = "child_field",
+        def __init__(self, a_field, field2, child_field):
+            super().__init__(a_field=a_field, field2=field2)
+            self.child_field = child_field
+
+    assert fields(MyClass) == ("a_field", "field2", "child_field")
+    assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
+    ```
+    """
+    retval = getattr(pd, "__plain_data_fields", None)
+    if not isinstance(retval, tuple):
+        raise TypeError("the passed-in object must be a class or an instance"
+                        " of a class decorated with @plain_data()")
+    return retval
+
+
+__NOT_SPECIFIED = object()
+
+
+def replace(pd, **changes):
+    """ Return a new instance of the passed-in `@plain_data()`-decorated
+    object, but with the specified fields replaced with new values.
+    This is quite useful with frozen `@plain_data()` classes.
+
+    e.g.:
+    ```
+    @plain_data(frozen=True)
+    class MyClass:
+        __slots__ = "a", "b", "c"
+        def __init__(self, a, b, *, c):
+            self.a = a
+            self.b = b
+            self.c = c
+
+    v1 = MyClass(1, 2, c=3)
+    v2 = replace(v1, b=4)
+    assert v2 == MyClass(a=1, b=4, c=3)
+    assert v2 is not v1
+    ```
+    """
+    kwargs = {}
+    ty = type(pd)
+    # call fields on ty rather than pd to ensure we're not called with a
+    # class rather than an instance.
+    for name in fields(ty):
+        value = changes.pop(name, __NOT_SPECIFIED)
+        if value is __NOT_SPECIFIED:
+            kwargs[name] = getattr(pd, name)
+        else:
+            kwargs[name] = value
+    if len(changes) != 0:
+        raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
+    return ty(**kwargs)
diff --git a/src/nmutil/plain_data.pyi b/src/nmutil/plain_data.pyi
new file mode 100644 (file)
index 0000000..520a4d9
--- /dev/null
@@ -0,0 +1,24 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+from typing import TypeVar, Type, Callable, Any
+
+_T = TypeVar("_T")
+
+
+class FrozenPlainDataError(AttributeError):
+    pass
+
+
+def plain_data(*, eq: bool = True, unsafe_hash: bool = False,
+               order: bool = False, repr: bool = True,
+               frozen: bool = False) -> Callable[[Type[_T]], Type[_T]]:
+    ...
+
+
+def fields(pd: Any) -> tuple[str, ...]:
+    ...
+
+
+def replace(pd: _T, **changes: Any) -> _T:
+    ...
index 2f0642ec6577db267d5245d12ef9292e0aa7e05a..265af7fbd995a9fba1b32d0c678abfc9be5e15ea 100644 (file)
@@ -1,13 +1,18 @@
 # based on ariane plru, from tlb.sv
 
-from nmigen import Signal, Module, Cat, Const, Repl
+# old PLRU API, once all users have migrated to new API in plru2.py, then
+# plru2.py will be renamed to plru.py, replacing this file.
+
+from nmigen import Signal, Module, Cat, Const, Repl, Array
 from nmigen.hdl.ir import Elaboratable
 from nmigen.cli import rtlil
 from nmigen.utils import log2_int
+from nmigen.lib.coding import Decoder
+from warnings import warn
 
 
 class PLRU(Elaboratable):
-    """ PLRU - Pseudo Least Recently Used Replacement
+    r""" PLRU - Pseudo Least Recently Used Replacement
 
         PLRU-tree indexing:
         lvl0        0
@@ -21,17 +26,24 @@ class PLRU(Elaboratable):
     """
 
     def __init__(self, BITS):
+        warn("nmutil.plru.PLRU is deprecated due to having a broken API, use "
+             "nmutil.plru2.PLRU instead", DeprecationWarning, stacklevel=2)
         self.BITS = BITS
-        self.acc_en = Signal(BITS)
-        self.acc_i = Signal()
+        self.acc_i = Signal(BITS)
+        self.acc_en = Signal()
         self.lru_o = Signal(BITS)
 
+        self._plru_tree = Signal(self.TLBSZ)
+        """ exposed only for testing """
+
+    @property
+    def TLBSZ(self):
+        return 2 * (self.BITS - 1)
+
     def elaborate(self, platform=None):
         m = Module()
 
         # Tree (bit per entry)
-        TLBSZ = 2*(self.BITS-1)
-        plru_tree = Signal(TLBSZ)
 
         # Just predefine which nodes will be set/cleared
         # E.g. for a TLB with 8 entries, the for-loop is semantically
@@ -50,7 +62,7 @@ class PLRU(Elaboratable):
 
         LOG_TLB = log2_int(self.BITS, False)
         hit = Signal(self.BITS, reset_less=True)
-        m.d.comb += hit.eq(Repl(self.acc_i, self.BITS) & self.acc_en)
+        m.d.comb += hit.eq(Repl(self.acc_en, self.BITS) & self.acc_i)
 
         for i in range(self.BITS):
             # we got a hit so update the pointer as it was least recently used
@@ -62,9 +74,9 @@ class PLRU(Elaboratable):
                     shift = LOG_TLB - lvl
                     new_idx = Const(~((i >> (shift-1)) & 1), 1)
                     plru_idx = idx_base + (i >> shift)
-                    #print("plru", i, lvl, hex(idx_base),
+                    # print("plru", i, lvl, hex(idx_base),
                     #      plru_idx, shift, new_idx)
-                    m.d.sync += plru_tree[plru_idx].eq(new_idx)
+                    m.d.sync += self._plru_tree[plru_idx].eq(new_idx)
 
         # Decode tree to write enable signals
         # Next for-loop basically creates the following logic for e.g.
@@ -90,9 +102,9 @@ class PLRU(Elaboratable):
                 new_idx = (i >> (shift-1)) & 1
                 plru_idx = idx_base + (i >> shift)
                 plru = Signal(reset_less=True,
-                              name="plru-%d-%d-%d-%d" % \
-                                    (i, lvl, plru_idx, new_idx))
-                m.d.comb += plru.eq(plru_tree[plru_idx])
+                              name="plru-%d-%d-%d-%d" %
+                              (i, lvl, plru_idx, new_idx))
+                m.d.comb += plru.eq(self._plru_tree[plru_idx])
                 if new_idx:
                     en.append(~plru)  # yes inverted (using bool() below)
                 else:
@@ -109,9 +121,59 @@ class PLRU(Elaboratable):
         return [self.acc_en, self.lru_o, self.acc_i]
 
 
+class PLRUs(Elaboratable):
+    def __init__(self, n_plrus, n_bits):
+        warn("nmutil.plru.PLRUs is deprecated due to having a broken API, use "
+             "nmutil.plru2.PLRUs instead", DeprecationWarning, stacklevel=2)
+        self.n_plrus = n_plrus
+        self.n_bits = n_bits
+        self.valid = Signal()
+        self.way = Signal(n_bits)
+        self.index = Signal(n_plrus.bit_length())
+        self.isel = Signal(n_plrus.bit_length())
+        self.o_index = Signal(n_bits)
+
+    def elaborate(self, platform):
+        """Generate TLB PLRUs
+        """
+        m = Module()
+        comb = m.d.comb
+
+        if self.n_plrus == 0:
+            return m
+
+        # Binary-to-Unary one-hot, enabled by valid
+        m.submodules.te = te = Decoder(self.n_plrus)
+        comb += te.n.eq(~self.valid)
+        comb += te.i.eq(self.index)
+
+        out = Array(Signal(self.n_bits, name="plru_out%d" % x)
+                    for x in range(self.n_plrus))
+
+        for i in range(self.n_plrus):
+            # PLRU interface
+            m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
+
+            comb += plru.acc_en.eq(te.o[i])
+            comb += plru.acc_i.eq(self.way)
+            comb += out[i].eq(plru.lru_o)
+
+        # select output based on index
+        comb += self.o_index.eq(out[self.isel])
+
+        return m
+
+    def ports(self):
+        return [self.valid, self.way, self.index, self.isel, self.o_index]
+
+
 if __name__ == '__main__':
-    dut = PLRU(8)
+    dut = PLRU(3)
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("test_plru.il", "w") as f:
         f.write(vl)
 
+    dut = PLRUs(4, 2)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_plrus.il", "w") as f:
+        f.write(vl)
diff --git a/src/nmutil/plru.txt b/src/nmutil/plru.txt
new file mode 100644 (file)
index 0000000..4bac768
--- /dev/null
@@ -0,0 +1,51 @@
+pseudo-LRU
+
+two-way set associative - one bit
+
+   indicates which line of the two has been reference more recently
+
+
+four-way set associative - three bits
+
+   each bit represents one branch point in a binary decision tree; let 1
+   represent that the left side has been referenced more recently than the
+   right side, and 0 vice-versa
+
+              are all 4 lines valid?
+                   /       \
+                 yes        no, use an invalid line
+                  |
+                  |
+                  |
+             bit_0 == 0?            state | replace      ref to | next state
+              /       \             ------+--------      -------+-----------
+             y         n             00x  |  line_0      line_0 |    11_
+            /           \            01x  |  line_1      line_1 |    10_
+     bit_1 == 0?    bit_2 == 0?      1x0  |  line_2      line_2 |    0_1
+       /    \          /    \        1x1  |  line_3      line_3 |    0_0
+      y      n        y      n
+     /        \      /        \        ('x' means       ('_' means unchanged)
+   line_0  line_1  line_2  line_3      don't care)
+
+   (see Figure 3-7, p. 3-18, in Intel Embedded Pentium Processor Family Dev.
+    Manual, 1998, http://www.intel.com/design/intarch/manuals/273204.htm)
+
+
+note that there is a 6-bit encoding for true LRU for four-way set associative
+
+  bit 0: bank[1] more recently used than bank[0]
+  bit 1: bank[2] more recently used than bank[0]
+  bit 2: bank[2] more recently used than bank[1]
+  bit 3: bank[3] more recently used than bank[0]
+  bit 4: bank[3] more recently used than bank[1]
+  bit 5: bank[3] more recently used than bank[2]
+
+  this results in 24 valid bit patterns within the 64 possible bit patterns
+  (4! possible valid traces for bank references)
+
+  e.g., a trace of 0 1 2 3, where 0 is LRU and 3 is MRU, is encoded as 111111
+
+  you can implement a state machine with a 256x6 ROM (6-bit state encoding
+  appended with a 2-bit bank reference input will yield a new 6-bit state),
+  and you can implement an LRU bank indicator with a 64x2 ROM
+
diff --git a/src/nmutil/plru2.py b/src/nmutil/plru2.py
new file mode 100644 (file)
index 0000000..187a7b4
--- /dev/null
@@ -0,0 +1,194 @@
+# based on microwatt plru.vhdl
+# https://github.com/antonblanchard/microwatt/blob/f67b1431655c291fc1c99857a5c1ef624d5b264c/plru.vhdl
+
+# new PLRU API, once all users have migrated to new API in plru2.py, then
+# plru2.py will be renamed to plru.py.
+# IMPORTANT: since the API will change more, migration should be blocked on:
+# https://bugs.libre-soc.org/show_bug.cgi?id=913
+
+from nmigen.hdl.ir import Elaboratable, Display, Signal, Array, Const, Value
+from nmigen.hdl.dsl import Module
+from nmigen.cli import rtlil
+from nmigen.lib.coding import Decoder
+
+
+class PLRU(Elaboratable):
+    r""" PLRU - Pseudo Least Recently Used Replacement
+
+        IMPORTANT: since the API will change more, migration should be blocked on:
+        https://bugs.libre-soc.org/show_bug.cgi?id=913
+
+        PLRU-tree indexing:
+        lvl0        0
+                   / \
+                  /   \
+                 /     \
+        lvl1    1       2
+               / \     / \
+        lvl2  3   4   5   6
+             / \ / \ / \ / \
+             ... ... ... ...
+    """
+
+    def __init__(self, log2_num_ways, debug=False):
+        # type: (int, bool) -> None
+        """
+        IMPORTANT: since the API will change more, migration should be blocked on:
+        https://bugs.libre-soc.org/show_bug.cgi?id=913
+
+        Arguments:
+        log2_num_ways: int
+            the log-base-2 of the number of cache ways -- BITS in plru.vhdl
+        debug: bool
+            true if this should print debugging messages at simulation time.
+        """
+        assert log2_num_ways > 0
+        self.log2_num_ways = log2_num_ways
+        self.debug = debug
+        self.acc_i = Signal(log2_num_ways)
+        self.acc_en_i = Signal()
+        self.lru_o = Signal(log2_num_ways)
+
+        def mk_tree(i):
+            return Signal(name=f"tree_{i}", reset=0)
+
+        # original vhdl has array 1 too big, last entry is never used,
+        # subtract 1 to compensate
+        self._tree = Array(mk_tree(i) for i in range(self.num_ways - 1))
+        """ exposed only for testing """
+
+        def mk_node(i, prefix):
+            return Signal(range(self.num_ways), name=f"{prefix}_node_{i}",
+                          reset=0)
+
+        nodes_range = range(self.log2_num_ways)
+
+        self._get_lru_nodes = [mk_node(i, "get_lru") for i in nodes_range]
+        """ exposed only for testing """
+
+        self._upd_lru_nodes = [mk_node(i, "upd_lru") for i in nodes_range]
+        """ exposed only for testing """
+
+    @property
+    def num_ways(self):
+        return 1 << self.log2_num_ways
+
+    def _display(self, msg, *args):
+        if not self.debug:
+            return []
+        # work around not yet having
+        # https://gitlab.com/nmigen/nmigen/-/merge_requests/10
+        # by sending through Value.cast()
+        return [Display(msg, *map(Value.cast, args))]
+
+    def _get_lru(self, m):
+        """ get_lru process in plru.vhdl """
+        # XXX Check if we can turn that into a little ROM instead that
+        # takes the tree bit vector and returns the LRU. See if it's better
+        # in term of FPGA resource usage...
+        m.d.comb += self._get_lru_nodes[0].eq(0)
+        for i in range(self.log2_num_ways):
+            node = self._get_lru_nodes[i]
+            val = self._tree[node]
+            m.d.comb += self._display("GET: i:%i node:%#x val:%i",
+                                      i, node, val)
+            m.d.comb += self.lru_o[self.log2_num_ways - 1 - i].eq(val)
+            if i != self.log2_num_ways - 1:
+                # modified from microwatt version, it uses `node * 2` value
+                # to index into tree, rather than using node like is used
+                # earlier in this loop iteration
+                node <<= 1
+                with m.If(val):
+                    m.d.comb += self._get_lru_nodes[i + 1].eq(node + 2)
+                with m.Else():
+                    m.d.comb += self._get_lru_nodes[i + 1].eq(node + 1)
+
+    def _update_lru(self, m):
+        """ update_lru process in plru.vhdl """
+        with m.If(self.acc_en_i):
+            m.d.comb += self._upd_lru_nodes[0].eq(0)
+            for i in range(self.log2_num_ways):
+                node = self._upd_lru_nodes[i]
+                abit = self.acc_i[self.log2_num_ways - 1 - i]
+                m.d.sync += [
+                    self._tree[node].eq(~abit),
+                    self._display("UPD: i:%i node:%#x val:%i",
+                                  i, node, ~abit),
+                ]
+                if i != self.log2_num_ways - 1:
+                    node <<= 1
+                    with m.If(abit):
+                        m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 2)
+                    with m.Else():
+                        m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 1)
+
+    def elaborate(self, platform=None):
+        m = Module()
+        self._get_lru(m)
+        self._update_lru(m)
+        return m
+
+    def __iter__(self):
+        yield self.acc_i
+        yield self.acc_en_i
+        yield self.lru_o
+
+    def ports(self):
+        return list(self)
+
+
+# FIXME: convert PLRUs to new API
+# class PLRUs(Elaboratable):
+#     def __init__(self, n_plrus, n_bits):
+#         self.n_plrus = n_plrus
+#         self.n_bits = n_bits
+#         self.valid = Signal()
+#         self.way = Signal(n_bits)
+#         self.index = Signal(n_plrus.bit_length())
+#         self.isel = Signal(n_plrus.bit_length())
+#         self.o_index = Signal(n_bits)
+#
+#     def elaborate(self, platform):
+#         """Generate TLB PLRUs
+#         """
+#         m = Module()
+#         comb = m.d.comb
+#
+#         if self.n_plrus == 0:
+#             return m
+#
+#         # Binary-to-Unary one-hot, enabled by valid
+#         m.submodules.te = te = Decoder(self.n_plrus)
+#         comb += te.n.eq(~self.valid)
+#         comb += te.i.eq(self.index)
+#
+#         out = Array(Signal(self.n_bits, name="plru_out%d" % x)
+#                     for x in range(self.n_plrus))
+#
+#         for i in range(self.n_plrus):
+#             # PLRU interface
+#             m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
+#
+#             comb += plru.acc_en.eq(te.o[i])
+#             comb += plru.acc_i.eq(self.way)
+#             comb += out[i].eq(plru.lru_o)
+#
+#         # select output based on index
+#         comb += self.o_index.eq(out[self.isel])
+#
+#         return m
+#
+#     def ports(self):
+#         return [self.valid, self.way, self.index, self.isel, self.o_index]
+
+
+if __name__ == '__main__':
+    dut = PLRU(3)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_plru.il", "w") as f:
+        f.write(vl)
+
+    # dut = PLRUs(4, 2)
+    # vl = rtlil.convert(dut, ports=dut.ports())
+    # with open("test_plrus.il", "w") as f:
+    #     f.write(vl)
diff --git a/src/nmutil/popcount.py b/src/nmutil/popcount.py
new file mode 100644 (file)
index 0000000..b3e3bea
--- /dev/null
@@ -0,0 +1,61 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from nmigen import Module
+from nmigen.hdl.ast import Value, Const, Signal
+from nmutil.plain_data import plain_data
+from nmutil.prefix_sum import tree_reduction
+from nmigen.cli import rtlil
+
+
+def pop_count(v, *, width=None, process_temporary=lambda v: v):
+    """return the population count (number of 1 bits) of `v`.
+    Arguments:
+    v: nmigen.Value | int
+        the value to calculate the pop-count of.
+    width: int | None
+        the bit-width of `v`.
+        If `width` is None, then `v` must be a nmigen Value or
+        match `v`'s width.
+    process_temporary: function of (type(v)) -> type(v)
+        called after every addition operation, can be used to introduce
+        `Signal`s for the intermediate values in the pop-count computation
+        like so:
+
+        ```
+        def process_temporary(v):
+            sig = Signal.like(v)
+            m.d.comb += sig.eq(v)
+            return sig
+        ```
+    """
+    if isinstance(v, Value):
+        if width is None:
+            width = len(v)
+        assert width == len(v)
+        bits = [v[i] for i in range(width)]
+        if len(bits) == 0:
+            return Const(0)
+    else:
+        assert width is not None, "width must be given"
+        # v and width are ints
+        bits = [(v & (1 << i)) != 0 for i in range(width)]
+        if len(bits) == 0:
+            return 0
+    return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
+
+
+# run this as simply "python3 popcount.py" to create an ilang file that
+# can be viewed with yosys "read_ilang test_popcount.il; show top"
+if __name__ == "__main__":
+    m = Module()
+    v = Signal(8)
+    x = Signal(8)
+    pc = pop_count(v, width=8)
+    m.d.comb += v.eq(pc)
+    vl = rtlil.convert(m)
+    with open("test_popcount.il", "w") as f:
+        f.write(vl)
diff --git a/src/nmutil/prefix_sum.py b/src/nmutil/prefix_sum.py
new file mode 100644 (file)
index 0000000..23eca36
--- /dev/null
@@ -0,0 +1,300 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from collections import defaultdict
+import operator
+from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
+
+
+@plain_data(order=True, unsafe_hash=True, frozen=True)
+class Op:
+    """An associative operation in a prefix-sum.
+    The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
+    The operation is not assumed to be commutative.
+    """
+    __slots__ = "out", "lhs", "rhs", "row"
+
+    def __init__(self, out, lhs, rhs, row):
+        self.out = out
+        "index of the item to output to"
+
+        self.lhs = lhs
+        "index of the item the left-hand-side input comes from"
+
+        self.rhs = rhs
+        "index of the item the right-hand-side input comes from"
+
+        self.row = row
+        "row in the prefix-sum diagram"
+
+
+def prefix_sum_ops(item_count, *, work_efficient=False):
+    """Get the associative operations needed to compute a parallel prefix-sum
+    of `item_count` items.
+
+    The operations aren't assumed to be commutative.
+
+    This has a depth of `O(log(N))` and an operation count of `O(N)` if
+    `work_efficient` is true, otherwise `O(N*log(N))`.
+
+    The algorithms used are derived from:
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+    Parameters:
+    item_count: int
+        number of input items.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    Returns: Iterable[Op]
+        output associative operations.
+    """
+    assert isinstance(item_count, int)
+    # compute the partial sums using a set of binary trees
+    # this is the first half of the work-efficient algorithm and the whole of
+    # the non-work-efficient algorithm.
+    dist = 1
+    row = 0
+    while dist < item_count:
+        start = dist * 2 - 1 if work_efficient else dist
+        step = dist * 2 if work_efficient else 1
+        for i in reversed(range(start, item_count, step)):
+            yield Op(out=i, lhs=i - dist, rhs=i, row=row)
+        dist <<= 1
+        row += 1
+    if work_efficient:
+        # express all output items in terms of the computed partial sums.
+        dist >>= 1
+        while dist >= 1:
+            for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
+                yield Op(out=i, lhs=i - dist, rhs=i, row=row)
+            row += 1
+            dist >>= 1
+
+
+def prefix_sum(items, fn=operator.add, *, work_efficient=False):
+    """Compute the parallel prefix-sum of `items`, using associative operator
+    `fn` instead of addition.
+
+    This has a depth of `O(log(N))` and an operation count of `O(N)` if
+    `work_efficient` is true, otherwise `O(N*log(N))`.
+
+    The algorithms used are derived from:
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+    Parameters:
+    items: Iterable[_T]
+        input items.
+    fn: Callable[[_T, _T], _T]
+        Operation to use for the prefix-sum algorithm instead of addition.
+        Assumed to be associative not necessarily commutative.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    Returns: list[_T]
+        output items.
+    """
+    items = list(items)
+    for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
+        items[op.out] = fn(items[op.lhs], items[op.rhs])
+    return items
+
+
+@plain_data()
+class _Cell:
+    __slots__ = "slant", "plus", "tee"
+
+    def __init__(self, slant, plus, tee):
+        self.slant = slant
+        self.plus = plus
+        self.tee = tee
+
+
+def render_prefix_sum_diagram(item_count, *, work_efficient=False,
+                              sp=" ", vbar="|", plus="⊕",
+                              slant="\\", connect="●", no_connect="X",
+                              padding=1,
+                              ):
+    """renders a prefix-sum diagram, matches `prefix_sum_ops`.
+
+    Parameters:
+    item_count: int
+        number of input items.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    sp: str
+        character used for blank space
+    vbar: str
+        character used for a vertical bar
+    plus: str
+        character used for the addition operation
+    slant: str
+        character used to draw a line from the top left to the bottom right
+    connect: str
+        character used to draw a connection between a vertical line and a line
+        going from the center of this character to the bottom right
+    no_connect: str
+        character used to draw two lines crossing but not connecting, the lines
+        are vertical and diagonal from top left to the bottom right
+    padding: int
+        amount of padding characters in the output cells.
+    Returns: str
+        rendered diagram
+    """
+    ops_by_row = defaultdict(set)
+    for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
+        assert op.out == op.rhs, f"can't draw op: {op}"
+        assert op not in ops_by_row[op.row], f"duplicate op: {op}"
+        ops_by_row[op.row].add(op)
+
+    def blank_row():
+        return [_Cell(slant=False, plus=False, tee=False)
+                for _ in range(item_count)]
+
+    cells = [blank_row()]
+
+    for row in sorted(ops_by_row.keys()):
+        ops = ops_by_row[row]
+        max_distance = max(op.rhs - op.lhs for op in ops)
+        cells.extend(blank_row() for _ in range(max_distance))
+        for op in ops:
+            assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
+            y = len(cells) - 1
+            x = op.out
+            cells[y][x].plus = True
+            x -= 1
+            y -= 1
+            while op.lhs < x:
+                cells[y][x].slant = True
+                x -= 1
+                y -= 1
+            cells[y][x].tee = True
+
+    lines = []
+    for cells_row in cells:
+        row_text = [[] for y in range(2 * padding + 1)]
+        for cell in cells_row:
+            # top padding
+            for y in range(padding):
+                # top left padding
+                for x in range(padding):
+                    is_slant = x == y and (cell.plus or cell.slant)
+                    row_text[y].append(slant if is_slant else sp)
+                # top vertical bar
+                row_text[y].append(vbar)
+                # top right padding
+                for x in range(padding):
+                    row_text[y].append(sp)
+            # center left padding
+            for x in range(padding):
+                row_text[padding].append(sp)
+            # center
+            center = vbar
+            if cell.plus:
+                center = plus
+            elif cell.tee:
+                center = connect
+            elif cell.slant:
+                center = no_connect
+            row_text[padding].append(center)
+            # center right padding
+            for x in range(padding):
+                row_text[padding].append(sp)
+            # bottom padding
+            for y in range(padding + 1, 2 * padding + 1):
+                # bottom left padding
+                for x in range(padding):
+                    row_text[y].append(sp)
+                # bottom vertical bar
+                row_text[y].append(vbar)
+                # bottom right padding
+                for x in range(padding + 1, 2 * padding + 1):
+                    is_slant = x == y and (cell.tee or cell.slant)
+                    row_text[y].append(slant if is_slant else sp)
+        for line in row_text:
+            lines.append("".join(line))
+
+    return "\n".join(map(str.rstrip, lines))
+
+
+def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
+    """ Get the associative operations needed to compute a parallel prefix-sum
+    of `len(needed_outputs)` items.
+
+    The operations aren't assumed to be commutative.
+
+    This has a depth of `O(log(N))` and an operation count of `O(N)` if
+    `work_efficient` is true, otherwise `O(N*log(N))`.
+
+    The algorithms used are derived from:
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+    Parameters:
+    needed_outputs: Iterable[bool]
+        The length is the number of input/output items.
+        Each item is True if that corresponding output is needed.
+        Unneeded outputs have unspecified value.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    Returns: Iterable[Op]
+        output associative operations.
+    """
+
+    # needed_outputs is an iterable, we need to construct a new list so we
+    # don't modify the passed-in value
+    items_live_flags = [bool(i) for i in needed_outputs]
+    ops = list(prefix_sum_ops(item_count=len(items_live_flags),
+                              work_efficient=work_efficient))
+    ops_live_flags = [False] * len(ops)
+    for i in reversed(range(len(ops))):
+        op = ops[i]
+        out_live = items_live_flags[op.out]
+        items_live_flags[op.out] = False
+        items_live_flags[op.lhs] |= out_live
+        items_live_flags[op.rhs] |= out_live
+        ops_live_flags[i] = out_live
+    for op, live_flag in zip(ops, ops_live_flags):
+        if live_flag:
+            yield op
+
+
+def tree_reduction_ops(item_count):
+    assert item_count >= 1
+    needed_outputs = (i == item_count - 1 for i in range(item_count))
+    return partial_prefix_sum_ops(needed_outputs)
+
+
+def tree_reduction(items, fn=operator.add):
+    items = list(items)
+    for op in tree_reduction_ops(len(items)):
+        items[op.out] = fn(items[op.lhs], items[op.rhs])
+    return items[-1]
+
+
+if __name__ == "__main__":
+    print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
+          "\n"
+          "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
+          "\n\n")
+    print(render_prefix_sum_diagram(16, work_efficient=False))
+    print()
+    print()
+    print("the work-efficient algorithm, matches the diagram in wikipedia:")
+    print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
+    print()
+    print(render_prefix_sum_diagram(16, work_efficient=True))
+    print()
+    print()
index ec592c83a56e8ea3e7416163dce28d9c1615d9da..36512024ff7c1ed26b22eb957c40de2e1409469c 100644 (file)
@@ -23,8 +23,7 @@
 # TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
 # MODIFICATIONS.
 
-from nmigen.compat.fhdl.specials import Memory
-from nmigen import Module, Signal, Mux, Elaboratable
+from nmigen import Module, Signal, Mux, Elaboratable, Memory
 from nmigen.utils import bits_for
 from nmigen.cli import main
 from nmigen.lib.fifo import FIFOInterface
@@ -62,8 +61,8 @@ class Queue(FIFOInterface, Elaboratable):
         m = Module()
 
         # set up an SRAM.  XXX bug in Memory: cannot create SRAM of depth 1
-        ram = Memory(self.width, self.depth if self.depth > 1 else 2)
-        m.submodules.ram = ram
+        ram = Memory(width=self.width,
+                     depth=self.depth if self.depth > 1 else 2)
         m.submodules.ram_read = ram_read = ram.read_port(domain="comb")
         m.submodules.ram_write = ram_write = ram.write_port()
 
@@ -74,17 +73,19 @@ class Queue(FIFOInterface, Elaboratable):
         # deq is "dequeue" (data out, aka "next stage")
         p_o_ready = self.w_rdy
         p_i_valid = self.w_en
-        enq_data = self.w_data # aka p_data_i
+        enq_data = self.w_data  # aka p_i_data
 
         n_o_valid = self.r_rdy
         n_i_ready = self.r_en
-        deq_data = self.r_data # aka n_data_o
+        deq_data = self.r_data  # aka n_o_data
 
         # intermediaries
         ptr_width = bits_for(self.depth - 1) if self.depth > 1 else 0
-        enq_ptr = Signal(ptr_width) # cyclic pointer to "insert" point (wrport)
-        deq_ptr = Signal(ptr_width) # cyclic pointer to "remove" point (rdport)
-        maybe_full = Signal() # not reset_less (set by sync)
+        # cyclic pointer to "insert" point (wrport)
+        enq_ptr = Signal(ptr_width)
+        # cyclic pointer to "remove" point (rdport)
+        deq_ptr = Signal(ptr_width)
+        maybe_full = Signal()  # not reset_less (set by sync)
 
         # temporaries
         do_enq = Signal(reset_less=True)
@@ -96,17 +97,17 @@ class Queue(FIFOInterface, Elaboratable):
         enq_max = Signal(reset_less=True)
         deq_max = Signal(reset_less=True)
 
-        m.d.comb += [ptr_match.eq(enq_ptr == deq_ptr), # read-ptr = write-ptr
+        m.d.comb += [ptr_match.eq(enq_ptr == deq_ptr),  # read-ptr = write-ptr
                      ptr_diff.eq(enq_ptr - deq_ptr),
                      enq_max.eq(enq_ptr == self.depth - 1),
                      deq_max.eq(deq_ptr == self.depth - 1),
                      empty.eq(ptr_match & ~maybe_full),
                      full.eq(ptr_match & maybe_full),
-                     do_enq.eq(p_o_ready & p_i_valid), # write conditions ok
-                     do_deq.eq(n_i_ready & n_o_valid), # read conditions ok
+                     do_enq.eq(p_o_ready & p_i_valid),  # write conditions ok
+                     do_deq.eq(n_i_ready & n_o_valid),  # read conditions ok
 
                      # set r_rdy and w_rdy (NOTE: see pipe mode below)
-                     n_o_valid.eq(~empty), # cannot read if empty!
+                     n_o_valid.eq(~empty),  # cannot read if empty!
                      p_o_ready.eq(~full),  # cannot write if full!
 
                      # set up memory and connect to input and output
@@ -114,8 +115,9 @@ class Queue(FIFOInterface, Elaboratable):
                      ram_write.data.eq(enq_data),
                      ram_write.en.eq(do_enq),
                      ram_read.addr.eq(deq_ptr),
-                     deq_data.eq(ram_read.data) # NOTE: overridden in fwft mode
-                    ]
+                     # NOTE: overridden in fwft mode
+                     deq_data.eq(ram_read.data)
+                     ]
 
         # under write conditions, SRAM write-pointer moves on next clock
         with m.If(do_enq):
index 96a49826029c4a8fa85a03d2d863e2bf0b18ffdf..1afa05dc32f72a8f9c844ac64d8ced02e6d0e28d 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -17,14 +18,23 @@ from nmigen import Signal, Module, Elaboratable, Mux, Cat
 from nmigen.cli import main
 
 
-class RippleLSB(Elaboratable):
-    """RippleLSB
+class Ripple(Elaboratable):
+    """Ripple
 
-    based on a partition mask, the LSB is "rippled" (duplicated)
-    up to the beginning of the next partition.
+    starting from certain bits (marked by "gates") that bit is "rippled"
+    up to the point where a gate bit is no longer set.  ripple direction can
+    be set by start_lsb.
+
+    if start_lsb=True:
+        gates      =>  1 1 0 0 0 1 1
+        (ripples)     <<<< xxx <<<<<
+        results_in => 0 0 1 0 1 0 0 1
+        output     => 1 1 1 0 0 1 1 1
     """
-    def __init__(self, width):
+
+    def __init__(self, width, start_lsb=True):
         self.width = width
+        self.start_lsb = start_lsb
         self.results_in = Signal(width, reset_less=True)
         self.gates = Signal(width-1, reset_less=True)
         self.output = Signal(width, reset_less=True)
@@ -34,16 +44,43 @@ class RippleLSB(Elaboratable):
         comb = m.d.comb
         width = self.width
 
-        current_result = self.results_in[0]
-        comb += self.output[0].eq(current_result)
+        results_in = list(self.results_in)
+        if not self.start_lsb:
+            results_in = reversed(results_in)
+        l = [results_in[0]]
 
         for i in range(width-1):
-            cur = Mux(self.gates[i], self.results_in[i+1], self.output[i])
-            comb += self.output[i+1].eq(cur)
+            l.append(Mux(self.gates[i], results_in[i+1], self.output[i]))
+
+        if not self.start_lsb:
+            l = reversed(l)
+        comb += self.output.eq(Cat(*l))
 
         return m
 
 
+class RippleLSB(Ripple):
+    """RippleLSB
+
+    based on a partition mask, the LSB is "rippled" (duplicated)
+    up to the beginning of the next partition.
+    """
+
+    def __init__(self, width):
+        Ripple.__init__(self, width, start_lsb=True)
+
+
+class RippleMSB(Ripple):
+    """RippleMSB
+
+    based on a partition mask, the MSB is "rippled" (duplicated)
+    down to the beginning of the next partition.
+    """
+
+    def __init__(self, width):
+        Ripple.__init__(self, width, start_lsb=False)
+
+
 class MoveMSBDown(Elaboratable):
     """MoveMSBDown
 
@@ -53,6 +90,7 @@ class MoveMSBDown(Elaboratable):
     into its own useful module), then ANDs the (new) LSB with the
     partition mask to isolate it.
     """
+
     def __init__(self, width):
         self.width = width
         self.results_in = Signal(width, reset_less=True)
@@ -66,14 +104,14 @@ class MoveMSBDown(Elaboratable):
         intermed = Signal(width, reset_less=True)
 
         # first propagate MSB down until the nearest partition gate
-        comb += intermed[-1].eq(self.results_in[-1]) # start at MSB
+        comb += intermed[-1].eq(self.results_in[-1])  # start at MSB
         for i in range(width-2, -1, -1):
             cur = Mux(self.gates[i], self.results_in[i], intermed[i+1])
             comb += intermed[i].eq(cur)
 
         # now only select those bits where the mask starts
-        out = [intermed[0]] # LSB of first part always set
-        for i in range(width-1): # length of partition gates
+        out = [intermed[0]]  # LSB of first part always set
+        for i in range(width-1):  # length of partition gates
             out.append(self.gates[i] & intermed[i+1])
         comb += self.output.eq(Cat(*out))
 
@@ -85,4 +123,3 @@ if __name__ == "__main__":
     # then check with yosys "read_ilang ripple.il; show top"
     alu = MoveMSBDown(width=4)
     main(alu, ports=[alu.results_in, alu.gates, alu.output])
-
index 2b2e3a2af2139c78b07c32ade0e1ea502516498e..21ecb743e8ac94f9ab531378f09f856323f58771 100644 (file)
@@ -41,13 +41,13 @@ except ImportError:
                                        Delay, Settle, Tick, Passive)
 
 nmigen_sim_environ_variable = os.environ.get("NMIGEN_SIM_MODE") \
-                              or "pysim"
+    or "pysim"
 """Detected run-time engine from environment"""
 
 
 def Simulator(*args, **kwargs):
     """Wrapper that allows run-time selection of simulator engine"""
-    if detected_new_api:
+    if detected_new_api and 'engine' not in kwargs:
         kwargs['engine'] = nmigen_sim_environ_variable
     return RealSimulator(*args, **kwargs)
 
diff --git a/src/nmutil/sim_util.py b/src/nmutil/sim_util.py
new file mode 100644 (file)
index 0000000..d5678c2
--- /dev/null
@@ -0,0 +1,46 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from contextlib import contextmanager
+from hashlib import sha256
+
+from nmigen.hdl.ir import Fragment, ClockDomain
+from nmutil.get_test_path import get_test_path
+from nmigen.sim import Simulator
+from nmigen.back.rtlil import convert
+
+
+def hash_256(v):
+    return int.from_bytes(
+        sha256(bytes(v, encoding='utf-8')).digest(),
+        byteorder='little'
+    )
+
+
+def write_il(test_case, dut, ports=()):
+    # only elaborate once, cuz users' stupid code breaks if elaborating twice
+    dut = Fragment.get(dut, platform=None)
+    if "sync" not in dut.domains:
+        dut.add_domains(ClockDomain("sync"))
+    path = get_test_path(test_case, "sim_test_out")
+    path.parent.mkdir(parents=True, exist_ok=True)
+    il_path = path.with_suffix(".il")
+    il_path.write_text(convert(dut, ports=ports), encoding="utf-8")
+    return dut, path
+
+
+@contextmanager
+def do_sim(test_case, dut, traces=(), ports=None):
+    if ports is None:
+        ports = traces
+    dut, path = write_il(test_case, dut, ports)
+    sim = Simulator(dut)
+    vcd_path = path.with_suffix(".vcd")
+    gtkw_path = path.with_suffix(".gtkw")
+    with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
+                       gtkw_path.open("wt", encoding="utf-8"),
+                       traces=traces):
+        yield sim
index 6f0d155fb5a62b4d62c6cbb9d494207cb0cea6a3..eb57e02af5beb86de557283ef368380f3c3eae9b 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ Pipeline API.  For multi-input and multi-output variants, see multipipe.
 
     This work is funded through NLnet under Grant 2019-02-012
@@ -160,11 +161,13 @@ class RecordBasedStage(Stage):
         honestly it's a lot easier just to create a direct Records-based
         class (see ExampleAddRecordStage)
     """
+
     def __init__(self, in_shape, out_shape, processfn, setupfn=None):
         self.in_shape = in_shape
         self.out_shape = out_shape
         self.__process = processfn
         self.__setup = setupfn
+
     def ispec(self): return Record(self.in_shape)
     def ospec(self): return Record(self.out_shape)
     def process(seif, i): return self.__process(i)
@@ -179,6 +182,7 @@ class PassThroughStage(StageCls):
         (many APIs would potentially use a static "wrap" method in e.g.
          StageCls to achieve a similar effect)
     """
+
     def __init__(self, iospecfn): self.iospecfn = iospecfn
     def ispec(self): return self.iospecfn()
     def ospec(self): return self.iospecfn()
@@ -194,6 +198,7 @@ class ControlBase(StageHelper, Elaboratable):
         *BYPASSES* a ControlBase instance ready/valid signalling, which
         clearly should not be done without a really, really good reason.
     """
+
     def __init__(self, stage=None, in_multi=None, stage_ctl=False, maskwid=0):
         """ Base class containing ready/valid/data to previous and next stages
 
@@ -201,11 +206,11 @@ class ControlBase(StageHelper, Elaboratable):
             * n: contains ready/valid to the next stage
 
             Except when calling Controlbase.connect(), user must also:
-            * add data_i member to PrevControl (p) and
-            * add data_o member to NextControl (n)
+            * add i_data member to PrevControl (p) and
+            * add o_data member to NextControl (n)
             Calling ControlBase._new_data is a good way to do that.
         """
-        print ("ControlBase", self, stage, in_multi, stage_ctl)
+        print("ControlBase", self, stage, in_multi, stage_ctl)
         StageHelper.__init__(self, stage)
 
         # set up input and output IO ACK (prev/next ready/valid)
@@ -217,13 +222,13 @@ class ControlBase(StageHelper, Elaboratable):
             self._new_data("data")
 
     def _new_data(self, name):
-        """ allocates new data_i and data_o
+        """ allocates new i_data and o_data
         """
-        self.p.data_i, self.n.data_o = self.new_specs(name)
+        self.p.i_data, self.n.o_data = self.new_specs(name)
 
     @property
     def data_r(self):
-        return self.process(self.p.data_i)
+        return self.process(self.p.i_data)
 
     def connect_to_next(self, nxt):
         """ helper function to connect to the next stage data/valid/ready.
@@ -254,9 +259,9 @@ class ControlBase(StageHelper, Elaboratable):
                        v    |  v    |  v     |
                      out---in out--in out---in
 
-            Also takes care of allocating data_i/data_o, by looking up
+            Also takes care of allocating i_data/o_data, by looking up
             the data spec for each end of the pipechain.  i.e It is NOT
-            necessary to allocate self.p.data_i or self.n.data_o manually:
+            necessary to allocate self.p.i_data or self.n.o_data manually:
             this is handled AUTOMATICALLY, here.
 
             Basically this function is the direct equivalent of StageChain,
@@ -284,19 +289,19 @@ class ControlBase(StageHelper, Elaboratable):
         """
         assert len(pipechain) > 0, "pipechain must be non-zero length"
         assert self.stage is None, "do not use connect with a stage"
-        eqs = [] # collated list of assignment statements
+        eqs = []  # collated list of assignment statements
 
         # connect inter-chain
         for i in range(len(pipechain)-1):
             pipe1 = pipechain[i]                # earlier
             pipe2 = pipechain[i+1]              # later (by 1)
-            eqs += pipe1.connect_to_next(pipe2) # earlier n to later p
+            eqs += pipe1.connect_to_next(pipe2)  # earlier n to later p
 
         # connect front and back of chain to ourselves
         front = pipechain[0]                # first in chain
         end = pipechain[-1]                 # last in chain
-        self.set_specs(front, end) # sets up ispec/ospec functions
-        self._new_data("chain") # NOTE: REPLACES existing data
+        self.set_specs(front, end)  # sets up ispec/ospec functions
+        self._new_data("chain")  # NOTE: REPLACES existing data
         eqs += front._connect_in(self)      # front p to our p
         eqs += end._connect_out(self)       # end n   to our n
 
@@ -305,11 +310,11 @@ class ControlBase(StageHelper, Elaboratable):
     def set_input(self, i):
         """ helper function to set the input data (used in unit tests)
         """
-        return nmoperator.eq(self.p.data_i, i)
+        return nmoperator.eq(self.p.i_data, i)
 
     def __iter__(self):
-        yield from self.p # yields ready/valid/data (data also gets yielded)
-        yield from self.n # ditto
+        yield from self.p  # yields ready/valid/data (data also gets yielded)
+        yield from self.n  # ditto
 
     def ports(self):
         return list(self)
@@ -321,7 +326,7 @@ class ControlBase(StageHelper, Elaboratable):
         m.submodules.p = self.p
         m.submodules.n = self.n
 
-        self.setup(m, self.p.data_i)
+        self.setup(m, self.p.i_data)
 
         if not self.p.stage_ctl:
             return m
@@ -345,13 +350,13 @@ class BufferedHandshake(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                             process --->----^
                               |             |
                               +-- r_data ->-+
 
-        input data p.data_i is read (only), is processed and goes into an
+        input data p.i_data is read (only), is processed and goes into an
         intermediate result store [process()].  this is updated combinatorially.
 
         in a non-stall condition, the intermediate result will go into the
@@ -383,36 +388,39 @@ class BufferedHandshake(ControlBase):
         por_pivn = Signal(reset_less=True)
         npnn = Signal(reset_less=True)
         self.m.d.comb += [p_i_valid.eq(self.p.i_valid_test),
-                     o_n_validn.eq(~self.n.o_valid),
-                     n_i_ready.eq(self.n.i_ready_test),
-                     nir_por.eq(n_i_ready & self.p._o_ready),
-                     nir_por_n.eq(n_i_ready & ~self.p._o_ready),
-                     nir_novn.eq(n_i_ready | o_n_validn),
-                     nirn_novn.eq(~n_i_ready & o_n_validn),
-                     npnn.eq(nir_por | nirn_novn),
-                     por_pivn.eq(self.p._o_ready & ~p_i_valid)
-        ]
+                          o_n_validn.eq(~self.n.o_valid),
+                          n_i_ready.eq(self.n.i_ready_test),
+                          nir_por.eq(n_i_ready & self.p._o_ready),
+                          nir_por_n.eq(n_i_ready & ~self.p._o_ready),
+                          nir_novn.eq(n_i_ready | o_n_validn),
+                          nirn_novn.eq(~n_i_ready & o_n_validn),
+                          npnn.eq(nir_por | nirn_novn),
+                          por_pivn.eq(self.p._o_ready & ~p_i_valid)
+                          ]
 
         # store result of processing in combinatorial temporary
         self.m.d.comb += nmoperator.eq(result, self.data_r)
 
         # if not in stall condition, update the temporary register
-        with self.m.If(self.p.o_ready): # not stalled
-            self.m.d.sync += nmoperator.eq(r_data, result) # update buffer
+        with self.m.If(self.p.o_ready):  # not stalled
+            self.m.d.sync += nmoperator.eq(r_data, result)  # update buffer
 
         # data pass-through conditions
         with self.m.If(npnn):
-            data_o = self._postprocess(result) # XXX TBD, does nothing right now
-            self.m.d.sync += [self.n.o_valid.eq(p_i_valid), # valid if p_valid
-                              nmoperator.eq(self.n.data_o, data_o), # update out
-                             ]
+            # XXX TBD, does nothing right now
+            o_data = self._postprocess(result)
+            self.m.d.sync += [self.n.o_valid.eq(p_i_valid),  # valid if p_valid
+                              # update out
+                              nmoperator.eq(self.n.o_data, o_data),
+                              ]
         # buffer flush conditions (NOTE: can override data passthru conditions)
-        with self.m.If(nir_por_n): # not stalled
+        with self.m.If(nir_por_n):  # not stalled
             # Flush the [already processed] buffer to the output port.
-            data_o = self._postprocess(r_data) # XXX TBD, does nothing right now
+            # XXX TBD, does nothing right now
+            o_data = self._postprocess(r_data)
             self.m.d.sync += [self.n.o_valid.eq(1),  # reg empty
-                              nmoperator.eq(self.n.data_o, data_o), # flush
-                             ]
+                              nmoperator.eq(self.n.o_data, o_data),  # flush
+                              ]
         # output ready conditions
         self.m.d.sync += self.p._o_ready.eq(nir_novn | por_pivn)
 
@@ -429,10 +437,11 @@ class MaskNoDelayCancellable(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                               +--process->--^
     """
+
     def __init__(self, stage, maskwid, in_multi=None, stage_ctl=False):
         ControlBase.__init__(self, stage, in_multi, stage_ctl, maskwid)
 
@@ -448,7 +457,7 @@ class MaskNoDelayCancellable(ControlBase):
         # XXX EXCEPTIONAL CIRCUMSTANCES: inspection of the data payload
         # is NOT "normal" for the Stage API.
         p_i_valid = Signal(reset_less=True)
-        #print ("self.p.data_i", self.p.data_i)
+        #print ("self.p.i_data", self.p.i_data)
         maskedout = Signal(len(self.p.mask_i), reset_less=True)
         m.d.comb += maskedout.eq(self.p.mask_i & ~self.p.stop_i)
         m.d.comb += p_i_valid.eq(maskedout.bool())
@@ -459,8 +468,9 @@ class MaskNoDelayCancellable(ControlBase):
         m.d.sync += self.n.o_valid.eq(p_i_valid)
         m.d.sync += self.n.mask_o.eq(Mux(p_i_valid, maskedout, 0))
         with m.If(p_i_valid):
-            data_o = self._postprocess(result) # XXX TBD, does nothing right now
-            m.d.sync += nmoperator.eq(self.n.data_o, data_o) # update output
+            # XXX TBD, does nothing right now
+            o_data = self._postprocess(result)
+            m.d.sync += nmoperator.eq(self.n.o_data, o_data)  # update output
 
         # output valid if
         # input always "ready"
@@ -488,12 +498,13 @@ class MaskCancellable(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                               +--process->--^
     """
+
     def __init__(self, stage, maskwid, in_multi=None, stage_ctl=False,
-                       dynamic=False):
+                 dynamic=False):
         ControlBase.__init__(self, stage, in_multi, stage_ctl, maskwid)
         self.dynamic = dynamic
         if dynamic:
@@ -515,7 +526,7 @@ class MaskCancellable(ControlBase):
             # establish if the data should be passed on.  cancellation is
             # a global signal.
             p_i_valid = Signal(reset_less=True)
-            #print ("self.p.data_i", self.p.data_i)
+            #print ("self.p.i_data", self.p.i_data)
             maskedout = Signal(len(self.p.mask_i), reset_less=True)
             m.d.comb += maskedout.eq(self.p.mask_i & ~self.p.stop_i)
 
@@ -525,7 +536,7 @@ class MaskCancellable(ControlBase):
             m.d.comb += [p_i_valid.eq(self.p.i_valid_test & maskedout.bool()),
                          n_i_ready.eq(self.n.i_ready_test),
                          p_i_valid_p_o_ready.eq(p_i_valid & self.p.o_ready),
-            ]
+                         ]
 
             # if idmask nonzero, mask gets passed on (and register set).
             # register is left as-is if idmask is zero, but out-mask is set to
@@ -548,10 +559,10 @@ class MaskCancellable(ControlBase):
                 m.d.sync += r_busy.eq(1)      # output valid
             # previous invalid or not ready, however next is accepting
             with m.Elif(n_i_ready):
-                m.d.sync += r_busy.eq(0) # ...so set output invalid
+                m.d.sync += r_busy.eq(0)  # ...so set output invalid
 
             # output set combinatorially from latch
-            m.d.comb += nmoperator.eq(self.n.data_o, r_latch)
+            m.d.comb += nmoperator.eq(self.n.o_data, r_latch)
 
             m.d.comb += self.n.o_valid.eq(r_busy)
             # if next is ready, so is previous
@@ -567,7 +578,7 @@ class MaskCancellable(ControlBase):
             m.d.comb += self.p._o_ready.eq(self.n.i_ready_test)
             m.d.comb += self.n.stop_o.eq(self.p.stop_i)
             m.d.comb += self.n.mask_o.eq(self.p.mask_i)
-            m.d.comb += nmoperator.eq(self.n.data_o, data_r)
+            m.d.comb += nmoperator.eq(self.n.o_data, data_r)
 
         return self.m
 
@@ -580,7 +591,7 @@ class SimpleHandshake(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                               +--process->--^
         Truth Table
@@ -594,23 +605,23 @@ class SimpleHandshake(ControlBase):
         -------   -    -     - -
         0 0 0 0   0    0    >0 0    reg
         0 0 0 1   0    1    >1 0    reg
-        0 0 1 0   0    0     0 1    process(data_i)
-        0 0 1 1   0    0     0 1    process(data_i)
+        0 0 1 0   0    0     0 1    process(i_data)
+        0 0 1 1   0    0     0 1    process(i_data)
         -------   -    -     - -
         0 1 0 0   0    0    >0 0    reg
         0 1 0 1   0    1    >1 0    reg
-        0 1 1 0   0    0     0 1    process(data_i)
-        0 1 1 1   0    0     0 1    process(data_i)
+        0 1 1 0   0    0     0 1    process(i_data)
+        0 1 1 1   0    0     0 1    process(i_data)
         -------   -    -     - -
         1 0 0 0   0    0    >0 0    reg
         1 0 0 1   0    1    >1 0    reg
-        1 0 1 0   0    0     0 1    process(data_i)
-        1 0 1 1   0    0     0 1    process(data_i)
+        1 0 1 0   0    0     0 1    process(i_data)
+        1 0 1 1   0    0     0 1    process(i_data)
         -------   -    -     - -
-        1 1 0 0   1    0     1 0    process(data_i)
-        1 1 0 1   1    1     1 0    process(data_i)
-        1 1 1 0   1    0     1 1    process(data_i)
-        1 1 1 1   1    0     1 1    process(data_i)
+        1 1 0 0   1    0     1 0    process(i_data)
+        1 1 0 1   1    1     1 0    process(i_data)
+        1 1 1 0   1    0     1 1    process(i_data)
+        1 1 1 1   1    0     1 1    process(i_data)
         -------   -    -     - -
     """
 
@@ -627,24 +638,26 @@ class SimpleHandshake(ControlBase):
         m.d.comb += [p_i_valid.eq(self.p.i_valid_test),
                      n_i_ready.eq(self.n.i_ready_test),
                      p_i_valid_p_o_ready.eq(p_i_valid & self.p.o_ready),
-        ]
+                     ]
 
         # store result of processing in combinatorial temporary
         m.d.comb += nmoperator.eq(result, self.data_r)
 
         # previous valid and ready
         with m.If(p_i_valid_p_o_ready):
-            data_o = self._postprocess(result) # XXX TBD, does nothing right now
+            # XXX TBD, does nothing right now
+            o_data = self._postprocess(result)
             m.d.sync += [r_busy.eq(1),      # output valid
-                         nmoperator.eq(self.n.data_o, data_o), # update output
-                        ]
+                         nmoperator.eq(self.n.o_data, o_data),  # update output
+                         ]
         # previous invalid or not ready, however next is accepting
         with m.Elif(n_i_ready):
-            data_o = self._postprocess(result) # XXX TBD, does nothing right now
-            m.d.sync += [nmoperator.eq(self.n.data_o, data_o)]
+            # XXX TBD, does nothing right now
+            o_data = self._postprocess(result)
+            m.d.sync += [nmoperator.eq(self.n.o_data, o_data)]
             # TODO: could still send data here (if there was any)
-            #m.d.sync += self.n.o_valid.eq(0) # ...so set output invalid
-            m.d.sync += r_busy.eq(0) # ...so set output invalid
+            # m.d.sync += self.n.o_valid.eq(0) # ...so set output invalid
+            m.d.sync += r_busy.eq(0)  # ...so set output invalid
 
         m.d.comb += self.n.o_valid.eq(r_busy)
         # if next is ready, so is previous
@@ -669,7 +682,7 @@ class UnbufferedPipeline(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                             r_data        result
                               |             |
@@ -677,9 +690,9 @@ class UnbufferedPipeline(ControlBase):
 
         Attributes:
         -----------
-        p.data_i : StageInput, shaped according to ispec
+        p.i_data : StageInput, shaped according to ispec
             The pipeline input
-        p.data_o : StageOutput, shaped according to ospec
+        p.o_data : StageOutput, shaped according to ospec
             The pipeline output
         r_data : input_shape according to ispec
             A temporary (buffered) copy of a prior (valid) input.
@@ -713,10 +726,10 @@ class UnbufferedPipeline(ControlBase):
         1 0 1 0   0    1 1    reg
         1 0 1 1   0    1 1    reg
         -------   -    - -
-        1 1 0 0   0    1 1    process(data_i)
-        1 1 0 1   1    1 0    process(data_i)
-        1 1 1 0   0    1 1    process(data_i)
-        1 1 1 1   0    1 1    process(data_i)
+        1 1 0 0   0    1 1    process(i_data)
+        1 1 0 1   1    1 0    process(i_data)
+        1 1 1 0   0    1 1    process(i_data)
+        1 1 1 1   0    1 1    process(i_data)
         -------   -    - -
 
         Note: PoR is *NOT* involved in the above decision-making.
@@ -725,8 +738,8 @@ class UnbufferedPipeline(ControlBase):
     def elaborate(self, platform):
         self.m = m = ControlBase.elaborate(self, platform)
 
-        data_valid = Signal() # is data valid or not
-        r_data = _spec(self.stage.ospec, "r_tmp") # output type
+        data_valid = Signal()  # is data valid or not
+        r_data = _spec(self.stage.ospec, "r_tmp")  # output type
 
         # some temporaries
         p_i_valid = Signal(reset_less=True)
@@ -742,8 +755,8 @@ class UnbufferedPipeline(ControlBase):
 
         with m.If(pv):
             m.d.sync += nmoperator.eq(r_data, self.data_r)
-        data_o = self._postprocess(r_data) # XXX TBD, does nothing right now
-        m.d.comb += nmoperator.eq(self.n.data_o, data_o)
+        o_data = self._postprocess(r_data)  # XXX TBD, does nothing right now
+        m.d.comb += nmoperator.eq(self.n.o_data, o_data)
 
         return self.m
 
@@ -764,14 +777,14 @@ class UnbufferedPipeline2(ControlBase):
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
-        stage-1   p.data_i  >>in   stage   n.data_o  out>>   stage+1
+        stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |    |
                               +- process-> buf <-+
         Attributes:
         -----------
-        p.data_i : StageInput, shaped according to ispec
+        p.i_data : StageInput, shaped according to ispec
             The pipeline input
-        p.data_o : StageOutput, shaped according to ospec
+        p.o_data : StageOutput, shaped according to ospec
             The pipeline output
         buf : output_shape according to ospec
             A temporary (buffered) copy of a valid output
@@ -785,25 +798,25 @@ class UnbufferedPipeline2(ControlBase):
         V R R V        V R
 
         -------   -    - -
-        0 0 0 0   0    0 1   process(data_i)
+        0 0 0 0   0    0 1   process(i_data)
         0 0 0 1   1    1 0   reg (odata, unchanged)
-        0 0 1 0   0    0 1   process(data_i)
-        0 0 1 1   0    0 1   process(data_i)
+        0 0 1 0   0    0 1   process(i_data)
+        0 0 1 1   0    0 1   process(i_data)
         -------   -    - -
-        0 1 0 0   0    0 1   process(data_i)
+        0 1 0 0   0    0 1   process(i_data)
         0 1 0 1   1    1 0   reg (odata, unchanged)
-        0 1 1 0   0    0 1   process(data_i)
-        0 1 1 1   0    0 1   process(data_i)
+        0 1 1 0   0    0 1   process(i_data)
+        0 1 1 1   0    0 1   process(i_data)
         -------   -    - -
-        1 0 0 0   0    1 1   process(data_i)
+        1 0 0 0   0    1 1   process(i_data)
         1 0 0 1   1    1 0   reg (odata, unchanged)
-        1 0 1 0   0    1 1   process(data_i)
-        1 0 1 1   0    1 1   process(data_i)
+        1 0 1 0   0    1 1   process(i_data)
+        1 0 1 1   0    1 1   process(i_data)
         -------   -    - -
-        1 1 0 0   0    1 1   process(data_i)
+        1 1 0 0   0    1 1   process(i_data)
         1 1 0 1   1    1 0   reg (odata, unchanged)
-        1 1 1 0   0    1 1   process(data_i)
-        1 1 1 1   0    1 1   process(data_i)
+        1 1 1 0   0    1 1   process(i_data)
+        1 1 1 1   0    1 1   process(i_data)
         -------   -    - -
 
         Note: PoR is *NOT* involved in the above decision-making.
@@ -812,8 +825,8 @@ class UnbufferedPipeline2(ControlBase):
     def elaborate(self, platform):
         self.m = m = ControlBase.elaborate(self, platform)
 
-        buf_full = Signal() # is data valid or not
-        buf = _spec(self.stage.ospec, "r_tmp") # output type
+        buf_full = Signal()  # is data valid or not
+        buf = _spec(self.stage.ospec, "r_tmp")  # output type
 
         # some temporaries
         p_i_valid = Signal(reset_less=True)
@@ -823,10 +836,10 @@ class UnbufferedPipeline2(ControlBase):
         m.d.comb += self.p._o_ready.eq(~buf_full)
         m.d.sync += buf_full.eq(~self.n.i_ready_test & self.n.o_valid)
 
-        data_o = Mux(buf_full, buf, self.data_r)
-        data_o = self._postprocess(data_o) # XXX TBD, does nothing right now
-        m.d.comb += nmoperator.eq(self.n.data_o, data_o)
-        m.d.sync += nmoperator.eq(buf, self.n.data_o)
+        o_data = Mux(buf_full, buf, self.data_r)
+        o_data = self._postprocess(o_data)  # XXX TBD, does nothing right now
+        m.d.comb += nmoperator.eq(self.n.o_data, o_data)
+        m.d.sync += nmoperator.eq(buf, self.n.o_data)
 
         return self.m
 
@@ -867,7 +880,7 @@ class PassThroughHandshake(ControlBase):
     def elaborate(self, platform):
         self.m = m = ControlBase.elaborate(self, platform)
 
-        r_data = _spec(self.stage.ospec, "r_tmp") # output type
+        r_data = _spec(self.stage.ospec, "r_tmp")  # output type
 
         # temporaries
         p_i_valid = Signal(reset_less=True)
@@ -875,22 +888,23 @@ class PassThroughHandshake(ControlBase):
         m.d.comb += p_i_valid.eq(self.p.i_valid_test)
         m.d.comb += pvr.eq(p_i_valid & self.p.o_ready)
 
-        m.d.comb += self.p.o_ready.eq(~self.n.o_valid |  self.n.i_ready_test)
-        m.d.sync += self.n.o_valid.eq(p_i_valid       | ~self.p.o_ready)
+        m.d.comb += self.p.o_ready.eq(~self.n.o_valid | self.n.i_ready_test)
+        m.d.sync += self.n.o_valid.eq(p_i_valid | ~self.p.o_ready)
 
         odata = Mux(pvr, self.data_r, r_data)
         m.d.sync += nmoperator.eq(r_data, odata)
-        r_data = self._postprocess(r_data) # XXX TBD, does nothing right now
-        m.d.comb += nmoperator.eq(self.n.data_o, r_data)
+        r_data = self._postprocess(r_data)  # XXX TBD, does nothing right now
+        m.d.comb += nmoperator.eq(self.n.o_data, r_data)
 
         return m
 
 
 class RegisterPipeline(UnbufferedPipeline):
     """ A pipeline stage that delays by one clock cycle, creating a
-        sync'd latch out of data_o and o_valid as an indirect byproduct
+        sync'd latch out of o_data and o_valid as an indirect byproduct
         of using PassThroughStage
     """
+
     def __init__(self, iospecfn):
         UnbufferedPipeline.__init__(self, PassThroughStage(iospecfn))
 
@@ -899,10 +913,11 @@ class FIFOControl(ControlBase):
     """ FIFO Control.  Uses Queue to store data, coincidentally
         happens to have same valid/ready signalling as Stage API.
 
-        data_i -> fifo.din -> FIFO -> fifo.dout -> data_o
+        i_data -> fifo.din -> FIFO -> fifo.dout -> o_data
     """
+
     def __init__(self, depth, stage, in_multi=None, stage_ctl=False,
-                                     fwft=True, pipe=False):
+                 fwft=True, pipe=False):
         """ FIFO Control
 
             * :depth: number of entries in the FIFO
@@ -923,7 +938,7 @@ class FIFOControl(ControlBase):
             data is processed (and located) as follows:
 
             self.p  self.stage temp    fn temp  fn  temp  fp   self.n
-            data_i->process()->result->cat->din.FIFO.dout->cat(data_o)
+            i_data->process()->result->cat->din.FIFO.dout->cat(o_data)
 
             yes, really: cat produces a Cat() which can be assigned to.
             this is how the FIFO gets de-catted without needing a de-cat
@@ -937,36 +952,37 @@ class FIFOControl(ControlBase):
     def elaborate(self, platform):
         self.m = m = ControlBase.elaborate(self, platform)
 
-        # make a FIFO with a signal of equal width to the data_o.
-        (fwidth, _) = nmoperator.shape(self.n.data_o)
+        # make a FIFO with a signal of equal width to the o_data.
+        (fwidth, _) = nmoperator.shape(self.n.o_data)
         fifo = Queue(fwidth, self.fdepth, fwft=self.fwft, pipe=self.pipe)
         m.submodules.fifo = fifo
 
-        def processfn(data_i):
+        def processfn(i_data):
             # store result of processing in combinatorial temporary
             result = _spec(self.stage.ospec, "r_temp")
-            m.d.comb += nmoperator.eq(result, self.process(data_i))
+            m.d.comb += nmoperator.eq(result, self.process(i_data))
             return nmoperator.cat(result)
 
-        ## prev: make the FIFO (Queue object) "look" like a PrevControl...
+        # prev: make the FIFO (Queue object) "look" like a PrevControl...
         m.submodules.fp = fp = PrevControl()
-        fp.i_valid, fp._o_ready, fp.data_i = fifo.w_en, fifo.w_rdy, fifo.w_data
+        fp.i_valid, fp._o_ready, fp.i_data = fifo.w_en, fifo.w_rdy, fifo.w_data
         m.d.comb += fp._connect_in(self.p, fn=processfn)
 
         # next: make the FIFO (Queue object) "look" like a NextControl...
         m.submodules.fn = fn = NextControl()
-        fn.o_valid, fn.i_ready, fn.data_o  = fifo.r_rdy, fifo.r_en, fifo.r_data
+        fn.o_valid, fn.i_ready, fn.o_data = fifo.r_rdy, fifo.r_en, fifo.r_data
         connections = fn._connect_out(self.n, fn=nmoperator.cat)
-        valid_eq, ready_eq, data_o = connections
+        valid_eq, ready_eq, o_data = connections
 
         # ok ok so we can't just do the ready/valid eqs straight:
         # first 2 from connections are the ready/valid, 3rd is data.
         if self.fwft:
-            m.d.comb += [valid_eq, ready_eq] # combinatorial on next ready/valid
+            # combinatorial on next ready/valid
+            m.d.comb += [valid_eq, ready_eq]
         else:
-            m.d.sync += [valid_eq, ready_eq] # non-fwft mode needs sync
-        data_o = self._postprocess(data_o) # XXX TBD, does nothing right now
-        m.d.comb += data_o
+            m.d.sync += [valid_eq, ready_eq]  # non-fwft mode needs sync
+        o_data = self._postprocess(o_data)  # XXX TBD, does nothing right now
+        m.d.comb += o_data
 
         return m
 
@@ -975,19 +991,23 @@ class FIFOControl(ControlBase):
 class UnbufferedPipeline(FIFOControl):
     def __init__(self, stage, in_multi=None, stage_ctl=False):
         FIFOControl.__init__(self, 1, stage, in_multi, stage_ctl,
-                                   fwft=True, pipe=False)
+                             fwft=True, pipe=False)
 
 # aka "BreakReadyStage" XXX had to set fwft=True to get it to work
+
+
 class PassThroughHandshake(FIFOControl):
     def __init__(self, stage, in_multi=None, stage_ctl=False):
         FIFOControl.__init__(self, 1, stage, in_multi, stage_ctl,
-                                   fwft=True, pipe=True)
+                             fwft=True, pipe=True)
 
 # this is *probably* BufferedHandshake, although test #997 now succeeds.
+
+
 class BufferedHandshake(FIFOControl):
     def __init__(self, stage, in_multi=None, stage_ctl=False):
         FIFOControl.__init__(self, 2, stage, in_multi, stage_ctl,
-                                   fwft=True, pipe=False)
+                             fwft=True, pipe=False)
 
 
 """
index fc5d709f43809a8d5ae34a5ebe61486264db1a62..17c4f6509852045a4b812f2121aec6b14ea695ef 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """ Stage API
 
     This work is funded through NLnet under Grant 2019-02-012
@@ -118,10 +119,10 @@ class StageCls(metaclass=ABCMeta):
     def ispec(self): pass       # REQUIRED
     @abstractmethod
     def ospec(self): pass       # REQUIRED
-    #@abstractmethod
-    #def setup(self, m, i): pass # OPTIONAL
-    #@abstractmethod
-    #def process(self, i): pass  # OPTIONAL
+    # @abstractmethod
+    # def setup(self, m, i): pass # OPTIONAL
+    # @abstractmethod
+    # def process(self, i): pass  # OPTIONAL
 
 
 class Stage(metaclass=ABCMeta):
@@ -140,12 +141,12 @@ class Stage(metaclass=ABCMeta):
     @abstractmethod
     def ospec(): pass
 
-    #@staticmethod
-    #@abstractmethod
+    # @staticmethod
+    # @abstractmethod
     #def setup(m, i): pass
 
-    #@staticmethod
-    #@abstractmethod
+    # @staticmethod
+    # @abstractmethod
     #def process(i): pass
 
 
@@ -157,6 +158,7 @@ class StageHelper(Stage):
         it differs from the stage that it wraps in that all the "optional"
         functions are provided (hence the designation "convenience wrapper")
     """
+
     def __init__(self, stage):
         self.stage = stage
         self._ispecfn = None
@@ -195,10 +197,11 @@ class StageHelper(Stage):
 
     def setup(self, m, i):
         if self.stage is not None and hasattr(self.stage, "setup"):
-            self.stage.setup(m, i)
+            if self.stage is not self:  # stop infinite recursion
+                self.stage.setup(m, i)
 
-    def _postprocess(self, i): # XXX DISABLED
-        return i # RETURNS INPUT
+    def _postprocess(self, i):  # XXX DISABLED
+        return i  # RETURNS INPUT
         if hasattr(self.stage, "postprocess"):
             return self.stage.postprocess(i)
         return i
@@ -243,6 +246,7 @@ class StageChain(StageHelper):
         side-effects (state-based / clock-based input) or conditional
         (inter-chain) dependencies, unless you really know what you are doing.
     """
+
     def __init__(self, chain, specallocate=False):
         assert len(chain) > 0, "stage chain must be non-zero length"
         self.chain = chain
@@ -262,12 +266,13 @@ class StageChain(StageHelper):
             o = _spec(ofn, cname)
             if isinstance(o, Elaboratable):
                 setattr(m.submodules, cname, o)
-            m.d.comb += nmoperator.eq(o, c.process(i)) # process input into "o"
+            # process input into "o"
+            m.d.comb += nmoperator.eq(o, c.process(i))
             if idx == len(self.chain)-1:
                 break
             ifn = self.chain[idx+1].ispec   # new input on next loop
             i = _spec(ifn, 'chainin%d' % (idx+1))
-            m.d.comb += nmoperator.eq(i, o) # assign to next input
+            m.d.comb += nmoperator.eq(i, o)  # assign to next input
         self.o = o
         return self.o                       # last loop is the output
 
@@ -280,6 +285,4 @@ class StageChain(StageHelper):
         return self.o                       # last loop is the output
 
     def process(self, i):
-        return self.o # conform to Stage API: return last-loop output
-
-
+        return self.o  # conform to Stage API: return last-loop output
index 61e9b1346e5cb1cf7402dd27b1393561cd710b0c..823a999b52052a7a11e5ddb1c845b127f1ffae22 100644 (file)
@@ -4,8 +4,8 @@
 from nmutil.nmoperator import eq
 from nmutil.iocontrol import (PrevControl, NextControl)
 from nmutil.singlepipe import (PrevControl, NextControl, ControlBase,
-                        StageCls, Stage, StageChain,
-                        BufferedHandshake, UnbufferedPipeline)
+                               StageCls, Stage, StageChain,
+                               BufferedHandshake, UnbufferedPipeline)
 
 from nmigen import Signal, Module
 from nmigen.cli import verilog, rtlil
index c242e255a6c0efe9d0e0c7e3b009fd268baec1d9..c01bfac366da498dadb5537221ee6a08dd81fc67 100644 (file)
@@ -14,9 +14,9 @@ class Shifter(Elaboratable):
 
     * "Prev" port:
 
-        * ``p_data_i``: value to be shifted
+        * ``p_i_data``: value to be shifted
 
-        * ``p_shift_i``: shift amount
+        * ``p_i_shift``: shift amount
 
         * ``op__sdir``: shift direction (0 = left, 1 = right)
 
@@ -24,19 +24,20 @@ class Shifter(Elaboratable):
 
     * "Next" port:
 
-        * ``n_data_o``: shifted value
+        * ``n_o_data``: shifted value
 
         * ``n_o_valid`` and ``n_i_ready``: handshake
     """
+
     def __init__(self, width):
         self.width = width
         """data width"""
-        self.p_data_i = Signal(width)
-        self.p_shift_i = Signal(width)
+        self.p_i_data = Signal(width)
+        self.p_i_shift = Signal(width)
         self.op__sdir = Signal()
         self.p_i_valid = Signal()
         self.p_o_ready = Signal()
-        self.n_data_o = Signal(width)
+        self.n_o_data = Signal(width)
         self.n_o_valid = Signal()
         self.n_i_ready = Signal()
 
@@ -57,8 +58,8 @@ class Shifter(Elaboratable):
         # build the data flow
         m.d.comb += [
             # connect input and output
-            shift_in.eq(self.p_data_i),
-            self.n_data_o.eq(shift_reg),
+            shift_in.eq(self.p_i_data),
+            self.n_o_data.eq(shift_reg),
             # generate shifted views of the register
             shift_left_by_1.eq(Cat(0, shift_reg[:-1])),
             shift_right_by_1.eq(Cat(shift_reg[1:], 0)),
@@ -91,7 +92,7 @@ class Shifter(Elaboratable):
                     self.p_o_ready.eq(1),
                     # keep loading the shift register and shift count
                     load.eq(1),
-                    next_count.eq(self.p_shift_i),
+                    next_count.eq(self.p_i_shift),
                 ]
                 # capture the direction bit as well
                 m.d.sync += direction.eq(self.op__sdir)
@@ -123,13 +124,13 @@ class Shifter(Elaboratable):
 
     def __iter__(self):
         yield self.op__sdir
-        yield self.p_data_i
-        yield self.p_shift_i
+        yield self.p_i_data
+        yield self.p_i_shift
         yield self.p_i_valid
         yield self.p_o_ready
         yield self.n_i_ready
         yield self.n_o_valid
-        yield self.n_data_o
+        yield self.n_o_data
 
     def ports(self):
         return list(self)
@@ -156,9 +157,9 @@ def write_gtkw_direct():
         with gtkw.group("prev port"):
             gtkw.trace(dut + "op__sdir", color=style_input)
             # demonstrates using decimal base (default is hex)
-            gtkw.trace(dut + "p_data_i[7:0]", color=style_input,
+            gtkw.trace(dut + "p_i_data[7:0]", color=style_input,
                        datafmt='dec')
-            gtkw.trace(dut + "p_shift_i[7:0]", color=style_input,
+            gtkw.trace(dut + "p_i_shift[7:0]", color=style_input,
                        datafmt='dec')
             gtkw.trace(dut + "p_i_valid", color=style_input)
             gtkw.trace(dut + "p_o_ready", color=style_output)
@@ -175,7 +176,7 @@ def write_gtkw_direct():
             gtkw.trace(dut + "count[3:0]")
             gtkw.trace(dut + "shift_reg[7:0]", datafmt='dec')
         with gtkw.group("next port"):
-            gtkw.trace(dut + "n_data_o[7:0]", color=style_output,
+            gtkw.trace(dut + "n_o_data[7:0]", color=style_output,
                        datafmt='dec')
             gtkw.trace(dut + "n_o_valid", color=style_output)
             gtkw.trace(dut + "n_i_ready", color=style_input)
@@ -225,8 +226,8 @@ def test_shifter():
         ('prev port', [
             # attach a class style for each signal
             ('op__sdir', 'in'),
-            ('p_data_i[7:0]', 'in'),
-            ('p_shift_i[7:0]', 'in'),
+            ('p_i_data[7:0]', 'in'),
+            ('p_i_shift[7:0]', 'in'),
             ('p_i_valid', 'in'),
             ('p_o_ready', 'out'),
         ]),
@@ -246,7 +247,7 @@ def test_shifter():
             'shift_reg[7:0]',
         ]),
         ('next port', [
-            ('n_data_o[7:0]', 'out'),
+            ('n_o_data[7:0]', 'out'),
             ('n_o_valid', 'out'),
             ('n_i_ready', 'in'),
         ]),
@@ -278,8 +279,8 @@ def test_shifter():
 
     def send(data, shift, direction):
         # present input data and assert i_valid
-        yield dut.p_data_i.eq(data)
-        yield dut.p_shift_i.eq(shift)
+        yield dut.p_i_data.eq(data)
+        yield dut.p_i_shift.eq(shift)
         yield dut.op__sdir.eq(direction)
         yield dut.p_i_valid.eq(1)
         yield
@@ -297,8 +298,8 @@ def test_shifter():
         yield msg.eq(1)
         # clear input data and negate p.i_valid
         yield dut.p_i_valid.eq(0)
-        yield dut.p_data_i.eq(0)
-        yield dut.p_shift_i.eq(0)
+        yield dut.p_i_data.eq(0)
+        yield dut.p_i_shift.eq(0)
         yield dut.op__sdir.eq(0)
 
     def receive(expected):
@@ -309,7 +310,7 @@ def test_shifter():
         while not (yield dut.n_o_valid):
             yield
         # read result
-        result = yield dut.n_data_o
+        result = yield dut.n_o_data
         # negate n.i_ready
         yield dut.n_i_ready.eq(0)
         # check result
index a52964085b929f3405d579d33f4a13943870ce81..e738657a08c20fce0b97f8d87be3c21b22fa88be 100644 (file)
@@ -59,25 +59,25 @@ def tbench(dut):
     yield
     # yield dut.i_p_rst.eq(0)
     yield dut.n.i_ready.eq(1)
-    yield dut.p.data_i.eq(5)
+    yield dut.p.i_data.eq(5)
     yield dut.p.i_valid.eq(1)
     yield
 
-    yield dut.p.data_i.eq(7)
+    yield dut.p.i_data.eq(7)
     yield from check_o_n_valid(dut, 0)  # effects of i_p_valid delayed
     yield
     yield from check_o_n_valid(dut, 1)  # ok *now* i_p_valid effect is felt
 
-    yield dut.p.data_i.eq(2)
+    yield dut.p.i_data.eq(2)
     yield
     # begin going into "stall" (next stage says ready)
     yield dut.n.i_ready.eq(0)
-    yield dut.p.data_i.eq(9)
+    yield dut.p.i_data.eq(9)
     yield
     yield dut.p.i_valid.eq(0)
-    yield dut.p.data_i.eq(12)
+    yield dut.p.i_data.eq(12)
     yield
-    yield dut.p.data_i.eq(32)
+    yield dut.p.i_data.eq(32)
     yield dut.n.i_ready.eq(1)
     yield
     yield from check_o_n_valid(dut, 1)  # buffer still needs to output
@@ -96,28 +96,28 @@ def tbench2(dut):
     yield
     # yield dut.p.i_rst.eq(0)
     yield dut.n.i_ready.eq(1)
-    yield dut.p.data_i.eq(5)
+    yield dut.p.i_data.eq(5)
     yield dut.p.i_valid.eq(1)
     yield
 
-    yield dut.p.data_i.eq(7)
+    yield dut.p.i_data.eq(7)
     # effects of i_p_valid delayed 2 clocks
     yield from check_o_n_valid2(dut, 0)
     yield
     # effects of i_p_valid delayed 2 clocks
     yield from check_o_n_valid2(dut, 0)
 
-    yield dut.p.data_i.eq(2)
+    yield dut.p.i_data.eq(2)
     yield
     yield from check_o_n_valid2(dut, 1)  # ok *now* i_p_valid effect is felt
     # begin going into "stall" (next stage says ready)
     yield dut.n.i_ready.eq(0)
-    yield dut.p.data_i.eq(9)
+    yield dut.p.i_data.eq(9)
     yield
     yield dut.p.i_valid.eq(0)
-    yield dut.p.data_i.eq(12)
+    yield dut.p.i_data.eq(12)
     yield
-    yield dut.p.data_i.eq(32)
+    yield dut.p.i_data.eq(32)
     yield dut.n.i_ready.eq(1)
     yield
     yield from check_o_n_valid2(dut, 1)  # buffer still needs to output
@@ -157,7 +157,7 @@ class Test3:
                     continue
                 if send and self.i != len(self.data):
                     yield self.dut.p.i_valid.eq(1)
-                    yield self.dut.p.data_i.eq(self.data[self.i])
+                    yield self.dut.p.i_data.eq(self.data[self.i])
                     self.i += 1
                 else:
                     yield self.dut.p.i_valid.eq(0)
@@ -174,17 +174,17 @@ class Test3:
                 i_n_ready = yield self.dut.n.i_ready_test
                 if not o_n_valid or not i_n_ready:
                     continue
-                data_o = yield self.dut.n.data_o
-                self.resultfn(data_o, self.data[self.o], self.i, self.o)
+                o_data = yield self.dut.n.o_data
+                self.resultfn(o_data, self.data[self.o], self.i, self.o)
                 self.o += 1
                 if self.o == len(self.data):
                     break
 
 
-def resultfn_3(data_o, expected, i, o):
-    assert data_o == expected + 1, \
+def resultfn_3(o_data, expected, i, o):
+    assert o_data == expected + 1, \
         "%d-%d data %x not match %x\n" \
-        % (i, o, data_o, expected)
+        % (i, o, o_data, expected)
 
 
 def data_placeholder():
@@ -254,14 +254,14 @@ class Test5:
                 i_n_ready = yield self.dut.n.i_ready_test
                 if not o_n_valid or not i_n_ready:
                     continue
-                if isinstance(self.dut.n.data_o, Record):
-                    data_o = {}
-                    dod = self.dut.n.data_o
+                if isinstance(self.dut.n.o_data, Record):
+                    o_data = {}
+                    dod = self.dut.n.o_data
                     for k, v in dod.fields.items():
-                        data_o[k] = yield v
+                        o_data[k] = yield v
                 else:
-                    data_o = yield self.dut.n.data_o
-                self.resultfn(data_o, self.data[self.o], self.i, self.o)
+                    o_data = yield self.dut.n.o_data
+                self.resultfn(o_data, self.data[self.o], self.i, self.o)
                 self.o += 1
                 if self.o == len(self.data):
                     break
@@ -337,25 +337,25 @@ class TestMask:
                 i_n_ready = yield self.dut.n.i_ready_test
                 if not o_n_valid or not i_n_ready:
                     continue
-                if isinstance(self.dut.n.data_o, Record):
-                    data_o = {}
-                    dod = self.dut.n.data_o
+                if isinstance(self.dut.n.o_data, Record):
+                    o_data = {}
+                    dod = self.dut.n.o_data
                     for k, v in dod.fields.items():
-                        data_o[k] = yield v
+                        o_data[k] = yield v
                 else:
-                    data_o = yield self.dut.n.data_o
-                print("recv", self.o, data_o)
-                self.resultfn(data_o, self.data[self.o], self.i, self.o)
+                    o_data = yield self.dut.n.o_data
+                print("recv", self.o, o_data)
+                self.resultfn(o_data, self.data[self.o], self.i, self.o)
                 self.o += 1
                 if self.o == len(self.data):
                     break
 
 
-def resultfn_5(data_o, expected, i, o):
+def resultfn_5(o_data, expected, i, o):
     res = expected[0] + expected[1]
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d data %x not match %s\n" \
-        % (i, o, data_o, repr(expected))
+        % (i, o, o_data, repr(expected))
 
 
 def tbench4(dut):
@@ -373,7 +373,7 @@ def tbench4(dut):
         if o_p_ready:
             if send and i != len(data):
                 yield dut.p.i_valid.eq(1)
-                yield dut.p.data_i.eq(data[i])
+                yield dut.p.i_data.eq(data[i])
                 i += 1
             else:
                 yield dut.p.i_valid.eq(0)
@@ -381,9 +381,9 @@ def tbench4(dut):
         o_n_valid = yield dut.n.o_valid
         i_n_ready = yield dut.n.i_ready_test
         if o_n_valid and i_n_ready:
-            data_o = yield dut.n.data_o
-            assert data_o == data[o] + 2, "%d-%d data %x not match %x\n" \
-                % (i, o, data_o, data[o])
+            o_data = yield dut.n.o_data
+            assert o_data == data[o] + 2, "%d-%d data %x not match %x\n" \
+                % (i, o, o_data, data[o])
             o += 1
             if o == len(data):
                 break
@@ -433,11 +433,11 @@ def data_chain2():
     return data
 
 
-def resultfn_9(data_o, expected, i, o):
+def resultfn_9(o_data, expected, i, o):
     res = expected + 2
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d received data %x not match expected %x\n" \
-        % (i, o, data_o, res)
+        % (i, o, o_data, res)
 
 
 ######################################################################
@@ -525,11 +525,11 @@ class ExampleLTBufferedPipeDerived(BufferedHandshake):
         BufferedHandshake.__init__(self, stage)
 
 
-def resultfn_6(data_o, expected, i, o):
+def resultfn_6(o_data, expected, i, o):
     res = 1 if expected[0] < expected[1] else 0
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d data %x not match %s\n" \
-        % (i, o, data_o, repr(expected))
+        % (i, o, o_data, repr(expected))
 
 
 ######################################################################
@@ -600,11 +600,11 @@ class ExampleAddRecordPipe(UnbufferedPipeline):
         UnbufferedPipeline.__init__(self, stage)
 
 
-def resultfn_7(data_o, expected, i, o):
+def resultfn_7(o_data, expected, i, o):
     res = (expected['src1'] + 1, expected['src2'] + 1)
-    assert data_o['src1'] == res[0] and data_o['src2'] == res[1], \
+    assert o_data['src1'] == res[0] and o_data['src2'] == res[1], \
         "%d-%d data %s not match %s\n" \
-        % (i, o, repr(data_o), repr(expected))
+        % (i, o, repr(o_data), repr(expected))
 
 
 class ExampleAddRecordPlaceHolderPipe(UnbufferedPipeline):
@@ -616,12 +616,12 @@ class ExampleAddRecordPlaceHolderPipe(UnbufferedPipeline):
         UnbufferedPipeline.__init__(self, stage)
 
 
-def resultfn_11(data_o, expected, i, o):
+def resultfn_11(o_data, expected, i, o):
     res1 = expected.src1 + 1
     res2 = expected.src2 + 1
-    assert data_o['src1'] == res1 and data_o['src2'] == res2, \
+    assert o_data['src1'] == res1 and o_data['src2'] == res2, \
         "%d-%d data %s not match %s\n" \
-        % (i, o, repr(data_o), repr(expected))
+        % (i, o, repr(o_data), repr(expected))
 
 
 ######################################################################
@@ -684,11 +684,11 @@ class TestInputAdd:
         self.op2 = op2
 
 
-def resultfn_8(data_o, expected, i, o):
+def resultfn_8(o_data, expected, i, o):
     res = expected.op1 + expected.op2  # these are a TestInputAdd instance
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d data %s res %x not match %s\n" \
-        % (i, o, repr(data_o), res, repr(expected))
+        % (i, o, repr(o_data), res, repr(expected))
 
 
 def data_2op():
@@ -762,11 +762,11 @@ def data_chain1():
     return data
 
 
-def resultfn_12(data_o, expected, i, o):
+def resultfn_12(o_data, expected, i, o):
     res = expected + 1
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d data %x not match %x\n" \
-        % (i, o, data_o, res)
+        % (i, o, o_data, res)
 
 
 ######################################################################
@@ -841,11 +841,11 @@ class PassThroughTest(PassThroughHandshake):
         PassThroughHandshake.__init__(self, stage)
 
 
-def resultfn_identical(data_o, expected, i, o):
+def resultfn_identical(o_data, expected, i, o):
     res = expected
-    assert data_o == res, \
+    assert o_data == res, \
         "%d-%d data %x not match %x\n" \
-        % (i, o, data_o, res)
+        % (i, o, o_data, res)
 
 
 ######################################################################
@@ -1230,13 +1230,13 @@ def data_chain0(n_tests):
     return data
 
 
-def resultfn_0(data_o, expected, i, o):
-    assert data_o['src1'] == expected.src1 + 2, \
+def resultfn_0(o_data, expected, i, o):
+    assert o_data['src1'] == expected.src1 + 2, \
         "src1 %x-%x received data no match\n" \
-        % (data_o['src1'], expected.src1 + 2)
-    assert data_o['src2'] == expected.src2 + 2, \
+        % (o_data['src1'], expected.src1 + 2)
+    assert o_data['src2'] == expected.src2 + 2, \
         "src2 %x-%x received data no match\n" \
-        % (data_o['src2'], expected.src2 + 2)
+        % (o_data['src2'], expected.src2 + 2)
 
 
 ######################################################################
@@ -1252,7 +1252,7 @@ def test0():
     dut = MaskCancellablePipe(maskwid)
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        dut.p.data_i.ports() + dut.n.data_o.ports()
+        dut.p.i_data.ports() + dut.n.o_data.ports()
     vl = rtlil.convert(dut, ports=ports)
     with open("test_maskchain0.il", "w") as f:
         f.write(vl)
@@ -1268,7 +1268,7 @@ def test0_1():
     dut = MaskCancellableDynamic(maskwid=maskwid)
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready]  # + \
-    #dut.p.data_i.ports() + dut.n.data_o.ports()
+    #dut.p.i_data.ports() + dut.n.o_data.ports()
     vl = rtlil.convert(dut, ports=ports)
     with open("test_maskchain0_dynamic.il", "w") as f:
         f.write(vl)
@@ -1290,7 +1290,7 @@ def notworking2():
     run_simulation(dut, tbench2(dut), vcd_name="test_bufpipe2.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpipe2.il", "w") as f:
         f.write(vl)
@@ -1334,7 +1334,7 @@ def test6():
 
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        list(dut.p.data_i) + [dut.n.data_o]
+        list(dut.p.i_data) + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_ltcomb_pipe.il", "w") as f:
         f.write(vl)
@@ -1347,8 +1347,8 @@ def test7():
     test = Test5(dut, resultfn_7, data=data)
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready,
-             dut.p.data_i.src1, dut.p.data_i.src2,
-             dut.n.data_o.src1, dut.n.data_o.src2]
+             dut.p.i_data.src1, dut.p.i_data.src2,
+             dut.n.o_data.src1, dut.n.o_data.src2]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_recordcomb_pipe.il", "w") as f:
         f.write(vl)
@@ -1370,7 +1370,7 @@ def test9():
     dut = ExampleBufPipeChain2()
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpipechain2.il", "w") as f:
         f.write(vl)
@@ -1412,7 +1412,7 @@ def test12():
                    vcd_name="test_bufpipe12.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpipe12.il", "w") as f:
         f.write(vl)
@@ -1427,7 +1427,7 @@ def test13():
                    vcd_name="test_unbufpipe13.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_unbufpipe13.il", "w") as f:
         f.write(vl)
@@ -1442,7 +1442,7 @@ def test15():
                    vcd_name="test_bufunbuf15.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufunbuf15.il", "w") as f:
         f.write(vl)
@@ -1457,7 +1457,7 @@ def test16():
                    vcd_name="test_bufunbuf16.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufunbuf16.il", "w") as f:
         f.write(vl)
@@ -1472,7 +1472,7 @@ def test17():
                    vcd_name="test_unbufpipe17.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_unbufpipe17.il", "w") as f:
         f.write(vl)
@@ -1487,7 +1487,7 @@ def test18():
                    vcd_name="test_passthru18.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_passthru18.il", "w") as f:
         f.write(vl)
@@ -1502,7 +1502,7 @@ def test19():
                    vcd_name="test_bufpass19.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpass19.il", "w") as f:
         f.write(vl)
@@ -1516,7 +1516,7 @@ def test20():
     run_simulation(dut, [test.send(), test.rcv()], vcd_name="test_fifo20.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_fifo20.il", "w") as f:
         f.write(vl)
@@ -1531,7 +1531,7 @@ def test21():
                    vcd_name="test_fifopass21.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_fifopass21.il", "w") as f:
         f.write(vl)
@@ -1546,8 +1546,8 @@ def test22():
                    vcd_name="test_addrecord22.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i.op1, dut.p.data_i.op2] + \
-        [dut.n.data_o]
+        [dut.p.i_data.op1, dut.p.i_data.op2] + \
+        [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_addrecord22.il", "w") as f:
         f.write(vl)
@@ -1562,8 +1562,8 @@ def test23():
                    vcd_name="test_addrecord23.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i.op1, dut.p.data_i.op2] + \
-        [dut.n.data_o]
+        [dut.p.i_data.op1, dut.p.i_data.op2] + \
+        [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_addrecord23.il", "w") as f:
         f.write(vl)
@@ -1576,8 +1576,8 @@ def test24():
     test = Test5(dut, resultfn_8, data=data)
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i.op1, dut.p.data_i.op2] + \
-        [dut.n.data_o]
+        [dut.p.i_data.op1, dut.p.i_data.op2] + \
+        [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_addrecord24.il", "w") as f:
         f.write(vl)
@@ -1594,7 +1594,7 @@ def test25():
                    vcd_name="test_add2pipe25.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_add2pipe25.il", "w") as f:
         f.write(vl)
@@ -1609,7 +1609,7 @@ def test997():
                    vcd_name="test_bufpass997.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpass997.il", "w") as f:
         f.write(vl)
@@ -1625,7 +1625,7 @@ def test998():
                    vcd_name="test_bufpipe14.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpipe14.il", "w") as f:
         f.write(vl)
@@ -1640,7 +1640,7 @@ def test999():
                    vcd_name="test_bufunbuf999.vcd")
     ports = [dut.p.i_valid, dut.n.i_ready,
              dut.n.o_valid, dut.p.o_ready] + \
-        [dut.p.data_i] + [dut.n.data_o]
+        [dut.p.i_data] + [dut.n.o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufunbuf999.il", "w") as f:
         f.write(vl)
index 14e34726720bebb2864a2e8f733d273680e14432..9e34379cc1d6b17d815657ab3e3357a8cc25bd8f 100644 (file)
@@ -1,44 +1,77 @@
-from nmigen import Module, Signal
-from nmigen.back.pysim import Simulator, Delay
-
-from nmutil.clz import CLZ
+from nmigen.sim import Delay
+from nmutil.clz import CLZ, clz
+from nmutil.sim_util import do_sim
 import unittest
-import math
-import random
 
 
-class CLZTestCase(unittest.TestCase):
-    def run_tst(self, inputs, width=8):
+def reference_clz(v, width):
+    assert isinstance(width, int) and 0 <= width
+    assert isinstance(v, int) and 0 <= v < 1 << width
+    msb = 1 << (width - 1)
+    retval = 0
+    while retval < width:
+        if v & msb:
+            break
+        v <<= 1
+        retval += 1
+    return retval
+
 
-        m = Module()
+class TestCLZ(unittest.TestCase):
+    def tst(self, width):
+        assert isinstance(width, int) and 0 <= width
+        dut = CLZ(width)
 
-        m.submodules.dut = dut = CLZ(width)
-        sig_in = Signal.like(dut.sig_in)
-        count = Signal.like(dut.lz)
+        def process():
+            for inp in range(1 << width):
+                expected = reference_clz(inp, width)
+                with self.subTest(inp=hex(inp), expected=expected):
+                    yield dut.sig_in.eq(inp)
+                    yield Delay(1e-6)
+                    sim_lz = yield dut.lz
+                    py_lz = clz(inp, width)
+                    with self.subTest(sim_lz=sim_lz, py_lz=py_lz):
+                        self.assertEqual(sim_lz, expected)
+                        self.assertEqual(py_lz, expected)
+        with do_sim(self, dut, [dut.sig_in, dut.lz]) as sim:
+            sim.add_process(process)
+            sim.run()
 
+    def test_1(self):
+        self.tst(1)
 
-        m.d.comb += [
-            dut.sig_in.eq(sig_in),
-            count.eq(dut.lz)]
+    def test_2(self):
+        self.tst(2)
 
-        sim = Simulator(m)
+    def test_3(self):
+        self.tst(3)
 
-        def process():
-            for i in inputs:
-                yield sig_in.eq(i)
-                yield Delay(1e-6)
-        sim.add_process(process)
-        with sim.write_vcd("clz.vcd", "clz.gtkw", traces=[
-                sig_in, count]):
-            sim.run()
+    def test_4(self):
+        self.tst(4)
+
+    def test_5(self):
+        self.tst(5)
+
+    def test_6(self):
+        self.tst(6)
+
+    def test_7(self):
+        self.tst(7)
+
+    def test_8(self):
+        self.tst(8)
+
+    def test_9(self):
+        self.tst(9)
+
+    def test_10(self):
+        self.tst(10)
 
-    def test_selected(self):
-        inputs = [0, 15, 10, 127]
-        self.run_tst(iter(inputs), width=8)
+    def test_11(self):
+        self.tst(11)
 
-    def test_non_power_2(self):
-        inputs = [0, 128, 512]
-        self.run_tst(iter(inputs), width=11)
+    def test_12(self):
+        self.tst(12)
 
 
 if __name__ == "__main__":
diff --git a/src/nmutil/test/test_deduped.py b/src/nmutil/test/test_deduped.py
new file mode 100644 (file)
index 0000000..42a4edf
--- /dev/null
@@ -0,0 +1,90 @@
+import unittest
+from nmutil.deduped import deduped
+
+
+class TestDeduped(unittest.TestCase):
+    def test_deduped1(self):
+        global_key = 1
+        call_count = 0
+
+        def call_counter():
+            nonlocal call_count
+            retval = call_count
+            call_count += 1
+            return retval
+
+        class C:
+            def __init__(self, name):
+                self.name = name
+
+            @deduped()
+            def method(self, a, *, b=1):
+                return self, a, b, call_counter()
+
+            @deduped(global_keys=[lambda: global_key])
+            def method_with_global(self, a, *, b=1):
+                return self, a, b, call_counter(), global_key
+
+            @staticmethod
+            @deduped()
+            def smethod(a, *, b=1):
+                return a, b, call_counter()
+
+            @classmethod
+            @deduped()
+            def cmethod(cls, a, *, b=1):
+                return cls, a, b, call_counter()
+
+            def __repr__(self):
+                return f"{self.__class__.__name__}({self.name})"
+
+        class D(C):
+            pass
+
+        c1 = C("c1")
+        c2 = C("c2")
+
+        # run everything twice to ensure caching works
+        for which_pass in ("first", "second"):
+            with self.subTest(which_pass=which_pass):
+                self.assertEqual(C.cmethod(1), (C, 1, 1, 0))
+                self.assertEqual(C.cmethod(2), (C, 2, 1, 1))
+                self.assertEqual(C.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(D.cmethod(1, b=5), (D, 1, 5, 3))
+                self.assertEqual(D.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(C.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(None), (c1, None, 1, 5))
+                global_key = 2
+                self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
+                self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
+                self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 8, 2))
+                global_key = 1
+                self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
+                self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
+                self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 9, 1))
+        self.assertEqual(call_count, 10)
+
+    def test_bad_methods(self):
+        with self.assertRaisesRegex(TypeError,
+                                    ".*@staticmethod.*applied.*@deduped.*"):
+            class C:
+                @deduped()
+                @staticmethod
+                def f():
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    ".*@classmethod.*applied.*@deduped.*"):
+            class C:
+                @deduped()
+                @classmethod
+                def f():
+                    pass
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/src/nmutil/test/test_grev.py b/src/nmutil/test/test_grev.py
new file mode 100644 (file)
index 0000000..780239d
--- /dev/null
@@ -0,0 +1,86 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+import unittest
+from nmigen.hdl.ast import AnyConst, Assert
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.grev import GRev, grev
+from nmigen.sim import Delay
+from nmutil.sim_util import do_sim, hash_256
+
+
+class TestGrev(FHDLTestCase):
+    def tst(self, msb_first, log2_width=6):
+        width = 2 ** log2_width
+        dut = GRev(log2_width, msb_first)
+        self.assertEqual(width, dut.width)
+        self.assertEqual(len(dut._intermediates), log2_width + 1)
+
+        def case(inval, chunk_sizes):
+            expected = grev(inval, chunk_sizes, log2_width)
+            with self.subTest(inval=hex(inval), chunk_sizes=bin(chunk_sizes),
+                              expected=hex(expected)):
+                yield dut.input.eq(inval)
+                yield dut.chunk_sizes.eq(chunk_sizes)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(expected, output)
+                for sig, expected in dut._sigs_and_expected(inval,
+                                                            chunk_sizes):
+                    value = yield sig
+                    with self.subTest(sig=sig.name, value=hex(value),
+                                      expected=hex(expected)):
+                        self.assertEqual(value, expected)
+
+        def process():
+            for count in range(width + 1):
+                inval = (1 << count) - 1
+                for chunk_sizes in range(2 ** log2_width):
+                    yield from case(inval, chunk_sizes)
+            for i in range(100):
+                inval = hash_256(f"grev input {i}")
+                inval &= 2 ** width - 1
+                chunk_sizes = hash_256(f"grev 2 {i}")
+                chunk_sizes &= 2 ** log2_width - 1
+                yield from case(inval, chunk_sizes)
+        with do_sim(self, dut, [dut.input, dut.chunk_sizes,
+                                dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test(self):
+        self.tst(msb_first=False)
+
+    def test_msb_first(self):
+        self.tst(msb_first=True)
+
+    def test_small(self):
+        self.tst(msb_first=False, log2_width=3)
+
+    def test_small_msb_first(self):
+        self.tst(msb_first=True, log2_width=3)
+
+    def tst_formal(self, msb_first):
+        log2_width = 4
+        dut = GRev(log2_width, msb_first)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
+        m.d.comb += dut.chunk_sizes.eq(AnyConst(log2_width))
+        # actual formal correctness proof is inside the module itself, now
+        self.assertFormal(m)
+
+    def test_formal(self):
+        self.tst_formal(msb_first=False)
+
+    def test_formal_msb_first(self):
+        self.tst_formal(msb_first=True)
+
+
+if __name__ == "__main__":
+    unittest.main()
index 8194880b8b0743623f606ee0083b737b0f983c5e..16ee2f1bf426dea76d13fef92949a2dd8e7b4c75 100644 (file)
@@ -126,10 +126,10 @@ class InputTest:
             op2 = self.di[muxid][i]
             rs = self.dut.p[muxid]
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.idx.eq(i)
-            yield rs.data_i.muxid.eq(muxid)
-            yield rs.data_i.operator.eq(1)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.idx.eq(i)
+            yield rs.i_data.muxid.eq(muxid)
+            yield rs.i_data.operator.eq(1)
             yield rs.mask_i.eq(1)
             yield
             o_p_ready = yield rs.o_ready
@@ -199,9 +199,9 @@ class InputTest:
             if not o_n_valid or not i_n_ready:
                 continue
 
-            out_muxid = yield n.data_o.muxid
-            out_i = yield n.data_o.idx
-            out_v = yield n.data_o.data
+            out_muxid = yield n.o_data.muxid
+            out_i = yield n.o_data.idx
+            out_v = yield n.o_data.data
 
             print("recv", out_muxid, out_i, hex(out_v), hex(out_v))
 
index bbb44afc4e6f0bf7e8fe651a1002e08692895956..149c6a5c05e473946df980ff404142fe56fd4d6e 100644 (file)
@@ -32,16 +32,15 @@ class PassData(Object):
         self.data = Signal(16, reset_less=True)
 
 
-
 class PassThroughStage:
     def ispec(self):
         return PassData()
+
     def ospec(self):
-        return self.ispec() # same as ospec
+        return self.ispec()  # same as ospec
 
     def process(self, i):
-        return i # pass-through
-
+        return i  # pass-through
 
 
 class PassThroughPipe(SimpleHandshake):
@@ -59,7 +58,7 @@ class InputTest:
             self.di[muxid] = {}
             self.do[muxid] = {}
             for i in range(self.tlen):
-                self.di[muxid][i] = randint(0, 255) + (muxid<<8)
+                self.di[muxid][i] = randint(0, 255) + (muxid << 8)
                 self.do[muxid][i] = self.di[muxid][i]
 
     def send(self, muxid):
@@ -67,16 +66,16 @@ class InputTest:
             op2 = self.di[muxid][i]
             rs = self.dut.p[muxid]
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.idx.eq(i)
-            yield rs.data_i.muxid.eq(muxid)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.idx.eq(i)
+            yield rs.i_data.muxid.eq(muxid)
             yield
             o_p_ready = yield rs.o_ready
             while not o_p_ready:
                 yield
                 o_p_ready = yield rs.o_ready
 
-            print ("send", muxid, i, hex(op2))
+            print("send", muxid, i, hex(op2))
 
             yield rs.i_valid.eq(0)
             # wait random period of time before queueing another value
@@ -86,22 +85,22 @@ class InputTest:
         yield rs.i_valid.eq(0)
         yield
 
-        print ("send ended", muxid)
+        print("send ended", muxid)
 
-        ## wait random period of time before queueing another value
-        #for i in range(randint(0, 3)):
+        # wait random period of time before queueing another value
+        # for i in range(randint(0, 3)):
         #    yield
 
         #send_range = randint(0, 3)
-        #if send_range == 0:
+        # if send_range == 0:
         #    send = True
-        #else:
+        # else:
         #    send = randint(0, send_range) != 0
 
     def rcv(self, muxid):
         while True:
             #stall_range = randint(0, 3)
-            #for j in range(randint(1,10)):
+            # for j in range(randint(1,10)):
             #    stall = randint(0, stall_range) != 0
             #    yield self.dut.n[0].i_ready.eq(stall)
             #    yield
@@ -113,24 +112,24 @@ class InputTest:
             if not o_n_valid or not i_n_ready:
                 continue
 
-            out_muxid = yield n.data_o.muxid
-            out_i = yield n.data_o.idx
-            out_v = yield n.data_o.data
+            out_muxid = yield n.o_data.muxid
+            out_i = yield n.o_data.idx
+            out_v = yield n.o_data.data
 
-            print ("recv", out_muxid, out_i, hex(out_v))
+            print("recv", out_muxid, out_i, hex(out_v))
 
             # see if this output has occurred already, delete it if it has
             assert muxid == out_muxid, \
-                    "out_muxid %d not correct %d" % (out_muxid, muxid)
+                "out_muxid %d not correct %d" % (out_muxid, muxid)
             assert out_i in self.do[muxid], "out_i %d not in array %s" % \
-                                          (out_i, repr(self.do[muxid]))
-            assert self.do[muxid][out_i] == out_v # pass-through data
+                (out_i, repr(self.do[muxid]))
+            assert self.do[muxid][out_i] == out_v  # pass-through data
             del self.do[muxid][out_i]
 
             # check if there's any more outputs
             if len(self.do[muxid]) == 0:
                 break
-        print ("recv ended", muxid)
+        print("recv ended", muxid)
 
 
 class TestPriorityMuxPipe(PriorityCombMuxInPipe):
@@ -151,7 +150,7 @@ class OutputTest:
                 muxid = i
             else:
                 muxid = randint(0, dut.num_rows-1)
-            data = randint(0, 255) + (muxid<<8)
+            data = randint(0, 255) + (muxid << 8)
 
     def send(self):
         for i in range(self.tlen * dut.num_rows):
@@ -159,15 +158,15 @@ class OutputTest:
             muxid = self.di[i][1]
             rs = dut.p
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.muxid.eq(muxid)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.muxid.eq(muxid)
             yield
             o_p_ready = yield rs.o_ready
             while not o_p_ready:
                 yield
                 o_p_ready = yield rs.o_ready
 
-            print ("send", muxid, i, hex(op2))
+            print("send", muxid, i, hex(op2))
 
             yield rs.i_valid.eq(0)
             # wait random period of time before queueing another value
@@ -187,13 +186,13 @@ class TestMuxOutPipe(CombMuxOutPipe):
 class TestInOutPipe(Elaboratable):
     def __init__(self, num_rows=4):
         self.num_rows = num_rows
-        self.inpipe = TestPriorityMuxPipe(num_rows) # fan-in (combinatorial)
+        self.inpipe = TestPriorityMuxPipe(num_rows)  # fan-in (combinatorial)
         self.pipe1 = PassThroughPipe()              # stage 1 (clock-sync)
         self.pipe2 = PassThroughPipe()              # stage 2 (clock-sync)
         self.outpipe = TestMuxOutPipe(num_rows)     # fan-out (combinatorial)
 
         self.p = self.inpipe.p  # kinda annoying,
-        self.n = self.outpipe.n # use pipe in/out as this class in/out
+        self.n = self.outpipe.n  # use pipe in/out as this class in/out
         self._ports = self.inpipe.ports() + self.outpipe.ports()
 
     def elaborate(self, platform):
@@ -225,8 +224,9 @@ def test1():
                          test.rcv(3), test.rcv(2),
                          test.send(0), test.send(1),
                          test.send(3), test.send(2),
-                        ],
+                         ],
                    vcd_name="test_inoutmux_pipe.vcd")
 
+
 if __name__ == '__main__':
     test1()
index 896002253b8ad0163c945ce760c90ac1c493b282..fd0570c744f63772921f9074e979faf9978a8fda 100644 (file)
@@ -33,16 +33,15 @@ class PassData(Object):
         self.data = Signal(16, reset_less=True)
 
 
-
 class PassThroughStage:
     def ispec(self):
         return PassData()
+
     def ospec(self):
-        return self.ispec() # same as ospec
+        return self.ispec()  # same as ospec
 
     def process(self, i):
-        return i # pass-through
-
+        return i  # pass-through
 
 
 class PassThroughPipe(MaskCancellable):
@@ -62,7 +61,7 @@ class InputTest:
             self.do[muxid] = {}
             self.sent[muxid] = []
             for i in range(self.tlen):
-                self.di[muxid][i] = randint(0, 255) + (muxid<<8)
+                self.di[muxid][i] = randint(0, 255) + (muxid << 8)
                 self.do[muxid][i] = self.di[muxid][i]
 
     def send(self, muxid):
@@ -70,9 +69,9 @@ class InputTest:
             op2 = self.di[muxid][i]
             rs = self.dut.p[muxid]
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.idx.eq(i)
-            yield rs.data_i.muxid.eq(muxid)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.idx.eq(i)
+            yield rs.i_data.muxid.eq(muxid)
             yield rs.mask_i.eq(1)
             yield
             o_p_ready = yield rs.o_ready
@@ -80,7 +79,7 @@ class InputTest:
                 yield
                 o_p_ready = yield rs.o_ready
 
-            print ("send", muxid, i, hex(op2), op2)
+            print("send", muxid, i, hex(op2), op2)
             self.sent[muxid].append(i)
 
             yield rs.i_valid.eq(0)
@@ -96,16 +95,16 @@ class InputTest:
         yield rs.i_valid.eq(0)
         yield
 
-        print ("send ended", muxid)
+        print("send ended", muxid)
 
-        ## wait random period of time before queueing another value
-        #for i in range(randint(0, 3)):
+        # wait random period of time before queueing another value
+        # for i in range(randint(0, 3)):
         #    yield
 
         #send_range = randint(0, 3)
-        #if send_range == 0:
+        # if send_range == 0:
         #    send = True
-        #else:
+        # else:
         #    send = randint(0, send_range) != 0
 
     def rcv(self, muxid):
@@ -115,16 +114,16 @@ class InputTest:
             # check cancellation
             if self.sent[muxid] and randint(0, 2) == 0:
                 todel = self.sent[muxid].pop()
-                print ("to delete", muxid, self.sent[muxid], todel)
+                print("to delete", muxid, self.sent[muxid], todel)
                 if todel in self.do[muxid]:
                     del self.do[muxid][todel]
                     yield rs.stop_i.eq(1)
-                print ("left", muxid, self.do[muxid])
+                print("left", muxid, self.do[muxid])
                 if len(self.do[muxid]) == 0:
                     break
 
             stall_range = randint(0, 3)
-            for j in range(randint(1,10)):
+            for j in range(randint(1, 10)):
                 stall = randint(0, stall_range) != 0
                 yield self.dut.n[0].i_ready.eq(stall)
                 yield
@@ -132,27 +131,27 @@ class InputTest:
             n = self.dut.n[muxid]
             yield n.i_ready.eq(1)
             yield
-            yield rs.stop_i.eq(0) # resets cancel mask
+            yield rs.stop_i.eq(0)  # resets cancel mask
             o_n_valid = yield n.o_valid
             i_n_ready = yield n.i_ready
             if not o_n_valid or not i_n_ready:
                 continue
 
-            out_muxid = yield n.data_o.muxid
-            out_i = yield n.data_o.idx
-            out_v = yield n.data_o.data
+            out_muxid = yield n.o_data.muxid
+            out_i = yield n.o_data.idx
+            out_v = yield n.o_data.data
 
-            print ("recv", out_muxid, out_i, hex(out_v), out_v)
+            print("recv", out_muxid, out_i, hex(out_v), out_v)
 
             # see if this output has occurred already, delete it if it has
             assert muxid == out_muxid, \
-                    "out_muxid %d not correct %d" % (out_muxid, muxid)
+                "out_muxid %d not correct %d" % (out_muxid, muxid)
             if out_i not in self.sent[muxid]:
-                print ("cancelled/recv", muxid, out_i)
+                print("cancelled/recv", muxid, out_i)
                 continue
             assert out_i in self.do[muxid], "out_i %d not in array %s" % \
-                                          (out_i, repr(self.do[muxid]))
-            assert self.do[muxid][out_i] == out_v # pass-through data
+                (out_i, repr(self.do[muxid]))
+            assert self.do[muxid][out_i] == out_v  # pass-through data
             del self.do[muxid][out_i]
             todel = self.sent[muxid].index(out_i)
             del self.sent[muxid][todel]
@@ -161,7 +160,7 @@ class InputTest:
             if len(self.do[muxid]) == 0:
                 break
 
-        print ("recv ended", muxid)
+        print("recv ended", muxid)
 
 
 class TestPriorityMuxPipe(PriorityCombMuxInPipe):
@@ -183,7 +182,7 @@ class TestMuxOutPipe(CombMuxOutPipe):
 class TestInOutPipe(Elaboratable):
     def __init__(self, num_rows=4):
         self.num_rows = nr = num_rows
-        self.inpipe = TestPriorityMuxPipe(nr) # fan-in (combinatorial)
+        self.inpipe = TestPriorityMuxPipe(nr)  # fan-in (combinatorial)
         self.pipe1 = PassThroughPipe(nr)      # stage 1 (clock-sync)
         self.pipe2 = PassThroughPipe(nr)      # stage 2 (clock-sync)
         self.pipe3 = PassThroughPipe(nr)      # stage 3 (clock-sync)
@@ -191,7 +190,7 @@ class TestInOutPipe(Elaboratable):
         self.outpipe = TestMuxOutPipe(nr)     # fan-out (combinatorial)
 
         self.p = self.inpipe.p  # kinda annoying,
-        self.n = self.outpipe.n # use pipe in/out as this class in/out
+        self.n = self.outpipe.n  # use pipe in/out as this class in/out
         self._ports = self.inpipe.ports() + self.outpipe.ports()
 
     def elaborate(self, platform):
@@ -228,8 +227,9 @@ def test1():
                          test.rcv(3), test.rcv(2),
                          test.send(0), test.send(1),
                          test.send(3), test.send(2),
-                        ],
+                         ],
                    vcd_name="test_inoutmux_unarycancel_pipe.vcd")
 
+
 if __name__ == '__main__':
     test1()
diff --git a/src/nmutil/test/test_lut.py b/src/nmutil/test/test_lut.py
new file mode 100644 (file)
index 0000000..e0a9809
--- /dev/null
@@ -0,0 +1,139 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2021 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+import unittest
+from nmigen.hdl.ast import AnyConst, Assert, Signal
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
+from nmigen.sim import Delay
+from nmutil.sim_util import do_sim, hash_256
+
+
+class TestBitwiseMux(FHDLTestCase):
+    def test(self):
+        width = 2
+        dut = BitwiseMux(width)
+
+        def case(sel, t, f, expected):
+            with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
+                yield dut.sel.eq(sel)
+                yield dut.t.eq(t)
+                yield dut.f.eq(f)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=bin(output), expected=bin(expected)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for sel in range(2 ** width):
+                for t in range(2 ** width):
+                    for f in range(2**width):
+                        expected = 0
+                        for i in range(width):
+                            if sel & 2 ** i:
+                                if t & 2 ** i:
+                                    expected |= 2 ** i
+                            elif f & 2 ** i:
+                                expected |= 2 ** i
+                        yield from case(sel, t, f, expected)
+        with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test_formal(self):
+        width = 2
+        dut = BitwiseMux(width)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.sel.eq(AnyConst(width))
+        m.d.comb += dut.f.eq(AnyConst(width))
+        m.d.comb += dut.t.eq(AnyConst(width))
+        for i in range(width):
+            with m.If(dut.sel[i]):
+                m.d.comb += Assert(dut.t[i] == dut.output[i])
+            with m.Else():
+                m.d.comb += Assert(dut.f[i] == dut.output[i])
+        self.assertFormal(m)
+
+
+class TestBitwiseLut(FHDLTestCase):
+    def tst(self, cls):
+        dut = cls(3, 16)
+        mask = 2 ** dut.width - 1
+        lut_mask = 2 ** dut.lut.width - 1
+
+        def case(in0, in1, in2, lut):
+            expected = 0
+            for i in range(dut.width):
+                lut_index = 0
+                if in0 & 2 ** i:
+                    lut_index |= 2 ** 0
+                if in1 & 2 ** i:
+                    lut_index |= 2 ** 1
+                if in2 & 2 ** i:
+                    lut_index |= 2 ** 2
+                if lut & 2 ** lut_index:
+                    expected |= 2 ** i
+            with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
+                              lut=bin(lut)):
+                yield dut.inputs[0].eq(in0)
+                yield dut.inputs[1].eq(in1)
+                yield dut.inputs[2].eq(in2)
+                yield dut.lut.eq(lut)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=bin(output), expected=bin(expected)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for shift in range(dut.lut.width):
+                with self.subTest(shift=shift):
+                    yield from case(in0=0xAAAA, in1=0xCCCC, in2=0xF0F0,
+                                    lut=1 << shift)
+            for case_index in range(100):
+                with self.subTest(case_index=case_index):
+                    in0 = hash_256(f"{case_index} in0") & mask
+                    in1 = hash_256(f"{case_index} in1") & mask
+                    in2 = hash_256(f"{case_index} in2") & mask
+                    lut = hash_256(f"{case_index} lut") & lut_mask
+                    yield from case(in0, in1, in2, lut)
+        with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def tst_formal(self, cls):
+        dut = cls(3, 16)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
+        m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
+        m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
+        m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
+        for i in range(dut.width):
+            lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
+            for j in range(dut.input_count):
+                m.d.comb += lut_index[j].eq(dut.inputs[j][i])
+            for j in range(dut.lut.width):
+                with m.If(lut_index == j):
+                    m.d.comb += Assert(dut.lut[j] == dut.output[i])
+        self.assertFormal(m)
+
+    def test(self):
+        self.tst(BitwiseLut)
+
+    def test_tree(self):
+        self.tst(TreeBitwiseLut)
+
+    def test_formal(self):
+        self.tst_formal(BitwiseLut)
+
+    def test_tree_formal(self):
+        self.tst_formal(TreeBitwiseLut)
+
+
+if __name__ == "__main__":
+    unittest.main()
index 212d17963a4149f91b714b541ed6da09ea31cc9c..4624a5157d388c3f53d8ede9773a573714fffd60 100644 (file)
@@ -22,7 +22,7 @@ class PassThroughStage:
 
     def ospec(self, name):
         return Signal(16, name="%s_dout" % name, reset_less=True)
-                
+
     def process(self, i):
         return i.data
 
@@ -30,12 +30,12 @@ class PassThroughStage:
 class PassThroughDataStage:
     def ispec(self):
         return PassInData()
+
     def ospec(self):
-        return self.ispec() # same as ospec
+        return self.ispec()  # same as ospec
 
     def process(self, i):
-        return i # pass-through
-
+        return i  # pass-through
 
 
 class PassThroughPipe(PassThroughHandshake):
@@ -54,7 +54,7 @@ class OutputTest:
                 muxid = i
             else:
                 muxid = randint(0, dut.num_rows-1)
-            data = randint(0, 255) + (muxid<<8)
+            data = randint(0, 255) + (muxid << 8)
             if muxid not in self.do:
                 self.do[muxid] = []
             self.di.append((data, muxid))
@@ -66,15 +66,15 @@ class OutputTest:
             muxid = self.di[i][1]
             rs = self.dut.p
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.muxid.eq(muxid)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.muxid.eq(muxid)
             yield
             o_p_ready = yield rs.o_ready
             while not o_p_ready:
                 yield
                 o_p_ready = yield rs.o_ready
 
-            print ("send", muxid, i, hex(op2))
+            print("send", muxid, i, hex(op2))
 
             yield rs.i_valid.eq(0)
             # wait random period of time before queueing another value
@@ -98,11 +98,11 @@ class OutputTest:
             if not o_n_valid or not i_n_ready:
                 continue
 
-            out_v = yield n.data_o
+            out_v = yield n.o_data
 
-            print ("recv", muxid, out_i, hex(out_v))
+            print("recv", muxid, out_i, hex(out_v))
 
-            assert self.do[muxid][out_i] == out_v # pass-through data
+            assert self.do[muxid][out_i] == out_v  # pass-through data
 
             out_i += 1
 
@@ -140,11 +140,11 @@ class TestSyncToPriorityPipe(Elaboratable):
 
     def ports(self):
         res = [self.p.i_valid, self.p.o_ready] + \
-                self.p.data_i.ports()
+            self.p.i_data.ports()
         for i in range(len(self.n)):
             res += [self.n[i].i_ready, self.n[i].o_valid] + \
-                    [self.n[i].data_o]
-                    #self.n[i].data_o.ports()
+                [self.n[i].o_data]
+            # self.n[i].o_data.ports()
         return res
 
 
@@ -160,5 +160,6 @@ def test1():
                          test.send()],
                    vcd_name="test_outmux_pipe.vcd")
 
+
 if __name__ == '__main__':
     test1()
diff --git a/src/nmutil/test/test_plain_data.py b/src/nmutil/test/test_plain_data.py
new file mode 100644 (file)
index 0000000..f16faba
--- /dev/null
@@ -0,0 +1,283 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+import operator
+import pickle
+import unittest
+import typing
+from nmutil.plain_data import (FrozenPlainDataError, plain_data,
+                               fields, replace)
+
+try:
+    from typing import Protocol
+except ImportError:
+    try:
+        from typing_extensions import Protocol
+    except ImportError:
+        Protocol = None
+
+
+@plain_data(order=True)
+class PlainData0:
+    __slots__ = ()
+
+
+@plain_data(order=True)
+class PlainData1:
+    __slots__ = "a", "b", "x", "y"
+
+    def __init__(self, a, b, *, x, y):
+        self.a = a
+        self.b = b
+        self.x = x
+        self.y = y
+
+
+@plain_data(order=True)
+class PlainData2(PlainData1):
+    __slots__ = "a", "z"
+
+    def __init__(self, a, b, *, x, y, z):
+        super().__init__(a, b, x=x, y=y)
+        self.z = z
+
+
+@plain_data(order=True, frozen=True, unsafe_hash=True)
+class PlainDataF0:
+    __slots__ = ()
+
+
+@plain_data(order=True, frozen=True, unsafe_hash=True)
+class PlainDataF1:
+    __slots__ = "a", "b", "x", "y"
+
+    def __init__(self, a, b, *, x, y):
+        self.a = a
+        self.b = b
+        self.x = x
+        self.y = y
+
+
+@plain_data(order=True, frozen=True, unsafe_hash=True)
+class PlainDataF2(PlainDataF1):
+    __slots__ = "a", "z"
+
+    def __init__(self, a, b, *, x, y, z):
+        super().__init__(a, b, x=x, y=y)
+        self.z = z
+
+
+@plain_data()
+class UnsetField:
+    __slots__ = "a", "b"
+
+    def __init__(self, **kwargs):
+        for name, value in kwargs.items():
+            setattr(self, name, value)
+
+
+T = typing.TypeVar("T")
+
+
+@plain_data()
+class GenericClass(typing.Generic[T]):
+    __slots__ = "a",
+
+    def __init__(self, a):
+        self.a = a
+
+
+@plain_data()
+class MySet(typing.AbstractSet[int]):
+    __slots__ = ()
+
+    def __contains__(self, x):
+        raise NotImplementedError
+
+    def __iter__(self):
+        raise NotImplementedError
+
+    def __len__(self):
+        raise NotImplementedError
+
+
+@plain_data()
+class MyIntLike(typing.SupportsInt):
+    __slots__ = ()
+
+    def __int__(self):
+        return 1
+
+
+if Protocol is not None:
+    class MyProtocol(Protocol):
+        def my_method(self): ...
+
+    @plain_data()
+    class MyProtocolImpl(MyProtocol):
+        __slots__ = ()
+
+        def my_method(self):
+            pass
+
+
+class TestPlainData(unittest.TestCase):
+    def test_fields(self):
+        self.assertEqual(fields(PlainData0), ())
+        self.assertEqual(fields(PlainData0()), ())
+        self.assertEqual(fields(PlainData1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF0), ())
+        self.assertEqual(fields(PlainDataF0()), ())
+        self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(GenericClass(1)), ("a",))
+        self.assertEqual(fields(MySet()), ())
+        self.assertEqual(fields(MyIntLike()), ())
+        if Protocol is not None:
+            self.assertEqual(fields(MyProtocolImpl()), ())
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            fields(type)
+
+    def test_replace(self):
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            replace(PlainData0)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"):
+            replace(PlainData0(), a=1)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"):
+            replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1)
+        self.assertEqual(replace(PlainData0()), PlainData0())
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")),
+                         PlainDataF1(1, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3),
+                         PlainDataF1(3, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3),
+                         PlainDataF1(3, 2, x=5, y="y"))
+
+    def test_eq(self):
+        self.assertTrue(PlainData0() == PlainData0())
+        self.assertFalse('a' == PlainData0())
+        self.assertFalse(PlainDataF0() == PlainData0())
+        self.assertTrue(PlainData1(1, 2, x="x", y="y")
+                        == PlainData1(1, 2, x="x", y="y"))
+        self.assertFalse(PlainData1(1, 2, x="x", y="y")
+                         == PlainData1(1, 2, x="x", y="z"))
+        self.assertFalse(PlainData1(1, 2, x="x", y="y")
+                         == PlainData2(1, 2, x="x", y="y", z=3))
+
+    def test_hash(self):
+        def check_op(v, tuple_v):
+            with self.subTest(v=v, tuple_v=tuple_v):
+                self.assertEqual(hash(v), hash(tuple_v))
+
+        def check(a, b, x, y, z):
+            tuple_v = a, b, x, y, z
+            v = PlainDataF2(a=a, b=b, x=x, y=y, z=z)
+            check_op(v, tuple_v)
+
+        check(1, 2, "x", "y", "z")
+
+        check(1, 2, "x", "y", "a")
+        check(1, 2, "x", "y", "zz")
+
+        check(1, 2, "x", "a", "z")
+        check(1, 2, "x", "zz", "z")
+
+        check(1, 2, "a", "y", "z")
+        check(1, 2, "zz", "y", "z")
+
+        check(1, -10, "x", "y", "z")
+        check(1, 10, "x", "y", "z")
+
+        check(-10, 2, "x", "y", "z")
+        check(10, 2, "x", "y", "z")
+
+    def test_order(self):
+        def check_op(l, r, tuple_l, tuple_r, op):
+            with self.subTest(l=l, r=r,
+                              tuple_l=tuple_l, tuple_r=tuple_r, op=op):
+                self.assertEqual(op(l, r), op(tuple_l, tuple_r))
+                self.assertEqual(op(r, l), op(tuple_r, tuple_l))
+
+        def check(a, b, x, y, z):
+            tuple_l = 1, 2, "x", "y", "z"
+            l = PlainData2(a=1, b=2, x="x", y="y", z="z")
+            tuple_r = a, b, x, y, z
+            r = PlainData2(a=a, b=b, x=x, y=y, z=z)
+            check_op(l, r, tuple_l, tuple_r, operator.eq)
+            check_op(l, r, tuple_l, tuple_r, operator.ne)
+            check_op(l, r, tuple_l, tuple_r, operator.lt)
+            check_op(l, r, tuple_l, tuple_r, operator.le)
+            check_op(l, r, tuple_l, tuple_r, operator.gt)
+            check_op(l, r, tuple_l, tuple_r, operator.ge)
+
+        check(1, 2, "x", "y", "z")
+
+        check(1, 2, "x", "y", "a")
+        check(1, 2, "x", "y", "zz")
+
+        check(1, 2, "x", "a", "z")
+        check(1, 2, "x", "zz", "z")
+
+        check(1, 2, "a", "y", "z")
+        check(1, 2, "zz", "y", "z")
+
+        check(1, -10, "x", "y", "z")
+        check(1, 10, "x", "y", "z")
+
+        check(-10, 2, "x", "y", "z")
+        check(10, 2, "x", "y", "z")
+
+    def test_repr(self):
+        self.assertEqual(repr(PlainData0()), "PlainData0()")
+        self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
+                         "PlainData1(a=1, b=2, x='x', y='y')")
+        self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
+                         "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+        self.assertEqual(repr(PlainDataF2(1, 2, x="x", y="y", z=3)),
+                         "PlainDataF2(a=1, b=2, x='x', y='y', z=3)")
+        self.assertEqual(repr(UnsetField()),
+                         "UnsetField(a=<not set>, b=<not set>)")
+        self.assertEqual(repr(UnsetField(a=2)), "UnsetField(a=2, b=<not set>)")
+        self.assertEqual(repr(UnsetField(b=3)), "UnsetField(a=<not set>, b=3)")
+        self.assertEqual(repr(UnsetField(a=5, b=3)), "UnsetField(a=5, b=3)")
+
+    def test_frozen(self):
+        not_frozen = PlainData0()
+        not_frozen.a = 1
+        frozen0 = PlainDataF0()
+        with self.assertRaises(AttributeError):
+            frozen0.a = 1
+        frozen1 = PlainDataF1(1, 2, x="x", y="y")
+        with self.assertRaises(FrozenPlainDataError):
+            frozen1.a = 1
+
+    def test_pickle(self):
+        def check(v):
+            with self.subTest(v=v):
+                self.assertEqual(v, pickle.loads(pickle.dumps(v)))
+
+        check(PlainData0())
+        check(PlainData1(a=1, b=2, x="x", y="y"))
+        check(PlainData2(a=1, b=2, x="x", y="y", z="z"))
+        check(PlainDataF0())
+        check(PlainDataF1(a=1, b=2, x="x", y="y"))
+        check(PlainDataF2(a=1, b=2, x="x", y="y", z="z"))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/nmutil/test/test_prefix_sum.py b/src/nmutil/test/test_prefix_sum.py
new file mode 100644 (file)
index 0000000..2b88407
--- /dev/null
@@ -0,0 +1,307 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from functools import reduce
+from nmutil.formaltest import FHDLTestCase
+from nmutil.sim_util import write_il
+from itertools import accumulate
+import operator
+from nmutil.popcount import pop_count
+from nmutil.prefix_sum import (Op, prefix_sum,
+                               render_prefix_sum_diagram,
+                               tree_reduction, tree_reduction_ops)
+import unittest
+from nmigen.hdl.ast import Signal, AnyConst, Assert
+from nmigen.hdl.dsl import Module
+
+
+def reference_prefix_sum(items, fn):
+    return list(accumulate(items, fn))
+
+
+class TestPrefixSum(FHDLTestCase):
+    maxDiff = None
+
+    def test_prefix_sum_str(self):
+        input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
+        expected = reference_prefix_sum(input_items, operator.add)
+        with self.subTest(expected=repr(expected)):
+            non_work_efficient = prefix_sum(input_items, work_efficient=False)
+            self.assertEqual(expected, non_work_efficient)
+        with self.subTest(expected=repr(expected)):
+            work_efficient = prefix_sum(input_items, work_efficient=True)
+            self.assertEqual(expected, work_efficient)
+
+    def test_tree_reduction_str(self):
+        input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
+        expected = reduce(operator.add, input_items)
+        with self.subTest(expected=repr(expected)):
+            work_efficient = tree_reduction(input_items)
+            self.assertEqual(expected, work_efficient)
+
+    def test_tree_reduction_ops_9(self):
+        ops = list(tree_reduction_ops(9))
+        self.assertEqual(ops, [
+            Op(out=8, lhs=7, rhs=8, row=0),
+            Op(out=6, lhs=5, rhs=6, row=0),
+            Op(out=4, lhs=3, rhs=4, row=0),
+            Op(out=2, lhs=1, rhs=2, row=0),
+            Op(out=8, lhs=6, rhs=8, row=1),
+            Op(out=4, lhs=2, rhs=4, row=1),
+            Op(out=8, lhs=4, rhs=8, row=2),
+            Op(out=8, lhs=0, rhs=8, row=3),
+        ])
+
+    def test_tree_reduction_ops_8(self):
+        ops = list(tree_reduction_ops(8))
+        self.assertEqual(ops, [
+            Op(out=7, lhs=6, rhs=7, row=0),
+            Op(out=5, lhs=4, rhs=5, row=0),
+            Op(out=3, lhs=2, rhs=3, row=0),
+            Op(out=1, lhs=0, rhs=1, row=0),
+            Op(out=7, lhs=5, rhs=7, row=1),
+            Op(out=3, lhs=1, rhs=3, row=1),
+            Op(out=7, lhs=3, rhs=7, row=2),
+        ])
+
+    def tst_pop_count_int(self, width):
+        assert isinstance(width, int)
+        for v in range(1 << width):
+            expected = bin(v).count("1")  # converts to a string, counts 1s
+            with self.subTest(v=v, expected=expected):
+                self.assertEqual(expected, pop_count(v, width=width))
+
+    def test_pop_count_int_0(self):
+        self.tst_pop_count_int(0)
+
+    def test_pop_count_int_1(self):
+        self.tst_pop_count_int(1)
+
+    def test_pop_count_int_2(self):
+        self.tst_pop_count_int(2)
+
+    def test_pop_count_int_3(self):
+        self.tst_pop_count_int(3)
+
+    def test_pop_count_int_4(self):
+        self.tst_pop_count_int(4)
+
+    def test_pop_count_int_5(self):
+        self.tst_pop_count_int(5)
+
+    def test_pop_count_int_6(self):
+        self.tst_pop_count_int(6)
+
+    def test_pop_count_int_7(self):
+        self.tst_pop_count_int(7)
+
+    def test_pop_count_int_8(self):
+        self.tst_pop_count_int(8)
+
+    def test_pop_count_int_9(self):
+        self.tst_pop_count_int(9)
+
+    def test_pop_count_int_10(self):
+        self.tst_pop_count_int(10)
+
+    def tst_pop_count_formal(self, width):
+        assert isinstance(width, int)
+        m = Module()
+        v = Signal(width)
+        out = Signal(16)
+
+        def process_temporary(v):
+            sig = Signal.like(v)
+            m.d.comb += sig.eq(v)
+            return sig
+
+        m.d.comb += out.eq(pop_count(v, process_temporary=process_temporary))
+        write_il(self, m, [v, out])
+        m.d.comb += v.eq(AnyConst(width))
+        expected = Signal(16)
+        m.d.comb += expected.eq(reduce(operator.add,
+                                       (v[i] for i in range(width)),
+                                       0))
+        m.d.comb += Assert(out == expected)
+        self.assertFormal(m)
+
+    def test_pop_count_formal_0(self):
+        self.tst_pop_count_formal(0)
+
+    def test_pop_count_formal_1(self):
+        self.tst_pop_count_formal(1)
+
+    def test_pop_count_formal_2(self):
+        self.tst_pop_count_formal(2)
+
+    def test_pop_count_formal_3(self):
+        self.tst_pop_count_formal(3)
+
+    def test_pop_count_formal_4(self):
+        self.tst_pop_count_formal(4)
+
+    def test_pop_count_formal_5(self):
+        self.tst_pop_count_formal(5)
+
+    def test_pop_count_formal_6(self):
+        self.tst_pop_count_formal(6)
+
+    def test_pop_count_formal_7(self):
+        self.tst_pop_count_formal(7)
+
+    def test_pop_count_formal_8(self):
+        self.tst_pop_count_formal(8)
+
+    def test_pop_count_formal_9(self):
+        self.tst_pop_count_formal(9)
+
+    def test_pop_count_formal_10(self):
+        self.tst_pop_count_formal(10)
+
+    def test_render_work_efficient(self):
+        text = render_prefix_sum_diagram(16, work_efficient=True, plus="@")
+        expected = r"""
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |
+ ●  |  ●  |  ●  |  ●  |  ●  |  ●  |  ●  |  ●  |
+ |\ |  |\ |  |\ |  |\ |  |\ |  |\ |  |\ |  |\ |
+ | \|  | \|  | \|  | \|  | \|  | \|  | \|  | \|
+ |  @  |  @  |  @  |  @  |  @  |  @  |  @  |  @
+ |  |\ |  |  |  |\ |  |  |  |\ |  |  |  |\ |  |
+ |  | \|  |  |  | \|  |  |  | \|  |  |  | \|  |
+ |  |  X  |  |  |  X  |  |  |  X  |  |  |  X  |
+ |  |  |\ |  |  |  |\ |  |  |  |\ |  |  |  |\ |
+ |  |  | \|  |  |  | \|  |  |  | \|  |  |  | \|
+ |  |  |  @  |  |  |  @  |  |  |  @  |  |  |  @
+ |  |  |  |\ |  |  |  |  |  |  |  |\ |  |  |  |
+ |  |  |  | \|  |  |  |  |  |  |  | \|  |  |  |
+ |  |  |  |  X  |  |  |  |  |  |  |  X  |  |  |
+ |  |  |  |  |\ |  |  |  |  |  |  |  |\ |  |  |
+ |  |  |  |  | \|  |  |  |  |  |  |  | \|  |  |
+ |  |  |  |  |  X  |  |  |  |  |  |  |  X  |  |
+ |  |  |  |  |  |\ |  |  |  |  |  |  |  |\ |  |
+ |  |  |  |  |  | \|  |  |  |  |  |  |  | \|  |
+ |  |  |  |  |  |  X  |  |  |  |  |  |  |  X  |
+ |  |  |  |  |  |  |\ |  |  |  |  |  |  |  |\ |
+ |  |  |  |  |  |  | \|  |  |  |  |  |  |  | \|
+ |  |  |  |  |  |  |  @  |  |  |  |  |  |  |  @
+ |  |  |  |  |  |  |  |\ |  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  | \|  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  X  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |\ |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  | \|  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  X  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |\ |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  | \|  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  X  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |\ |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  | \|  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  X  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |\ |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  | \|  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  X  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |\ |  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  | \|  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  X  |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |\ |  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  | \|  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  X  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  |\ |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  | \|
+ |  |  |  |  |  |  |  ●  |  |  |  |  |  |  |  @
+ |  |  |  |  |  |  |  |\ |  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  | \|  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  X  |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |\ |  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  | \|  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  X  |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |\ |  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  | \|  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  X  |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  |\ |  |  |  |  |
+ |  |  |  |  |  |  |  |  |  |  | \|  |  |  |  |
+ |  |  |  ●  |  |  |  ●  |  |  |  @  |  |  |  |
+ |  |  |  |\ |  |  |  |\ |  |  |  |\ |  |  |  |
+ |  |  |  | \|  |  |  | \|  |  |  | \|  |  |  |
+ |  |  |  |  X  |  |  |  X  |  |  |  X  |  |  |
+ |  |  |  |  |\ |  |  |  |\ |  |  |  |\ |  |  |
+ |  |  |  |  | \|  |  |  | \|  |  |  | \|  |  |
+ |  ●  |  ●  |  @  |  ●  |  @  |  ●  |  @  |  |
+ |  |\ |  |\ |  |\ |  |\ |  |\ |  |\ |  |\ |  |
+ |  | \|  | \|  | \|  | \|  | \|  | \|  | \|  |
+ |  |  @  |  @  |  @  |  @  |  @  |  @  |  @  |
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |
+"""
+        expected = expected[1:-1]  # trim newline at start and end
+        if text != expected:
+            print("text:")
+            print(text)
+            print()
+        self.assertEqual(expected, text)
+
+    def test_render_not_work_efficient(self):
+        text = render_prefix_sum_diagram(16, work_efficient=False, plus="@")
+        expected = r"""
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |
+ ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  ●  |
+ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
+ | \| \| \| \| \| \| \| \| \| \| \| \| \| \| \|
+ ●  @  @  @  @  @  @  @  @  @  @  @  @  @  @  @
+ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |  |
+ | \| \| \| \| \| \| \| \| \| \| \| \| \| \|  |
+ |  X  X  X  X  X  X  X  X  X  X  X  X  X  X  |
+ |  |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
+ |  | \| \| \| \| \| \| \| \| \| \| \| \| \| \|
+ ●  ●  @  @  @  @  @  @  @  @  @  @  @  @  @  @
+ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |
+ | \| \| \| \| \| \| \| \| \| \| \| \|  |  |  |
+ |  X  X  X  X  X  X  X  X  X  X  X  X  |  |  |
+ |  |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |  |  |
+ |  | \| \| \| \| \| \| \| \| \| \| \| \|  |  |
+ |  |  X  X  X  X  X  X  X  X  X  X  X  X  |  |
+ |  |  |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |  |
+ |  |  | \| \| \| \| \| \| \| \| \| \| \| \|  |
+ |  |  |  X  X  X  X  X  X  X  X  X  X  X  X  |
+ |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
+ |  |  |  | \| \| \| \| \| \| \| \| \| \| \| \|
+ ●  ●  ●  ●  @  @  @  @  @  @  @  @  @  @  @  @
+ |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |  |  |  |  |
+ | \| \| \| \| \| \| \| \|  |  |  |  |  |  |  |
+ |  X  X  X  X  X  X  X  X  |  |  |  |  |  |  |
+ |  |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |  |  |  |
+ |  | \| \| \| \| \| \| \| \|  |  |  |  |  |  |
+ |  |  X  X  X  X  X  X  X  X  |  |  |  |  |  |
+ |  |  |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |  |  |
+ |  |  | \| \| \| \| \| \| \| \|  |  |  |  |  |
+ |  |  |  X  X  X  X  X  X  X  X  |  |  |  |  |
+ |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |  |
+ |  |  |  | \| \| \| \| \| \| \| \|  |  |  |  |
+ |  |  |  |  X  X  X  X  X  X  X  X  |  |  |  |
+ |  |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |  |  |  |
+ |  |  |  |  | \| \| \| \| \| \| \| \|  |  |  |
+ |  |  |  |  |  X  X  X  X  X  X  X  X  |  |  |
+ |  |  |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |  |  |
+ |  |  |  |  |  | \| \| \| \| \| \| \| \|  |  |
+ |  |  |  |  |  |  X  X  X  X  X  X  X  X  |  |
+ |  |  |  |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |  |
+ |  |  |  |  |  |  | \| \| \| \| \| \| \| \|  |
+ |  |  |  |  |  |  |  X  X  X  X  X  X  X  X  |
+ |  |  |  |  |  |  |  |\ |\ |\ |\ |\ |\ |\ |\ |
+ |  |  |  |  |  |  |  | \| \| \| \| \| \| \| \|
+ |  |  |  |  |  |  |  |  @  @  @  @  @  @  @  @
+ |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |
+"""
+        expected = expected[1:-1]  # trim newline at start and end
+        if text != expected:
+            print("text:")
+            print(text)
+            print()
+        self.assertEqual(expected, text)
+
+    # TODO: add more tests
+
+
+if __name__ == "__main__":
+    unittest.main()
index c5093a8df038f1583216e93559fbab88e766afdc..8bf0a9f2d298231279ab38e44b23fe0c5f14b630 100644 (file)
@@ -135,9 +135,9 @@ class InputTest:
             op2 = self.di[muxid][i]
             rs = self.dut.p[muxid]
             yield rs.i_valid.eq(1)
-            yield rs.data_i.data.eq(op2)
-            yield rs.data_i.idx.eq(i)
-            yield rs.data_i.muxid.eq(muxid)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.idx.eq(i)
+            yield rs.i_data.muxid.eq(muxid)
             yield
             o_p_ready = yield rs.o_ready
             step_limiter = StepLimiter(10000)
@@ -179,9 +179,9 @@ class InputTest:
             if not o_n_valid or not i_n_ready:
                 continue
 
-            muxid = yield n.data_o.muxid
-            out_i = yield n.data_o.idx
-            out_v = yield n.data_o.data
+            muxid = yield n.o_data.muxid
+            out_i = yield n.o_data.idx
+            out_v = yield n.o_data.data
 
             print("recv", muxid, out_i, hex(out_v))
 
diff --git a/src/nmutil/test/test_reservation_stations.py b/src/nmutil/test/test_reservation_stations.py
new file mode 100644 (file)
index 0000000..70c8495
--- /dev/null
@@ -0,0 +1,179 @@
+""" key strategic example showing how to do multi-input fan-in into a
+    multi-stage pipeline, then multi-output fanout.
+
+    the multiplex ID from the fan-in is passed in to the pipeline, preserved,
+    and used as a routing ID on the fanout.
+"""
+
+from random import randint
+from math import log
+from nmigen import Module, Signal, Cat, Value, Elaboratable
+from nmigen.compat.sim import run_simulation
+from nmigen.cli import verilog, rtlil
+
+from nmutil.concurrentunit import ReservationStations2
+from nmutil.singlepipe import SimpleHandshake, RecordObject, Object
+
+
+class PassData2(RecordObject):
+    def __init__(self):
+        RecordObject.__init__(self)
+        self.muxid = Signal(2, reset_less=True)
+        self.idx = Signal(8, reset_less=True)
+        self.data = Signal(16, reset_less=True)
+
+
+class PassData(Object):
+    def __init__(self, name=None):
+        Object.__init__(self)
+        if name is None:
+            name = ""
+        self.muxid = Signal(2, name="muxid"+name, reset_less=True)
+        self.idx = Signal(8, name="idx"+name, reset_less=True)
+        self.data = Signal(16, name="data"+name, reset_less=True)
+
+
+class PassThroughStage:
+    def ispec(self, name=None):
+        return PassData(name=name)
+
+    def ospec(self, name=None):
+        return self.ispec(name)  # same as ospec
+
+    def process(self, i):
+        return i  # pass-through
+
+
+class PassThroughPipe(SimpleHandshake):
+    def __init__(self):
+        SimpleHandshake.__init__(self, PassThroughStage())
+
+
+class InputTest:
+    def __init__(self, dut):
+        self.dut = dut
+        self.di = {}
+        self.do = {}
+        self.tlen = 100
+        for muxid in range(dut.num_rows):
+            self.di[muxid] = {}
+            self.do[muxid] = {}
+            for i in range(self.tlen):
+                self.di[muxid][i] = randint(0, 255) + (muxid << 8)
+                self.do[muxid][i] = self.di[muxid][i]
+
+    def send(self, muxid):
+        for i in range(self.tlen):
+            op2 = self.di[muxid][i]
+            rs = self.dut.p[muxid]
+            yield rs.i_valid.eq(1)
+            yield rs.i_data.data.eq(op2)
+            yield rs.i_data.idx.eq(i)
+            yield rs.i_data.muxid.eq(muxid)
+            yield
+            o_p_ready = yield rs.o_ready
+            while not o_p_ready:
+                yield
+                o_p_ready = yield rs.o_ready
+
+            print("send", muxid, i, hex(op2))
+
+            yield rs.i_valid.eq(0)
+            # wait random period of time before queueing another value
+            for i in range(randint(0, 3)):
+                yield
+
+        yield rs.i_valid.eq(0)
+        yield
+
+        print("send ended", muxid)
+
+        # wait random period of time before queueing another value
+        # for i in range(randint(0, 3)):
+        #    yield
+
+        #send_range = randint(0, 3)
+        # if send_range == 0:
+        #    send = True
+        # else:
+        #    send = randint(0, send_range) != 0
+
+    def rcv(self, muxid):
+        while True:
+            #stall_range = randint(0, 3)
+            # for j in range(randint(1,10)):
+            #    stall = randint(0, stall_range) != 0
+            #    yield self.dut.n[0].i_ready.eq(stall)
+            #    yield
+            n = self.dut.n[muxid]
+            yield n.i_ready.eq(1)
+            yield
+            o_n_valid = yield n.o_valid
+            i_n_ready = yield n.i_ready
+            if not o_n_valid or not i_n_ready:
+                continue
+
+            out_muxid = yield n.o_data.muxid
+            out_i = yield n.o_data.idx
+            out_v = yield n.o_data.data
+
+            print("recv", out_muxid, out_i, hex(out_v))
+
+            # see if this output has occurred already, delete it if it has
+            assert muxid == out_muxid, \
+                "out_muxid %d not correct %d" % (out_muxid, muxid)
+            assert out_i in self.do[muxid], "out_i %d not in array %s" % \
+                (out_i, repr(self.do[muxid]))
+            assert self.do[muxid][out_i] == out_v  # pass-through data
+            del self.do[muxid][out_i]
+
+            # check if there's any more outputs
+            if len(self.do[muxid]) == 0:
+                break
+        print("recv ended", muxid)
+
+
+class TestALU(Elaboratable):
+    def __init__(self):
+        self.pipe1 = PassThroughPipe()              # stage 1 (clock-sync)
+        self.pipe2 = PassThroughPipe()              # stage 2 (clock-sync)
+
+        self.p = self.pipe1.p
+        self.n = self.pipe2.n
+        self._ports = self.pipe1.ports() + self.pipe2.ports()
+
+    def elaborate(self, platform):
+        m = Module()
+        m.submodules.pipe1 = self.pipe1
+        m.submodules.pipe2 = self.pipe2
+
+        m.d.comb += self.pipe1.connect_to_next(self.pipe2)
+
+        return m
+
+    def new_specs(self, name):
+        return self.pipe1.ispec(name), self.pipe2.ospec(name)
+
+    def ports(self):
+        return self._ports
+
+
+def test1():
+    alu = TestALU()
+    dut = ReservationStations2(alu, num_rows=4)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_reservation_stations.il", "w") as f:
+        f.write(vl)
+    #run_simulation(dut, testbench(dut), vcd_name="test_inputgroup.vcd")
+
+    test = InputTest(dut)
+    run_simulation(dut, [test.rcv(1), test.rcv(0),
+                         test.rcv(3), test.rcv(2),
+                         test.send(0), test.send(1),
+                         test.send(3), test.send(2),
+                         ],
+                   vcd_name="test_reservation_stations.vcd")
+
+
+if __name__ == '__main__':
+    test1()
diff --git a/src/nmutil/toolchain.py b/src/nmutil/toolchain.py
new file mode 100644 (file)
index 0000000..123c4f0
--- /dev/null
@@ -0,0 +1,37 @@
+import os
+import shutil
+
+
+__all__ = ["ToolNotFound", "tool_env_var", "has_tool", "require_tool"]
+
+
+class ToolNotFound(Exception):
+    pass
+
+
+def tool_env_var(name):
+    return name.upper().replace("-", "_").replace("+", "X")
+
+
+def _get_tool(name):
+    return os.environ.get(tool_env_var(name), name)
+
+
+def has_tool(name):
+    return shutil.which(_get_tool(name)) is not None
+
+
+def require_tool(name):
+    env_var = tool_env_var(name)
+    path = _get_tool(name)
+    if shutil.which(path) is None:
+        if env_var in os.environ:
+            raise ToolNotFound("Could not find required tool {} in {} as "
+                               "specified via the {} environment variable".
+                               format(name, path, env_var))
+        else:
+            raise ToolNotFound("Could not find required tool {} in PATH. Place "
+                               "it directly in PATH or specify path explicitly "
+                               "via the {} environment variable".
+                               format(name, env_var))
+    return path
index ceb9a710df8ac7a0f42f319bae351dcb155fe0cd..5864d5edb92b8db47c752280c1fb3ae83145c698 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -8,6 +9,7 @@
 from collections.abc import Iterable
 from nmigen import Mux, Signal, Cat
 
+
 # XXX this already exists in nmigen._utils
 # see https://bugs.libre-soc.org/show_bug.cgi?id=297
 def flatten(v):
@@ -17,25 +19,36 @@ def flatten(v):
     else:
         yield v
 
+
 # tree reduction function.  operates recursively.
-def treereduce(tree, op, fn):
-    """treereduce: apply a map-reduce to a list.
+def treereduce(tree, op, fn=None):
+    """treereduce: apply a map-reduce to a list, reducing to a single item
+
+    this is *not* the same as "x = Signal(64) reduce(x, operator.add)",
+    which is a bit-wise reduction down to a single bit
+
+    it is "l = [Signal(w), ..., Signal(w)] reduce(l, operator.add)"
+    i.e. l[0] + l[1] ...
+
     examples: OR-reduction of one member of a list of Records down to a
-              single data point:
-              treereduce(tree, operator.or_, lambda x: getattr(x, "data_o"))
+              single value:
+              treereduce(tree, operator.or_, lambda x: getattr(x, "o_data"))
     """
-    #print ("treereduce", tree)
+    if fn is None:
+        def fn(x): return x
     if not isinstance(tree, list):
         return tree
     if len(tree) == 1:
         return fn(tree[0])
     if len(tree) == 2:
         return op(fn(tree[0]), fn(tree[1]))
-    s = len(tree) // 2 # splitpoint
+    s = len(tree) // 2  # splitpoint
     return op(treereduce(tree[:s], op, fn),
               treereduce(tree[s:], op, fn))
 
 # chooses assignment of 32 bit or full 64 bit depending on is_32bit
+
+
 def eq32(is_32bit, dest, src):
     return [dest[0:32].eq(src[0:32]),
             dest[32:64].eq(Mux(is_32bit, 0, src[32:64]))]
@@ -61,8 +74,8 @@ def rising_edge(m, sig):
     rising = Signal.like(sig)
     delay.name = "%s_dly" % sig.name
     rising.name = "%s_rise" % sig.name
-    m.d.sync += delay.eq(sig) # 1 clock delay
-    m.d.comb += rising.eq(sig & ~delay) # sig is hi but delay-sig is lo
+    m.d.sync += delay.eq(sig)  # 1 clock delay
+    m.d.comb += rising.eq(sig & ~delay)  # sig is hi but delay-sig is lo
     return rising