"""
 
 from ieee754.part_mul_add.adder import PartitionedAdder
+from ieee754.part_mul_add.partpoints import make_partition
 from nmigen import (Signal,
                     )
 
 class PartitionedSignal:
-    def __init__(self, partition_points, *args, **kwargs):
-        self.partpoints = partition_points
+    def __init__(self, mask, *args, **kwargs):
         self.sig = Signal(*args, **kwargs)
+        width = self.sig.shape()[0] # get signal width
+        self.partpoints = make_partition(mask, width) # create partition points
         self.modnames = {}
         for name in ['add']:
             self.modnames[name] = 0
 
 class TestPartitionPoints(unittest.TestCase):
     def test(self):
         width = 16
-        partition_nibbles = Signal() # divide into 4-bits
-        partition_bytes = Signal()   # divide on 8-bits
-        partpoints = {0x4: partition_nibbles,
-                      0x8: partition_bytes | partition_nibbles,
-                      0xC: partition_nibbles}
-        module = TestAddMod(width, partpoints)
+        part_mask = Signal(4) # divide into 4-bits
+        module = TestAddMod(width, part_mask)
 
         sim = create_simulator(module,
-                              [partition_nibbles,
-                               partition_bytes,
+                              [part_mask,
                                module.a.sig,
                                module.b.sig,
                                module.add_output],
                     msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
                         f" => 0x{y:X} != 0x{outval:X}"
                     self.assertEqual(y, outval, msg)
-            yield partition_nibbles.eq(0)
-            yield partition_bytes.eq(0)
+            yield part_mask.eq(0)
             yield from test_add("16-bit", 0xFFFF)
-            yield partition_nibbles.eq(0)
-            yield partition_bytes.eq(1)
+            yield part_mask.eq(0b10)
             yield from test_add("8-bit", 0xFF00, 0x00FF)
-            yield partition_nibbles.eq(1)
-            yield partition_bytes.eq(0)
+            yield part_mask.eq(0b1111)
             yield from test_add("4-bit", 0xF000, 0x0F00, 0x00F0, 0x000F)
 
         sim.add_process(async_process)
 
 
 from nmigen import Signal, Value, Cat, C
 
+def make_partition(mask, width):
+    """ from a mask and a bitwidth, create partition points.
+        note that the assumption is that the mask indicates the
+        breakpoints in regular intervals, and that the last bit (MSB)
+        of the mask is therefore *ignored*.
+        mask len = 4, width == 16 will return:
+            {4: mask[0], 8: mask[1], 12: mask[2]}
+        mask len = 8, width == 64 will return:
+            {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
+    """
+    ppoints = {}
+    mlen = mask.shape()[0]
+    ppos = mlen
+    midx = 0
+    while ppos < width:
+        ppoints[ppos] = mask[midx]
+        ppos += mlen
+        midx += 1
+    return ppoints
+
 
 class PartitionPoints(dict):
     """Partition points and corresponding ``Value``s.