add first cut at PartitionedSignal "Cat"
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 24 Sep 2021 13:59:01 +0000 (14:59 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 24 Sep 2021 13:59:01 +0000 (14:59 +0100)
src/ieee754/part_cat/cat.py [new file with mode: 0644]
src/ieee754/part_mul_add/partpoints.py

diff --git a/src/ieee754/part_cat/cat.py b/src/ieee754/part_cat/cat.py
new file mode 100644 (file)
index 0000000..aceafbb
--- /dev/null
@@ -0,0 +1,159 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+"""
+Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+
+dynamically-partitionable "cat" class, directly equivalent
+to nmigen Cat
+
+See:
+
+* http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/cat
+* http://bugs.libre-riscv.org/show_bug.cgi?id=707
+
+m.Switch()
+for pbits cases: 0b000 to 0b111
+  output = []
+  # set up some yielders which will retain where they each got to
+  # then when called below in the inner nested loop they give
+  # the relevant sequential chunk
+  yielders = [Yielder(a), Yielder(b), ....]
+  runlist = split pbits into runs of zeros
+  for y in yielders: # for each signal a b c d ...
+     for i in runlist: # for each partition
+        for _ in range(i)+1: # for the length of each partition
+            thing = yield from y # grab sequential chunks
+            output.append(thing)
+  with m.Case(pbits):
+     comb += out.eq(Cat(*output)
+
+"""
+
+from nmigen import Signal, Module, Elaboratable, Cat, C
+from nmigen.back.pysim import Simulator, Settle
+
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part.partsig import PartitionedSignal
+from ieee754.part.test.test_partsig import create_simulator
+
+
+
+def get_runlengths(pbit, size):
+    res = []
+    count = 1
+    # identify where the 1s are, which indicates "start of a new partition"
+    # we want a list of the lengths of all partitions
+    for i in range(size):
+        if pbit & (1<<i): # it's a 1: ends old partition, starts new
+            res.append(count) # add partition
+            count = 1 # start again
+        else:
+            count += 1
+    # end reached, add whatever is left. could have done this by creating
+    # "fake" extra bit on the partitions, but hey
+    res.append(count)
+
+    print ("get_runlengths", bin(pbit), size, res)
+
+    return res
+
+
+class PartitionedCat(Elaboratable):
+    def __init__(self, catlist, mask):
+        """Create a ``PartitionedCat`` operator
+        """
+        # work out the length (total of all PartitionedSignals)
+        self.catlist = catlist
+        self.mask = mask
+        width = 0
+        for p in catlist:
+            width += len(p.sig)
+        self.width = width
+        self.output = PartitionedSignal(mask, self.width, reset_less=True)
+        self.partition_points = self.output.partpoints
+        self.mwidth = len(self.partition_points)+1
+
+    def get_chunk(self, y, idx):
+        x = self.catlist[idx]
+        keys = [0] + list(x.partpoints.keys()) + [len(x.sig)]
+        # get current index and increment it (for next Cat chunk)
+        upto = y[idx]
+        y[idx] += 1
+        print ("getting", idx, upto, keys, len(x.sig))
+        # get the partition point as far as we are up to
+        start = keys[upto]
+        end = keys[upto+1]
+        print ("start end", start, end, len(x.sig))
+        return x.sig[start:end]
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+
+        keys = list(self.partition_points.keys())
+        print ("keys", keys, "values", self.partition_points.values())
+        with m.Switch(self.mask[:-1]):
+            # for each partition possibility, create a Cat sequence
+            for pbit in range(1<<len(keys)-1):
+                # set up some indices pointing to where things have got
+                # then when called below in the inner nested loop they give
+                # the relevant sequential chunk
+                y = [0] * len(self.catlist)
+                for yidx in range(len(y)):
+                    # get a list of the length of each partition run
+                    runlengths = get_runlengths(pbit, len(keys)-1)
+                    output = []
+                    for i in runlengths: # for each partition
+                        for _ in range(i): # for the length of each partition
+                            thing = self.get_chunk(y, yidx)# sequential chunks
+                            output.append(thing)
+                with m.Case(pbit):
+                    # direct access to the underlying Signal
+                    comb += self.output.sig.eq(Cat(*output))
+
+        return m
+
+    def ports(self):
+        res = []
+        for p in self.catlist + [self.output]:
+            res.append(p.sig)
+        return res
+
+
+if __name__ == "__main__":
+    from ieee754.part_mul_add.partpoints import make_partition
+    m = Module()
+    mask = Signal(4)
+    a = PartitionedSignal(mask, 32)
+    b = PartitionedSignal(mask, 16)
+    catlist = [a, b]
+    m.submodules.cat = cat = PartitionedCat(catlist, mask)
+
+    traces = cat.ports()
+    sim = create_simulator(m, traces, "partcat")
+
+    def process():
+        yield mask.eq(0b010)
+        yield a.eq(0x01234567)
+        yield b.eq(0xfdbc)
+        yield Settle()
+        out = yield cat.output.sig
+        print("out", bin(out), hex(out))
+        yield mask.eq(0b111)
+        yield a.eq(0x01234567)
+        yield b.eq(0xfdbc)
+        yield Settle()
+        out = yield cat.output.sig
+        print("out", bin(out), hex(out))
+        yield mask.eq(0b110)
+        yield a.eq(0x01234567)
+        yield b.eq(0xfdbc)
+        yield Settle()
+        out = yield cat.output.sig
+        print("out", bin(out), hex(out))
+
+    sim.add_process(process)
+    with sim.write_vcd("partition_cat.vcd", "partition_cat.gtkw",
+                        traces=traces):
+        sim.run()
index f7fbd2cb5017eaa734c01bc3014c1666a5e97191..887372af62cef7bb7b213cd63e98dbb10aae85ce 100644 (file)
@@ -4,6 +4,7 @@
 
 from nmigen import Signal, Value, Cat, C
 
+
 def make_partition(mask, width):
     """ from a mask and a bitwidth, create partition points.
         note that the assumption is that the mask indicates the
@@ -25,6 +26,28 @@ def make_partition(mask, width):
     return ppoints
 
 
+def make_partition2(mask, width):
+    """ from a mask and a bitwidth, create partition points.
+        note that the assumption is that the mask indicates the
+        breakpoints in regular intervals, and that the last bit (MSB)
+        of the mask is therefore *ignored*.
+        mask len = 4, width == 16 will return:
+            {4: mask[0], 8: mask[1], 12: mask[2]}
+        mask len = 8, width == 64 will return:
+            {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
+    """
+    mlen = len(mask)
+    jumpsize = width // mlen # amount to jump by (size of each partition)
+    ppoints = {}
+    ppos = jumpsize
+    midx = 0
+    while ppos < width and midx < mlen: # -1, ignore last bit
+        ppoints[ppos] = mask[midx]
+        ppos += jumpsize
+        midx += 1
+    return ppoints
+
+
 class PartitionPoints(dict):
     """Partition points and corresponding ``Value``s.