disable fadd f32 formal proofs by default -- they're too slow
[ieee754fpu.git] / src / ieee754 / part_ass / assign.py
index 00cebc6dab33a25f93d4e790dd7883fc859bae9e..9d411ba26c2c74d1c11439e207b42dbe3b3f3a5e 100644 (file)
@@ -19,7 +19,7 @@ from nmigen.back.pysim import Simulator, Settle
 from nmutil.extend import ext
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
-from ieee754.part.partsig import PartitionedSignal
+from ieee754.part.partsig import SimdSignal
 
 
 def get_runlengths(pbit, size):
@@ -43,26 +43,25 @@ def get_runlengths(pbit, size):
 
 
 class PartitionedAssign(Elaboratable):
-    def __init__(self, shape, assign, mask):
+    def __init__(self, shape, assign, ctx):
         """Create a ``PartitionedAssign`` operator
         """
-        # work out the length (total of all PartitionedSignals)
+        # work out the length (total of all SimdSignals)
         self.assign = assign
-        if isinstance(mask, dict):
-            mask = list(mask.values())
-        self.mask = mask
+        self.ptype = ctx
         self.shape = shape
-        self.output = PartitionedSignal(mask, self.shape, reset_less=True)
+        mask = ctx.get_mask()
+        self.output = SimdSignal(mask, self.shape, reset_less=True)
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
 
     def get_chunk(self, y, numparts):
         x = self.assign
-        if not isinstance(x, PartitionedSignal):
+        if not isinstance(x, SimdSignal):
             # assume Scalar. totally different rules
             end = numparts * (len(x) // self.mwidth)
             return x[:end]
-        # PartitionedSignal: start at partition point
+        # SimdSignal: start at partition point
         keys = [0] + list(x.partpoints.keys()) + [len(x)]
         # get current index and increment it (for next Assign chunk)
         upto = y[0]
@@ -72,22 +71,24 @@ class PartitionedAssign(Elaboratable):
         start = keys[upto]
         end = keys[upto+numparts]
         print ("start end", start, end, len(x))
-        return x[start:end]
+        # access the underlying signal of SimdSignal directly
+        return x.sig[start:end]
 
     def elaborate(self, platform):
+        print ("PartitionedAssign start")
         m = Module()
         comb = m.d.comb
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        print ("mask", self.mask)
+        print ("ptype", self.ptype)
         outpartsize = len(self.output) // self.mwidth
         width, signed = self.output.shape()
         print ("width, signed", width, signed)
 
-        with m.Switch(Cat(self.mask)):
+        with m.Switch(self.ptype.get_switch()):
             # for each partition possibility, create a Assign sequence
-            for pbit in range(1<<len(keys)):
+            for pbit in self.ptype.get_cases():
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
@@ -107,10 +108,11 @@ class PartitionedAssign(Elaboratable):
                     # direct access to the underlying Signal
                     comb += self.output.sig.eq(Cat(*output))
 
+        print ("PartitionedAssign end")
         return m
 
     def ports(self):
-        if isinstance(self.assign, PartitionedSignal):
+        if isinstance(self.assign, SimdSignal):
             return [self.assign.lower(), self.output.lower()]
         return [self.assign, self.output.lower()]
 
@@ -119,8 +121,8 @@ if __name__ == "__main__":
     from ieee754.part.test.test_partsig import create_simulator
     m = Module()
     mask = Signal(3)
-    a = PartitionedSignal(mask, 32)
-    m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
+    a = SimdSignal(mask, 32)
+    m.submodules.ass = ass = PartitionedAssign(signed(48), a, a.ptype)
     omask = (1<<len(ass.output))-1
 
     traces = ass.ports()
@@ -154,7 +156,20 @@ if __name__ == "__main__":
     m = Module()
     mask = Signal(3)
     a = Signal(32)
-    m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
+    class PartType:
+        def __init__(self, mask):
+            self.mask = mask
+        def get_mask(self):
+            return mask
+        def get_switch(self):
+            return Cat(self.get_mask())
+        def get_cases(self):
+            return range(1<<len(self.get_mask()))
+        @property
+        def blanklanes(self):
+            return 0
+    ptype = PartType(mask)
+    m.submodules.ass = ass = PartitionedAssign(signed(48), a, ptype)
     omask = (1<<len(ass.output))-1
 
     traces = ass.ports()