fix chunking to get correct order for PartitionedCat
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 24 Sep 2021 19:17:34 +0000 (20:17 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 24 Sep 2021 19:17:34 +0000 (20:17 +0100)
src/ieee754/part_cat/cat.py

index aceafbb19572905692afc0b16de859a6e6f7fa90..58aaecf8e768d8b8026691a085a2756cb257344f 100644 (file)
@@ -74,16 +74,16 @@ class PartitionedCat(Elaboratable):
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
 
-    def get_chunk(self, y, idx):
+    def get_chunk(self, y, idx, numparts):
         x = self.catlist[idx]
         keys = [0] + list(x.partpoints.keys()) + [len(x.sig)]
         # get current index and increment it (for next Cat chunk)
         upto = y[idx]
-        y[idx] += 1
-        print ("getting", idx, upto, keys, len(x.sig))
+        y[idx] += numparts
+        print ("getting", idx, upto, numparts, keys, len(x.sig))
         # get the partition point as far as we are up to
         start = keys[upto]
-        end = keys[upto+1]
+        end = keys[upto+numparts]
         print ("start end", start, end, len(x.sig))
         return x.sig[start:end]
 
@@ -99,15 +99,15 @@ class PartitionedCat(Elaboratable):
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
+                output = []
                 y = [0] * len(self.catlist)
-                for yidx in range(len(y)):
-                    # get a list of the length of each partition run
-                    runlengths = get_runlengths(pbit, len(keys)-1)
-                    output = []
-                    for i in runlengths: # for each partition
-                        for _ in range(i): # for the length of each partition
-                            thing = self.get_chunk(y, yidx)# sequential chunks
-                            output.append(thing)
+                # get a list of the length of each partition run
+                runlengths = get_runlengths(pbit, len(keys)-1)
+                print ("pbit", bin(pbit), "runs", runlengths)
+                for i in runlengths: # for each partition
+                    for yidx in range(len(y)):
+                        thing = self.get_chunk(y, yidx, i) # sequential chunks
+                        output.append(thing)
                 with m.Case(pbit):
                     # direct access to the underlying Signal
                     comb += self.output.sig.eq(Cat(*output))
@@ -125,7 +125,7 @@ if __name__ == "__main__":
     from ieee754.part_mul_add.partpoints import make_partition
     m = Module()
     mask = Signal(4)
-    a = PartitionedSignal(mask, 32)
+    a = PartitionedSignal(mask, 16)
     b = PartitionedSignal(mask, 16)
     catlist = [a, b]
     m.submodules.cat = cat = PartitionedCat(catlist, mask)
@@ -134,7 +134,7 @@ if __name__ == "__main__":
     sim = create_simulator(m, traces, "partcat")
 
     def process():
-        yield mask.eq(0b010)
+        yield mask.eq(0b000)
         yield a.eq(0x01234567)
         yield b.eq(0xfdbc)
         yield Settle()