fix syntax errors for test_partsig
[ieee754fpu.git] / src / ieee754 / part_mux / part_mux.py
index 54f34812d4cd06bc3195ea1aca0b198d2707c43d..b0cd90b4744d6bb80175a2cad0b63a797699bd72 100644 (file)
@@ -15,13 +15,17 @@ See:
 
 from nmigen import Signal, Module, Elaboratable, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.partpoints import make_partition
 
 modcount = 0 # global for now
-def PMux(m, sel, a, b):
+def PMux(m, mask, sel, a, b):
+    global modcount
     modcount += 1
-    pm = PartitionedMux(a.shape()[0])
-    m.d.comb += pm.a.eq(a)
-    m.d.comb += pm.b.eq(b)
+    width = a.sig.shape()[0] # get width
+    part_pts = make_partition(mask, width) # create partition points
+    pm = PartitionedMux(width, part_pts)
+    m.d.comb += pm.a.eq(a.sig)
+    m.d.comb += pm.b.eq(b.sig)
     m.d.comb += pm.sel.eq(sel)
     setattr(m.submodules, "pmux%d" % modcount, pm)
     return pm.output
@@ -35,7 +39,7 @@ class PartitionedMux(Elaboratable):
     consequently the incoming selector (sel) can completely
     ignore what the *actual* partition bits are.
     """
-    def __init__(self, width):
+    def __init__(self, width, partition_points):
         self.width = width
         self.partition_points = PartitionPoints(partition_points)
         self.mwidth = len(self.partition_points)+1
@@ -43,8 +47,8 @@ class PartitionedMux(Elaboratable):
         self.b = Signal(width, reset_less=True)
         self.sel = Signal(self.mwidth, reset_less=True)
         self.output = Signal(width, reset_less=True)
-        assert (self.partition_points.fits_in_width(width),
-                    "partition_points doesn't fit in width")
+        assert self.partition_points.fits_in_width(width), \
+                    "partition_points doesn't fit in width"
 
     def elaborate(self, platform):
         m = Module()
@@ -56,12 +60,12 @@ class PartitionedMux(Elaboratable):
         start = 0
         for i in range(len(keys)):
             end = keys[i]
-            mux = output[start:end]
-            mux.append(self.a[start:end] == self.b[start:end])
+            mux = self.output[start:end]
+            comb += mux.eq(self.a[start:end] == self.b[start:end])
             start = end  # for next time round loop
 
         return m
 
     def ports(self):
-        return [self.a, self.b, self.sel, self.output]
+        return [self.a.sig, self.b.sig, self.sel, self.output]