add comments / docstrings for layout function to illustrate where
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
index 26801a10712b043d22f06112828d799d3993654d..c3d78f020b5ed2af7669dafd4297f89918561b66 100644 (file)
@@ -13,365 +13,316 @@ Links:
 """
 
 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
-from nmigen.sim import Simulator, Delay, Settle
+from nmigen.back.pysim import Simulator, Delay, Settle
 from nmigen.cli import rtlil
-from enum import Enum
 
 from collections.abc import Mapping
 from functools import reduce
 import operator
 from collections import defaultdict
-import dataclasses
-from ieee754.part.util import XLEN, FpElWid, IntElWid, SimdMap, SimdScope
-from ieee754.part_mul_add.partpoints import PartitionPoints
+from pprint import pprint
 
+from ieee754.part_mul_add.partpoints import PartitionPoints
 
-@dataclasses.dataclass
-class LayoutResult:
-    ppoints: PartitionPoints
-    bitp: dict
-    bmask: int
-    width: int
-    lane_shapes: dict
-    part_wid: int
-    full_part_count: int
-
-    def __repr__(self):
-        fields = []
-        for field in dataclasses.fields(LayoutResult):
-            field_v = getattr(self, field.name)
-            if isinstance(field_v, PartitionPoints):
-                field_v = ',\n        '.join(
-                    f"{k}: {v}" for k, v in field_v.items())
-                field_v = f"{{{field_v}}}"
-            fields.append(f"{field.name}={field_v}")
-        fields = ",\n    ".join(fields)
-        return f"LayoutResult({fields})"
-
-
-class SimdLayout(Shape):
-    def __init__(self, lane_shapes=None, signed=None, *, fixed_width=None,
-                 width_follows_hint=True, scope=None):
-        """calculate a SIMD layout.
-
-        Glossary:
-        * element: a single scalar value that is an element of a SIMD vector.
-            it has a width in bits, and a signedness. Every element is made of
-            1 or more parts. An element optionally includes the padding
-            associated with it.
-        * lane: an element. An element optionally includes the padding
-            associated with it.
-        * ElWid: the element-width (really the element type) of an instruction.
-            Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
-            In Python, `ElWid` is either an enum type or is `int`.
-            Example `ElWid` definition for integers:
-
-            class ElWid(Enum):
-                I8 = ...
-                I16 = ...
-                I32 = ...
-                I64 = ...
-
-            Example `ElWid` definition for floats:
-
-            class ElWid(Enum):
-                F16 = ...
-                BF16 = ...
-                F32 = ...
-                F64 = ...
-        * part: (not to be confused with a partition) A piece of a SIMD vector,
-            every SIMD vector is made of a non-negative integer of parts.
-            Elements are made of a power-of-two number of parts. A part is a
-            fixed number of bits wide for each different SIMD layout, it
-            doesn't vary when `elwid` changes. A part can have a bit width of
-            any non-negative integer, it is not restricted to power-of-two.
-
-
-        Arguments:
-        * lane_shapes: int or Mapping[ElWid, int] or SimdMap (optional)
-            the bit-width of all elements in this SIMD layout.
-        * signed: bool
-            the signedness of all elements in this SIMD layout
-        * fixed_width: int (optional)
-            the total width of a SIMD vector. One of lane_shapes and fixed_width
-            must be provided.
-        * width_follows_hint: bool
-            if fixed_width defaults to SimdScope.get().simd_full_width_hint
-
-        Values used from SimdScope:
-        * elwid: ElWid or nmigen Value with ElWid as the shape
-            the current ElWid value
-        * part_counts: SimdMap
-            a map from `ElWid` values `k` to the number of parts in an element
-            when `elwid == k`. Values should be minimized, since higher values
-            often create bigger circuits.
-
-            Example:
-            # here, an I8 element is 1 part wide
-            part_counts = SimdMap({
-                IntElWid.I8: 1,
-                IntElWid.I16: 2,
-                IntElWid.I32: 4, 
-                IntElWid.I64: 8,
-            })
-
-            Another Example:
-            # here, an F16 element is 1 part wide
-            part_counts = SimdMap({
-                FpElWid.F16: 1,
-                FpElWid.BF16: 1,
-                FpElWid.F32: 2,
-                FpElWid.F64: 4,
-            })
-        """
-        if scope is None:
-            scope = SimdScope.get()
-        assert isinstance(scope, SimdScope)
-        self.scope = scope
-        elwid = self.scope.elwid
-        part_counts = self.scope.part_counts
-        assert isinstance(part_counts, SimdMap)
-        simd_full_width_hint = self.scope.simd_full_width_hint
-        full_part_count = self.scope.full_part_count
-        print(f"layout(elwid={elwid},\n"
-              f"    signed={signed},\n"
-              f"    part_counts={part_counts},\n"
-              f"    lane_shapes={lane_shapes},\n"
-              f"    fixed_width={fixed_width},\n"
-              f"    simd_full_width_hint={simd_full_width_hint},\n"
-              f"    width_follows_hint={width_follows_hint})")
-
-        # when there are no lane_shapes specified, this indicates a
-        # desire to use the maximum available space based on the fixed width
-        # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
-        if lane_shapes is None:
-            assert fixed_width is not None, \
-                "both fixed_width and lane_shapes cannot be None"
-            lane_shapes = {}
-            for k, cur_part_count in part_counts.items():
-                cur_element_count = full_part_count // cur_part_count
-                assert fixed_width % cur_element_count == 0, (
-                    f"fixed_width ({fixed_width}) can't be split evenly into "
-                    f"{cur_element_count} elements")
-                lane_shapes[k] = fixed_width // cur_element_count
-            print("lane_shapes", fixed_width, lane_shapes)
-        # convert lane_shapes to a Mapping[ElWid, Any]
-        lane_shapes = SimdMap(lane_shapes).mapping
-        # filter out unsupported elwidths
-        lane_shapes = {i: lane_shapes[i] for i in part_counts.keys()}
-        self.lane_shapes = lane_shapes
-        # calculate the minimum possible bit-width of a part.
-        # we divide each element's width by the number of parts in an element,
-        # giving the number of bits needed per part.
-        min_part_wid = 0
-        for i, c in part_counts.items():
-            # double negate to get ceil division
-            needed = -(-lane_shapes[i] // c)
-            min_part_wid = max(min_part_wid, needed)
-        # calculate the minimum bit-width required
-        min_width = min_part_wid * full_part_count
-        print("width", min_width, min_part_wid, full_part_count)
-        if width_follows_hint \
-                and min_width <= simd_full_width_hint \
-                and fixed_width is None:
-            fixed_width = simd_full_width_hint
-
-        if fixed_width is not None:  # override the width and part_wid
-            assert min_width <= fixed_width, \
-                "not enough space to fit partitions"
-            self.part_wid = fixed_width // full_part_count
-            assert fixed_width % full_part_count == 0, \
-                "fixed_width must be a multiple of full_part_count"
-            width = fixed_width
-            print("part_wid", self.part_wid, "count", full_part_count)
-        else:
-            # go with computed width
-            width = min_width
-            self.part_wid = min_part_wid
-        super().__init__(width, signed)
-        # create the breakpoints dictionary.
-        # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
-        # https://stackoverflow.com/questions/26367812/
-        # dpoints: dict from bit-index to dict[ElWid, None]
-        # we use a dict from ElWid to None as the values of dpoints in order to
-        # get an ordered set
-        dpoints = defaultdict(dict)  # if empty key, create a (empty) dict
-        for i, cur_part_count in part_counts.items():
-            def add_p(bit_index):
-                # auto-creates dict if key non-existent
-                dpoints[bit_index][i] = None
-            # go through all elements for elwid `i`, each element starts at
-            # part index `start_part`, and goes for `cur_part_count` parts
-            for start_part in range(0, full_part_count, cur_part_count):
-                start_bit = start_part * self.part_wid
-                add_p(start_bit)  # start of lane
-                add_p(start_bit + lane_shapes[i])  # start of padding
-        # do not need the breakpoints at the very start or the very end
-        dpoints.pop(0, None)
-        dpoints.pop(self.width, None)
-        plist = list(dpoints.keys())
-        plist.sort()
-        dpoints = {k: dpoints[k].keys() for k in plist}
-        self.dpoints = dpoints
-        print("dpoints")
-        for k in plist:
-            print(f"{k}: {list(dpoints[k])}")
-        # second stage, add (map to) the elwidth==i expressions.
-        # TODO: use nmutil.treereduce?
-        points = {}
-        for p in plist:
-            it = map(lambda i: elwid == i, dpoints[p])
-            points[p] = reduce(operator.or_, it)
-        # third stage, create the binary values which *if* elwidth is set to i
-        # *would* result in the mask at that elwidth being set to this value
-        # these can easily be double-checked through Assertion
-        self.bitp = {}
-        for i in part_counts.keys():
-            self.bitp[i] = 0
-            for p, elwidths in dpoints.items():
-                if i in elwidths:
-                    bitpos = plist.index(p)
-                    self.bitp[i] |= 1 << bitpos
-        # fourth stage: determine which partitions are 100% unused.
-        # these can then be "blanked out"
-        self.bmask = (1 << len(plist)) - 1
-        for p in self.bitp.values():
-            self.bmask &= ~p
-        self.ppoints = PartitionPoints(points)
-
-    def __repr__(self):
-        bitp = ", ".join(f"{k}: {bin(v)}" for k, v in self.bitp.items())
-        dpoints = []
-        for k, v in self.dpoints.items():
-            dpoints.append(f"{k}: {list(v)}")
-        dpoints = ",\n        ".join(dpoints)
-        ppoints = []
-        for k, v in self.ppoints.items():
-            ppoints.append(f"{k}: {list(v)}")
-        ppoints = ",\n        ".join(ppoints)
-        return (f"SimdLayout(lane_shapes={self.lane_shapes},\n"
-                f"    signed={self.signed},\n"
-                f"    fixed_width={self.width},\n"
-                f"    scope={self.scope},\n"
-                f"    bitp={{{bitp}}},\n"
-                f"    bmask={bin(self.bmask)},\n"
-                f"    dpoints={{\n"
-                f"        {dpoints}}},\n"
-                f"    part_wid={self.part_wid},\n"
-                f"    ppoints=PartitionPoints({{\n"
-                f"        {ppoints}}}))")
 
+# XXX MAKE SURE TO PRESERVE ALL THESE COMMENTS XXX
+
+# main fn, which started out here in the bugtracker:
+# https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
+# note that signed is **NOT** part of the layout, and will NOT
+# be added (because it is not relevant or appropriate).
+# sign belongs in ast.Shape and is the only appropriate location.
+# there is absolutely nothing within this function that in any
+# way requires a sign.  it is *purely* performing numerical width
+# computations that have absolutely nothing to do with whether the
+# actual data is signed or unsigned.
+#
+# context for parameters:
+# http://lists.libre-soc.org/pipermail/libre-soc-dev/2021-October/003921.html
+def layout(elwid,            # comes from SimdScope constructor
+           vec_el_counts,    # comes from SimdScope constructor
+           lane_shapes=None,   # from SimdScope.Signal via a SimdShape
+           fixed_width=None):  # from SimdScope.Signal via a SimdShape
+    """calculate a SIMD layout.
+
+    Glossary:
+    * element: a single scalar value that is an element of a SIMD vector.
+        it has a width in bits. Every element is made of 1 or
+        more parts.
+    * ElWid: the element-width (really the element type) of an instruction.
+        Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
+        In Python, `ElWid` is either an enum type or is `int`.
+        Example `ElWid` definition for integers:
+
+        class ElWid(Enum):
+            I64 = ...       # SVP64 value 0b00
+            I32 = ...       # SVP64 value 0b01
+            I16 = ...       # SVP64 value 0b10
+            I8 = ...        # SVP64 value 0b11
+
+        Example `ElWid` definition for floats:
+
+        class ElWid(Enum):
+            F64 = ...    # SVP64 value 0b00
+            F32 = ...    # SVP64 value 0b01
+            F16 = ...    # SVP64 value 0b10
+            BF16 = ...   # SVP64 value 0b11
+
+    * elwid: ElWid or nmigen Value with ElWid as the shape
+        the current element-width
+
+    * vec_el_counts: dict[ElWid, int]
+        a map from `ElWid` values `k` to the number of vector elements
+        required within a partition when `elwid == k`.
+
+        Example:
+        vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
+                       ElWid.I16(==0b10): 4,  # 4 vector elements
+                       ElWid.I32(==0b01): 2,  # 2 vector elements
+                       ElWid.I64(==0b00): 1}  # 1 vector (aka scalar) element
+
+        Another Example:
+        vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
+                         ElWid.F16(==0b10): 4,  # 4 vector elements
+                         ElWid.F32(==0b01): 2,  # 2 vector elements
+                         ElWid.F64(==0b00): 1}  # 1 (aka scalar) vector element
+
+    * lane_shapes: int or Mapping[ElWid, int] (optional)
+        the bit-width of all elements in a SIMD layout.
+        if not provided, the lane_shapes are computed from fixed_width
+        and vec_el_counts at each elwidth.
+
+    * fixed_width: int (optional)
+        the total width of a SIMD vector. One or both of lane_shapes or
+        fixed_width may be provided.  Both may not be left out.
+    """
+    # when there are no lane_shapes specified, this indicates a
+    # desire to use the maximum available space based on the fixed width
+    # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
+    if lane_shapes is None:
+        assert fixed_width is not None, \
+            "both fixed_width and lane_shapes cannot be None"
+        lane_shapes = {i: fixed_width // vec_el_counts[i]
+                       for i in vec_el_counts}
+        print("lane_shapes", fixed_width, lane_shapes)
+
+    # identify if the lane_shapes is a mapping (dict, etc.)
+    # if not, then assume that it is an integer (width) that
+    # needs to be requested across all partitions
+    if not isinstance(lane_shapes, Mapping):
+        lane_shapes = {i: lane_shapes for i in vec_el_counts}
+
+    # compute a set of partition widths
+    print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
+    cpart_wid = 0
+    width = 0
+    for i, lwid in lane_shapes.items():
+        required_width = lwid * vec_el_counts[i]
+        print("     required width", cpart_wid, i, lwid, required_width)
+        if required_width > width:
+            cpart_wid = lwid
+            width = required_width
+
+    # calculate the minumum width required if fixed_width specified
+    part_count = max(vec_el_counts.values())
+    print("width", width, cpart_wid, part_count)
+    if fixed_width is not None:  # override the width and part_wid
+        assert width <= fixed_width, "not enough space to fit partitions"
+        part_wid = fixed_width // part_count
+        assert part_wid * part_count == fixed_width, \
+            "calculated width not aligned multiples"
+        width = fixed_width
+        print("part_wid", part_wid, "count", part_count, "width", width)
+
+    # create the breakpoints dictionary.
+    # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
+    # https://stackoverflow.com/questions/26367812/
+    dpoints = defaultdict(list)  # if empty key, create a (empty) list
+    padding_masks = {}
+    always_padding_mask = (1 << width) - 1  # start with all bits padding
+    for i, c in vec_el_counts.items():
+        print("dpoints", i, "count", c)
+        # calculate part_wid based on overall width divided by number
+        # of elements.
+        part_wid = width // c
+
+        padding_mask = (1 << width) - 1  # start with all bits padding
+
+        def add_p(msg, start, p):
+            print("    adding dpoint", msg, start, part_wid, i, c, p)
+            dpoints[p].append(i)  # auto-creates list if key non-existent
+        # for each elwidth, create the required number of vector elements
+        for start in range(c):
+            start_bit = start * part_wid
+            end_bit = start_bit + lane_shapes[i]
+            element_mask = (1 << end_bit) - (1 << start_bit)
+            padding_mask &= ~element_mask  # remove element from padding_mask
+            add_p("start", start, start_bit)  # start of lane
+            add_p("end  ", start, end_bit)  # end lane
+        padding_masks[i] = padding_mask
+        always_padding_mask &= padding_mask
+
+    # deduplicate dpoints lists
+    for k in dpoints.keys():
+        dpoints[k] = list({i: None for i in dpoints[k]}.keys())
+
+    # do not need the breakpoints at the very start or the very end
+    dpoints.pop(0, None)
+    dpoints.pop(width, None)
+
+    # sort dpoints keys
+    dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
+
+    print("dpoints")
+    pprint(dpoints)
+
+    # second stage, add (map to) the elwidth==i expressions.
+    # TODO: use nmutil.treereduce?
+    points = {}
+    for p in dpoints.keys():
+        points[p] = map(lambda i: elwid == i, dpoints[p])
+        points[p] = reduce(operator.or_, points[p])
+
+    # third stage, create the binary values which *if* elwidth is set to i
+    # *would* result in the mask at that elwidth being set to this value
+    # these can easily be double-checked through Assertion
+    bitp = {}
+    for i in vec_el_counts.keys():
+        bitp[i] = 0
+        for bit_index, (p, elwidths) in enumerate(dpoints.items()):
+            if i in elwidths:
+                bitp[i] |= 1 << bit_index
+
+    # fourth stage: determine which partitions are 100% unused.
+    # these can then be "blanked out"
+
+    # points are the partition separators, not partition indexes
+    partition_ends = [*dpoints.keys(), width]
+    bmask = 0
+    partition_start = 0
+    for bit_index, partition_end in enumerate(partition_ends):
+        pmask = (1 << partition_end) - (1 << partition_start)
+        always_padding = (always_padding_mask & pmask) == pmask
+        if always_padding:
+            bmask |= 1 << bit_index
+        partition_start = partition_end
+    return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
+            part_wid)
+
+# XXX XXX XXX XXX quick tests TODO convert to proper ones but kinda good
+# enough for now.  if adding new tests do not alter or delete the old ones
+# XXX XXX XXX XXX
 
 if __name__ == '__main__':
-    # for each element-width (elwidth 0-3) the number of parts in an element
-    # is given:
-    #                                | part0 | part1 | part2 | part3 |
-    # elwid=F64 4 parts per element: |<-------------F64------------->|
-    # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
-    # elwid=F16 1 part per element:  |<-F16->|<-F16->|<-F16->|<-F16->|
-    # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
+
+    # for each element-width (elwidth 0-3) the number of Vector Elements is:
+    # elwidth=0b00 QTY 1 partitions:   |          ?          |
+    # elwidth=0b01 QTY 1 partitions:   |          ?          |
+    # elwidth=0b10 QTY 2 partitions:   |    ?     |     ?    |
+    # elwidth=0b11 QTY 4 partitions:   | ?  |  ?  |  ?  | ?  |
     # actual widths of Signals *within* those partitions is given separately
-    part_counts = {
-        FpElWid.F64: 4,
-        FpElWid.F32: 2,
-        FpElWid.F16: 1,
-        FpElWid.BF16: 1,
+    vec_el_counts = {
+        0: 1,
+        1: 1,
+        2: 2,
+        3: 4,
     }
 
-    # width=3 indicates "we want the same element bit-width (3) at all elwids"
-    # elwid=F64 1x 3-bit     |<--------i3------->|
-    # elwid=F32 2x 3-bit     |<---i3-->|<---i3-->|
-    # elwid=F16 4x 3-bit     |<i3>|<i3>|<i3>|<i3>|
-    # elwid=BF16 4x 3-bit    |<i3>|<i3>|<i3>|<i3>|
-    width_for_all_els = 3
-
-    for i in FpElWid:
-        with SimdScope(elwid=i, part_counts=part_counts):
-            print(i, SimdLayout(width_for_all_els, True, width_follows_hint=False))
-
-    # fixed_width=32 and no lane_widths says "allocate maximum"
-    # elwid=F64 1x 32-bit    |<-------i32------->|
-    # elwid=F32 2x 16-bit    |<--i16-->|<--i16-->|
-    # elwid=F16 4x 8-bit     |<i8>|<i8>|<i8>|<i8>|
-    # elwid=BF16 4x 8-bit    |<i8>|<i8>|<i8>|<i8>|
+    # width=3 indicates "same width Vector Elements (3) at all elwidths"
+    # elwidth=0b00 1x 5-bit     |  unused xx      ..3 |
+    # elwidth=0b01 1x 6-bit     |  unused xx      ..3 |
+    # elwidth=0b10 2x 12-bit    | xxx  ..3 | xxx  ..3 |
+    # elwidth=0b11 3x 24-bit    | ..3| ..3 | ..3 |..3 |
+    # expected partitions      (^)   |     |     |   (^)
+    # to be at these points:   (|)   |     |     |    |
+    width_in_all_parts = 3
+
+    for i in range(4):
+        pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
+
+    # specify that the Vector Element lengths are to be *different* at
+    # each of the elwidths.
+    # combined with vec_el_counts we have:
+    # elwidth=0b00 1x 5-bit    |<----unused---------->....5|
+    # elwidth=0b01 1x 6-bit    |<----unused--------->.....6|
+    # elwidth=0b10 2x 6-bit    |unused>.....6|unused>.....6|
+    # elwidth=0b11 4x 6-bit    |.....6|.....6|.....6|.....6|
+    # expected partitions     (^)     ^      ^      ^^    (^)
+    # to be at these points:  (|)     |      |      ||    (|)
+    #                         (24)   18     12      65    (0)
+    widths_at_elwidth = {
+        0: 5,
+        1: 6,
+        2: 6,
+        3: 6
+    }
 
-    print("maximum allocation from fixed_width=32")
-    for i in FpElWid:
-        with SimdScope(elwid=i, part_counts=part_counts):
-            print(i, SimdLayout(signed=True, fixed_width=32))
-
-    # specify that the length is to be *different* at each of the elwidths.
-    # combined with part_counts we have:
-    # elwid=F64 1x 24-bit    |<-------i24------->|
-    # elwid=F32 2x 12-bit    |<--i12-->|<--i12-->|
-    # elwid=F16 4x 6-bit     |<i6>|<i6>|<i6>|<i6>|
-    # elwid=BF16 4x 5-bit    |<i5>|<i5>|<i5>|<i5>|
+    print("5,6,6,6 elements", widths_at_elwidth)
+    for i in range(4):
+        pp, bitp, bm, b, c, d = \
+            layout(i, vec_el_counts, widths_at_elwidth)
+        pprint((i, (pp, bitp, bm, b, c, d)))
+    # now check that the expected partition points occur
+    print("5,6,6,6 ppt keys", pp.keys())
+    assert list(pp.keys()) == [5, 6, 12, 18]
+    assert bm == 0  # no unused partitions
+
+    # this example was probably what the 5,6,6,6 one was supposed to be.
+    # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
+    # elwidth=0b00 1x 24-bit    |.........................24|
+    # elwidth=0b01 1x 12-bit    |<--unused--->|...........12|
+    # elwidth=0b10 2x 5 -bit    |unused>|....5|unused>|....5|
+    # elwidth=0b11 4x 6 -bit    |.....6|.....6|.....6|.....6|
+    # expected partitions      (^)     ^^     ^       ^^    (^)
+    # to be at these points:   (|)     ||     |       ||    (|)
+    #                          (24)   1817   12       65    (0)
     widths_at_elwidth = {
-        FpElWid.F64: 24,
-        FpElWid.F32: 12,
-        FpElWid.F16: 6,
-        FpElWid.BF16: 5,
+        0: 24,  # QTY 1x 24
+        1: 12,  # QTY 1x 12
+        2: 5,   # QTY 2x 5
+        3: 6    # QTY 4x 6
     }
 
-    for i in FpElWid:
-        with SimdScope(elwid=i, part_counts=part_counts):
-            print(i, SimdLayout(widths_at_elwidth,
-                                False, width_follows_hint=False))
+    print("24,12,5,6 elements", widths_at_elwidth)
+    for i in range(4):
+        pp, bitp, bm, b, c, d = \
+            layout(i, vec_el_counts, widths_at_elwidth)
+        pprint((i, (pp, bitp, bm, b, c, d)))
+    # now check that the expected partition points occur
+    print("24,12,5,6 ppt keys", pp.keys())
+    assert list(pp.keys()) == [5, 6, 12, 17, 18]
+    print("bmask", bin(bm))
+    assert bm == 0  # no unused partitions
 
     # this tests elwidth as an actual Signal. layout is allowed to
     # determine arbitrarily the overall length
     # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
 
-    with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
-        l = SimdLayout(widths_at_elwidth, False, width_follows_hint=False)
-        elwid = scope.elwid
-    print(l)
+    elwid = Signal(2)
+    pp, bitp, bm, b, c, d = layout(
+        elwid, vec_el_counts, widths_at_elwidth)
+    pprint((pp, b, c, d))
+    for k, v in bitp.items():
+        print("bitp elwidth=%d" % k, bin(v))
+    print("bmask", bin(bm))
+    assert bm == 0  # no unused partitions
 
     m = Module()
 
     def process():
-        for i in FpElWid:
+        for i in range(4):
             yield elwid.eq(i)
             yield Settle()
             ppt = []
-            for pval in l.ppoints.values():
+            for pval in list(pp.values()):
                 val = yield pval  # get nmigen to evaluate pp
                 ppt.append(val)
-            print(i, ppt)
+            pprint((i, (ppt, b, c, d)))
             # check the results against bitp static-expected partition points
             # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
             # https://stackoverflow.com/a/27165694
             ival = int(''.join(map(str, ppt[::-1])), 2)
-            assert ival == l.bitp[i]
-
-    sim = Simulator(m)
-    sim.add_process(process)
-    sim.run()
-
-    # this tests elwidth as an actual Signal. layout uses the width hint
-    # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
-
-    with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
-        l = SimdLayout(widths_at_elwidth, False)
-        elwid = scope.elwid
-    print(l)
-
-    m = Module()
-
-    def process():
-        for i in FpElWid:
-            yield elwid.eq(i)
-            yield Settle()
-            ppt = []
-            for pval in l.ppoints.values():
-                val = yield pval  # get nmigen to evaluate pp
-                ppt.append(val)
-            print(i, ppt)
-            # check the results against bitp static-expected partition points
-            # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
-            # https://stackoverflow.com/a/27165694
-            ival = int(''.join(map(str, ppt[::-1])), 2)
-            assert ival == l.bitp[i]
+            assert ival == bitp[i]
 
     sim = Simulator(m)
     sim.add_process(process)
@@ -381,39 +332,145 @@ if __name__ == '__main__':
     # determine arbitrarily the overall length, it is fixed to 64
     # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
 
-    with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
-        l = SimdLayout(widths_at_elwidth, False, fixed_width=64)
-        elwid = scope.elwid
-    print(l)
+    # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
+    # elwidth=0b00 1x 24-bit
+    # elwidth=0b01 1x 12-bit
+    # elwidth=0b10 2x 5-bit
+    # elwidth=0b11 4x 6-bit
+    #
+    # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
+    # always unused:|     |     |    ||    |    |   | |     |    |     ||    |
+    #      1111111111000000 1111111111000000 1111111100000000 0000000000000000
+    #               |     |     |    ||    |    |   | |     |    |     ||    |
+    # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
+    # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
+    # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
+    # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
+    #               ^     ^          ^^    ^        ^ ^     ^    ^     ^^
+    #     ppoints:  |     |          ||    |        | |     |    |     ||
+    #               |  bit-48        /\    | bit-24-/ |     | bit-12   /\-bit-5
+    #            bit-54      bit-38-/  \ bit-32       |   bit-16      /
+    #                                 bit-37       bit-22          bit-6
+
+    elwid = Signal(2)
+    pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
+                                   widths_at_elwidth,
+                                   fixed_width=64)
+    pprint((pp, b, c, d))
+    for k, v in bitp.items():
+        print("bitp elwidth=%d" % k, bin(v))
+    print("bmask", bin(bm))
+    assert bm == 0b101001000000
 
     m = Module()
 
     def process():
-        for i in FpElWid:
+        for i in range(4):
             yield elwid.eq(i)
             yield Settle()
             ppt = []
-            for pval in list(l.ppoints.values()):
+            for pval in list(pp.values()):
                 val = yield pval  # get nmigen to evaluate pp
                 ppt.append(val)
-            print(f"test elwidth={i}")
-            print(i, ppt)
+            print("test elwidth=%d" % i)
+            pprint((i, (ppt, b, c, d)))
             # check the results against bitp static-expected partition points
             # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
             # https://stackoverflow.com/a/27165694
             ival = int(''.join(map(str, ppt[::-1])), 2)
-            assert ival == l.bitp[i], \
-                f"ival {bin(ival)} actual {bin(l.bitp[i])}"
+            assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
+                                                           bin(bitp[i]))
 
     sim = Simulator(m)
     sim.add_process(process)
     sim.run()
 
-    # test XLEN
-    with SimdScope(elwid_type=IntElWid):
-        print("\nSimdLayout(XLEN):")
-        l1 = SimdLayout(XLEN)
-        print(l1)
-        print("\nSimdLayout(XLEN // 2):")
-        l2 = SimdLayout(XLEN // 2)
-        print(l2)
+    # fixed_width=32 and no lane_widths says "allocate maximum"
+    # i.e. Vector Element Widths are auto-allocated
+    # elwidth=0b00 1x 32-bit    | .................32 |
+    # elwidth=0b01 1x 32-bit    | .................32 |
+    # elwidth=0b10 2x 12-bit    | ......16 | ......16 |
+    # elwidth=0b11 3x 24-bit    | ..8| ..8 | ..8 |..8 |
+    # expected partitions      (^)   |     |     |   (^)
+    # to be at these points:   (|)   |     |     |    |
+
+    # TODO, fix this so that it is correct.  put it at the end so it
+    # shows that things break and doesn't stop the other tests.
+    print("maximum allocation from fixed_width=32")
+    for i in range(4):
+        pprint((i, layout(i, vec_el_counts, fixed_width=32)))
+
+    # example "exponent"
+    #  https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
+    # 1xFP64: 11 bits, one exponent
+    # 2xFP32: 8 bits, two exponents
+    # 4xFP16: 5 bits, four exponents
+    # 4xBF16: 8 bits, four exponents
+    vec_el_counts = {
+        0: 1,  # QTY 1x FP64
+        1: 2,  # QTY 2x FP32
+        2: 4,  # QTY 4x FP16
+        3: 4,  # QTY 4x BF16
+    }
+    widths_at_elwidth = {
+        0: 11,  # FP64 ew=0b00
+        1: 8,  # FP32 ew=0b01
+        2: 5,  # FP16 ew=0b10
+        3: 8   # BF16 ew=0b11
+    }
+
+    # expected results:
+    #
+    #        |31|  |  |24|     16|15  |  |   8|7     0 |
+    #        |31|28|26|24| |20|16|  12|  |10|8|5|4   0 |
+    #  32bit | x| x| x|  |      x|   x| x|10 ....    0 |
+    #  16bit | x| x|26    ... 16 |   x| x|10 ....    0 |
+    #  8bit  | x|28 .. 24|  20.16|   x|11 .. 8|x|4.. 0 |
+    #  unused  x                     x
+
+    print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
+    for i in range(4):
+        pp, bitp, bm, b, c, d = \
+            layout(i, vec_el_counts, widths_at_elwidth,
+                   fixed_width=32)
+        pprint((i, (pp, bitp, bin(bm), b, c, d)))
+    # now check that the expected partition points occur
+    print("11,8,5,8 pp keys", pp.keys())
+    #assert list(pp.keys()) == [5,6,12,18]
+
+    ######                                                           ######
+    ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
+    ######                                                           ######
+
+    # example "exponent"
+    vec_el_counts = {
+        0: 1,  # QTY 1x FP64
+        1: 2,  # QTY 2x FP32
+        2: 4,  # QTY 4x FP16
+        3: 4,  # QTY 4x BF16
+    }
+    widths_at_elwidth = {
+        0: 11,  # FP64 ew=0b00
+        1: 11,  # FP32 ew=0b01
+        2: 5,  # FP16 ew=0b10
+        3: 8   # BF16 ew=0b11
+    }
+
+    # expected results:
+    #
+    #        |31|  |  |24|     16|15  |  |   8|7     0 |
+    #        |31|28|26|24| |20|16|  12|  |10|8|5|4   0 |
+    #  32bit | x| x| x|  |      x|   x| x|10 ....    0 |
+    #  16bit | x| x|26    ... 16 |   x| x|10 ....    0 |
+    #  8bit  | x|28 .. 24|  20.16|   x|11 .. 8|x|4.. 0 |
+    #  unused  x                     x
+
+    print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
+    for i in range(4):
+        pp, bitp, bm, b, c, d = \
+            layout(i, vec_el_counts, widths_at_elwidth,
+                   fixed_width=32)
+        pprint((i, (pp, bitp, bin(bm), b, c, d)))
+    # now check that the expected partition points occur
+    print("11,8,5,8 pp keys", pp.keys())
+    #assert list(pp.keys()) == [5,6,12,18]