add comments / docstrings for layout function to illustrate where
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
index dc827cb2cd8653ea8d665de8b653812f3f5ce12b..c3d78f020b5ed2af7669dafd4297f89918561b66 100644 (file)
@@ -25,6 +25,8 @@ from pprint import pprint
 from ieee754.part_mul_add.partpoints import PartitionPoints
 
 
+# XXX MAKE SURE TO PRESERVE ALL THESE COMMENTS XXX
+
 # main fn, which started out here in the bugtracker:
 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
 # note that signed is **NOT** part of the layout, and will NOT
@@ -34,7 +36,13 @@ from ieee754.part_mul_add.partpoints import PartitionPoints
 # way requires a sign.  it is *purely* performing numerical width
 # computations that have absolutely nothing to do with whether the
 # actual data is signed or unsigned.
-def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
+#
+# context for parameters:
+# http://lists.libre-soc.org/pipermail/libre-soc-dev/2021-October/003921.html
+def layout(elwid,            # comes from SimdScope constructor
+           vec_el_counts,    # comes from SimdScope constructor
+           lane_shapes=None,   # from SimdScope.Signal via a SimdShape
+           fixed_width=None):  # from SimdScope.Signal via a SimdShape
     """calculate a SIMD layout.
 
     Glossary:
@@ -60,19 +68,6 @@ def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
             F16 = ...    # SVP64 value 0b10
             BF16 = ...   # SVP64 value 0b11
 
-    # XXX this is redundant and out-of-date with respect to the
-    # clarification that the input is in counts of *elements*
-    # *NOT* "fixed width parts".
-    # fixed-width parts results in 14 such parts being created
-    # when 5 will do, for a simple example 5-6-6-6
-    * part: A piece of a SIMD vector, every SIMD vector is made of a
-        non-negative integer of parts. Elements are made of a power-of-two
-        number of parts. A part is a fixed number of bits wide for each
-        different SIMD layout, it doesn't vary when `elwid` changes. A part
-        can have a bit width of any non-negative integer, it is not restricted
-        to power-of-two. SIMD vectors should have as few parts as necessary,
-        since some circuits have size proportional to the number of parts.
-
     * elwid: ElWid or nmigen Value with ElWid as the shape
         the current element-width
 
@@ -143,36 +138,48 @@ def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
     # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
     # https://stackoverflow.com/questions/26367812/
     dpoints = defaultdict(list)  # if empty key, create a (empty) list
+    padding_masks = {}
+    always_padding_mask = (1 << width) - 1  # start with all bits padding
     for i, c in vec_el_counts.items():
         print("dpoints", i, "count", c)
         # calculate part_wid based on overall width divided by number
         # of elements.
         part_wid = width // c
 
+        padding_mask = (1 << width) - 1  # start with all bits padding
+
         def add_p(msg, start, p):
             print("    adding dpoint", msg, start, part_wid, i, c, p)
             dpoints[p].append(i)  # auto-creates list if key non-existent
         # for each elwidth, create the required number of vector elements
         for start in range(c):
-            add_p("start", start, start * part_wid)  # start of lane
-            add_p("end  ", start, start * part_wid +
-                  lane_shapes[i])  # end lane
+            start_bit = start * part_wid
+            end_bit = start_bit + lane_shapes[i]
+            element_mask = (1 << end_bit) - (1 << start_bit)
+            padding_mask &= ~element_mask  # remove element from padding_mask
+            add_p("start", start, start_bit)  # start of lane
+            add_p("end  ", start, end_bit)  # end lane
+        padding_masks[i] = padding_mask
+        always_padding_mask &= padding_mask
+
+    # deduplicate dpoints lists
+    for k in dpoints.keys():
+        dpoints[k] = list({i: None for i in dpoints[k]}.keys())
 
     # do not need the breakpoints at the very start or the very end
     dpoints.pop(0, None)
-    if fixed_width is not None:
-        dpoints.pop(fixed_width, None)
-    else:
-        dpoints.pop(width, None)
-    plist = list(dpoints.keys())
-    plist.sort()
+    dpoints.pop(width, None)
+
+    # sort dpoints keys
+    dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
+
     print("dpoints")
-    pprint(dict(dpoints))
+    pprint(dpoints)
 
     # second stage, add (map to) the elwidth==i expressions.
     # TODO: use nmutil.treereduce?
     points = {}
-    for p in plist:
+    for p in dpoints.keys():
         points[p] = map(lambda i: elwid == i, dpoints[p])
         points[p] = reduce(operator.or_, points[p])
 
@@ -182,19 +189,29 @@ def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
     bitp = {}
     for i in vec_el_counts.keys():
         bitp[i] = 0
-        for p, elwidths in dpoints.items():
+        for bit_index, (p, elwidths) in enumerate(dpoints.items()):
             if i in elwidths:
-                bitpos = plist.index(p)
-                bitp[i] |= 1 << bitpos
+                bitp[i] |= 1 << bit_index
 
     # fourth stage: determine which partitions are 100% unused.
     # these can then be "blanked out"
-    bmask = (1 << len(plist))-1
-    for p in bitp.values():
-        bmask &= ~p
+
+    # points are the partition separators, not partition indexes
+    partition_ends = [*dpoints.keys(), width]
+    bmask = 0
+    partition_start = 0
+    for bit_index, partition_end in enumerate(partition_ends):
+        pmask = (1 << partition_end) - (1 << partition_start)
+        always_padding = (always_padding_mask & pmask) == pmask
+        if always_padding:
+            bmask |= 1 << bit_index
+        partition_start = partition_end
     return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
             part_wid)
 
+# XXX XXX XXX XXX quick tests TODO convert to proper ones but kinda good
+# enough for now.  if adding new tests do not alter or delete the old ones
+# XXX XXX XXX XXX
 
 if __name__ == '__main__':
 
@@ -248,6 +265,7 @@ if __name__ == '__main__':
     # now check that the expected partition points occur
     print("5,6,6,6 ppt keys", pp.keys())
     assert list(pp.keys()) == [5, 6, 12, 18]
+    assert bm == 0  # no unused partitions
 
     # this example was probably what the 5,6,6,6 one was supposed to be.
     # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
@@ -273,6 +291,8 @@ if __name__ == '__main__':
     # now check that the expected partition points occur
     print("24,12,5,6 ppt keys", pp.keys())
     assert list(pp.keys()) == [5, 6, 12, 17, 18]
+    print("bmask", bin(bm))
+    assert bm == 0  # no unused partitions
 
     # this tests elwidth as an actual Signal. layout is allowed to
     # determine arbitrarily the overall length
@@ -285,6 +305,7 @@ if __name__ == '__main__':
     for k, v in bitp.items():
         print("bitp elwidth=%d" % k, bin(v))
     print("bmask", bin(bm))
+    assert bm == 0  # no unused partitions
 
     m = Module()
 
@@ -311,6 +332,26 @@ if __name__ == '__main__':
     # determine arbitrarily the overall length, it is fixed to 64
     # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
 
+    # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
+    # elwidth=0b00 1x 24-bit
+    # elwidth=0b01 1x 12-bit
+    # elwidth=0b10 2x 5-bit
+    # elwidth=0b11 4x 6-bit
+    #
+    # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
+    # always unused:|     |     |    ||    |    |   | |     |    |     ||    |
+    #      1111111111000000 1111111111000000 1111111100000000 0000000000000000
+    #               |     |     |    ||    |    |   | |     |    |     ||    |
+    # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
+    # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
+    # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
+    # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
+    #               ^     ^          ^^    ^        ^ ^     ^    ^     ^^
+    #     ppoints:  |     |          ||    |        | |     |    |     ||
+    #               |  bit-48        /\    | bit-24-/ |     | bit-12   /\-bit-5
+    #            bit-54      bit-38-/  \ bit-32       |   bit-16      /
+    #                                 bit-37       bit-22          bit-6
+
     elwid = Signal(2)
     pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
                                    widths_at_elwidth,
@@ -319,6 +360,7 @@ if __name__ == '__main__':
     for k, v in bitp.items():
         print("bitp elwidth=%d" % k, bin(v))
     print("bmask", bin(bm))
+    assert bm == 0b101001000000
 
     m = Module()