big rename PartitionedSignal to SimdSignal (shorter)
[ieee754fpu.git] / src / ieee754 / part_cat / cat.py
index aceafbb19572905692afc0b16de859a6e6f7fa90..09170898ad7a93b0e68ab5123baec6a18d9fa218 100644 (file)
@@ -34,11 +34,10 @@ from nmigen import Signal, Module, Elaboratable, Cat, C
 from nmigen.back.pysim import Simulator, Settle
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
-from ieee754.part.partsig import PartitionedSignal
+from ieee754.part.partsig import SimdSignal
 from ieee754.part.test.test_partsig import create_simulator
 
 
-
 def get_runlengths(pbit, size):
     res = []
     count = 1
@@ -60,30 +59,31 @@ def get_runlengths(pbit, size):
 
 
 class PartitionedCat(Elaboratable):
-    def __init__(self, catlist, mask):
+    def __init__(self, catlist, ctx):
         """Create a ``PartitionedCat`` operator
         """
-        # work out the length (total of all PartitionedSignals)
+        # work out the length (total of all SimdSignals)
         self.catlist = catlist
-        self.mask = mask
+        self.ptype = ctx
         width = 0
         for p in catlist:
             width += len(p.sig)
         self.width = width
-        self.output = PartitionedSignal(mask, self.width, reset_less=True)
+        mask = ctx.get_mask()
+        self.output = SimdSignal(mask, self.width, reset_less=True)
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
 
-    def get_chunk(self, y, idx):
+    def get_chunk(self, y, idx, numparts):
         x = self.catlist[idx]
         keys = [0] + list(x.partpoints.keys()) + [len(x.sig)]
         # get current index and increment it (for next Cat chunk)
         upto = y[idx]
-        y[idx] += 1
-        print ("getting", idx, upto, keys, len(x.sig))
+        y[idx] += numparts
+        print ("getting", idx, upto, numparts, keys, len(x.sig))
         # get the partition point as far as we are up to
         start = keys[upto]
-        end = keys[upto+1]
+        end = keys[upto+numparts]
         print ("start end", start, end, len(x.sig))
         return x.sig[start:end]
 
@@ -93,21 +93,22 @@ class PartitionedCat(Elaboratable):
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        with m.Switch(self.mask[:-1]):
+        print ("ptype", self.ptype)
+        with m.Switch(self.ptype.get_switch()):
             # for each partition possibility, create a Cat sequence
-            for pbit in range(1<<len(keys)-1):
+            for pbit in self.ptype.get_cases():
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
+                output = []
                 y = [0] * len(self.catlist)
-                for yidx in range(len(y)):
-                    # get a list of the length of each partition run
-                    runlengths = get_runlengths(pbit, len(keys)-1)
-                    output = []
-                    for i in runlengths: # for each partition
-                        for _ in range(i): # for the length of each partition
-                            thing = self.get_chunk(y, yidx)# sequential chunks
-                            output.append(thing)
+                # get a list of the length of each partition run
+                runlengths = get_runlengths(pbit, len(keys))
+                print ("pbit", bin(pbit), "runs", runlengths)
+                for i in runlengths: # for each partition
+                    for yidx in range(len(y)):
+                        thing = self.get_chunk(y, yidx, i) # sequential chunks
+                        output.append(thing)
                 with m.Case(pbit):
                     # direct access to the underlying Signal
                     comb += self.output.sig.eq(Cat(*output))
@@ -122,36 +123,41 @@ class PartitionedCat(Elaboratable):
 
 
 if __name__ == "__main__":
-    from ieee754.part_mul_add.partpoints import make_partition
     m = Module()
-    mask = Signal(4)
-    a = PartitionedSignal(mask, 32)
-    b = PartitionedSignal(mask, 16)
+    mask = Signal(3)
+    a = SimdSignal(mask, 32)
+    b = SimdSignal(mask, 16)
     catlist = [a, b]
-    m.submodules.cat = cat = PartitionedCat(catlist, mask)
+    m.submodules.cat = cat = PartitionedCat(catlist, a.ptype)
 
     traces = cat.ports()
     sim = create_simulator(m, traces, "partcat")
 
     def process():
-        yield mask.eq(0b010)
-        yield a.eq(0x01234567)
-        yield b.eq(0xfdbc)
+        yield mask.eq(0b000)
+        yield a.sig.eq(0x01234567)
+        yield b.sig.eq(0xfdbc)
         yield Settle()
         out = yield cat.output.sig
-        print("out", bin(out), hex(out))
-        yield mask.eq(0b111)
-        yield a.eq(0x01234567)
-        yield b.eq(0xfdbc)
+        print("out 000", bin(out), hex(out))
+        yield mask.eq(0b010)
+        yield a.sig.eq(0x01234567)
+        yield b.sig.eq(0xfdbc)
         yield Settle()
         out = yield cat.output.sig
-        print("out", bin(out), hex(out))
+        print("out 010", bin(out), hex(out))
         yield mask.eq(0b110)
-        yield a.eq(0x01234567)
-        yield b.eq(0xfdbc)
+        yield a.sig.eq(0x01234567)
+        yield b.sig.eq(0xfdbc)
+        yield Settle()
+        out = yield cat.output.sig
+        print("out 110", bin(out), hex(out))
+        yield mask.eq(0b111)
+        yield a.sig.eq(0x01234567)
+        yield b.sig.eq(0xfdbc)
         yield Settle()
         out = yield cat.output.sig
-        print("out", bin(out), hex(out))
+        print("out 111", bin(out), hex(out))
 
     sim.add_process(process)
     with sim.write_vcd("partition_cat.vcd", "partition_cat.gtkw",