"""
from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
-from nmigen.sim import Simulator, Delay, Settle
+from nmigen.back.pysim import Simulator, Delay, Settle
from nmigen.cli import rtlil
-from enum import Enum
from collections.abc import Mapping
from functools import reduce
import operator
from collections import defaultdict
-import dataclasses
+from pprint import pprint
from ieee754.part_mul_add.partpoints import PartitionPoints
-@dataclasses.dataclass
-class LayoutResult:
- ppoints: PartitionPoints
- bitp: dict
- bmask: int
- width: int
- lane_shapes: dict
- part_wid: int
- full_part_count: int
-
- def __repr__(self):
- fields = []
- for field in dataclasses.fields(LayoutResult):
- field_v = getattr(self, field.name)
- if isinstance(field_v, PartitionPoints):
- field_v = ',\n '.join(
- f"{k}: {v}" for k, v in field_v.items())
- field_v = f"{{{field_v}}}"
- fields.append(f"{field.name}={field_v}")
- fields = ",\n ".join(fields)
- return f"LayoutResult({fields})"
-
+# XXX MAKE SURE TO PRESERVE ALL THESE COMMENTS XXX
# main fn, which started out here in the bugtracker:
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
-def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None):
+# note that signed is **NOT** part of the layout, and will NOT
+# be added (because it is not relevant or appropriate).
+# sign belongs in ast.Shape and is the only appropriate location.
+# there is absolutely nothing within this function that in any
+# way requires a sign. it is *purely* performing numerical width
+# computations that have absolutely nothing to do with whether the
+# actual data is signed or unsigned.
+#
+# context for parameters:
+# http://lists.libre-soc.org/pipermail/libre-soc-dev/2021-October/003921.html
+def layout(elwid, # comes from SimdScope constructor
+ vec_el_counts, # comes from SimdScope constructor
+ lane_shapes=None, # from SimdScope.Signal via a SimdShape
+ fixed_width=None): # from SimdScope.Signal via a SimdShape
"""calculate a SIMD layout.
Glossary:
* element: a single scalar value that is an element of a SIMD vector.
- it has a width in bits, and a signedness. Every element is made of 1 or
- more parts. An element optionally includes the padding associated with
- it.
- * lane: an element. An element optionally includes the padding associated
- with it.
+ it has a width in bits. Every element is made of 1 or
+ more parts.
* ElWid: the element-width (really the element type) of an instruction.
Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
In Python, `ElWid` is either an enum type or is `int`.
Example `ElWid` definition for integers:
class ElWid(Enum):
- I8 = ...
- I16 = ...
- I32 = ...
- I64 = ...
+ I64 = ... # SVP64 value 0b00
+ I32 = ... # SVP64 value 0b01
+ I16 = ... # SVP64 value 0b10
+ I8 = ... # SVP64 value 0b11
Example `ElWid` definition for floats:
class ElWid(Enum):
- F16 = ...
- BF16 = ...
- F32 = ...
- F64 = ...
- * part: (not to be confused with a partition) A piece of a SIMD vector,
- every SIMD vector is made of a non-negative integer of parts. Elements
- are made of a power-of-two number of parts. A part is a fixed number
- of bits wide for each different SIMD layout, it doesn't vary when
- `elwid` changes. A part can have a bit width of any non-negative
- integer, it is not restricted to power-of-two. SIMD vectors should
- have as few parts as necessary, since some circuits have size
- proportional to the number of parts.
-
+ F64 = ... # SVP64 value 0b00
+ F32 = ... # SVP64 value 0b01
+ F16 = ... # SVP64 value 0b10
+ BF16 = ... # SVP64 value 0b11
* elwid: ElWid or nmigen Value with ElWid as the shape
the current element-width
- * signed: bool
- the signedness of all elements in a SIMD layout
- * part_counts: dict[ElWid, int]
- a map from `ElWid` values `k` to the number of parts in an element
- when `elwid == k`. Values should be minimized, since higher values
- often create bigger circuits.
+
+ * vec_el_counts: dict[ElWid, int]
+ a map from `ElWid` values `k` to the number of vector elements
+ required within a partition when `elwid == k`.
Example:
- # here, an I8 element is 1 part wide
- part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
+ vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
+ ElWid.I16(==0b10): 4, # 4 vector elements
+ ElWid.I32(==0b01): 2, # 2 vector elements
+ ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
Another Example:
- # here, an F16 element is 1 part wide
- part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
+ vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
+ ElWid.F16(==0b10): 4, # 4 vector elements
+ ElWid.F32(==0b01): 2, # 2 vector elements
+ ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
+
* lane_shapes: int or Mapping[ElWid, int] (optional)
the bit-width of all elements in a SIMD layout.
+ if not provided, the lane_shapes are computed from fixed_width
+ and vec_el_counts at each elwidth.
+
* fixed_width: int (optional)
- the total width of a SIMD vector. One of lane_shapes and fixed_width
- must be provided.
+ the total width of a SIMD vector. One or both of lane_shapes or
+ fixed_width may be provided. Both may not be left out.
"""
- print(f"layout(elwid={elwid},\n"
- f" signed={signed},\n"
- f" part_counts={part_counts},\n"
- f" lane_shapes={lane_shapes},\n"
- f" fixed_width={fixed_width})")
- assert isinstance(part_counts, Mapping)
- # assert all part_counts are powers of two
- assert all(v != 0 and (v & (v - 1)) == 0 for v in part_counts.values()),\
- "part_counts values must all be powers of two"
-
- full_part_count = max(part_counts.values())
-
# when there are no lane_shapes specified, this indicates a
# desire to use the maximum available space based on the fixed width
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
if lane_shapes is None:
assert fixed_width is not None, \
"both fixed_width and lane_shapes cannot be None"
- lane_shapes = {}
- for k, cur_part_count in part_counts.items():
- cur_element_count = full_part_count // cur_part_count
- assert fixed_width % cur_element_count == 0, (
- f"fixed_width ({fixed_width}) can't be split evenly into "
- f"{cur_element_count} elements")
- lane_shapes[k] = fixed_width // cur_element_count
+ lane_shapes = {i: fixed_width // vec_el_counts[i]
+ for i in vec_el_counts}
print("lane_shapes", fixed_width, lane_shapes)
+
# identify if the lane_shapes is a mapping (dict, etc.)
# if not, then assume that it is an integer (width) that
# needs to be requested across all partitions
if not isinstance(lane_shapes, Mapping):
- lane_shapes = {i: lane_shapes for i in part_counts}
- # calculate the minimum possible bit-width of a part.
- # we divide each element's width by the number of parts in an element,
- # giving the number of bits needed per part.
- # we use `-min(-a // b for ...)` to get `max(ceil(a / b) for ...)`,
- # but using integers.
- min_part_wid = -min(-lane_shapes[i] // c for i, c in part_counts.items())
- # calculate the minimum bit-width required
- min_width = min_part_wid * full_part_count
- print("width", min_width, min_part_wid, full_part_count)
+ lane_shapes = {i: lane_shapes for i in vec_el_counts}
+
+ # compute a set of partition widths
+ print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
+ cpart_wid = 0
+ width = 0
+ for i, lwid in lane_shapes.items():
+ required_width = lwid * vec_el_counts[i]
+ print(" required width", cpart_wid, i, lwid, required_width)
+ if required_width > width:
+ cpart_wid = lwid
+ width = required_width
+
+ # calculate the minumum width required if fixed_width specified
+ part_count = max(vec_el_counts.values())
+ print("width", width, cpart_wid, part_count)
if fixed_width is not None: # override the width and part_wid
- assert min_width <= fixed_width, "not enough space to fit partitions"
- part_wid = fixed_width // full_part_count
- assert fixed_width % full_part_count == 0, \
- "fixed_width must be a multiple of full_part_count"
+ assert width <= fixed_width, "not enough space to fit partitions"
+ part_wid = fixed_width // part_count
+ assert part_wid * part_count == fixed_width, \
+ "calculated width not aligned multiples"
width = fixed_width
- print("part_wid", part_wid, "count", full_part_count)
- else:
- # go with computed width
- width = min_width
- part_wid = min_part_wid
+ print("part_wid", part_wid, "count", part_count, "width", width)
+
# create the breakpoints dictionary.
# do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
# https://stackoverflow.com/questions/26367812/
- # dpoints: dict from bit-index to dict[ElWid, None]
- # we use a dict from ElWid to None as the values of dpoints in order to
- # get an ordered set
- dpoints = defaultdict(dict) # if empty key, create a (empty) dict
- for i, cur_part_count in part_counts.items():
- def add_p(bit_index):
- # auto-creates dict if key non-existent
- dpoints[bit_index][i] = None
- # go through all elements for elwid `i`, each element starts at
- # part index `start_part`, and goes for `cur_part_count` parts
- for start_part in range(0, full_part_count, cur_part_count):
- start_bit = start_part * part_wid
- add_p(start_bit) # start of lane
- add_p(start_bit + lane_shapes[i]) # start of padding
+ dpoints = defaultdict(list) # if empty key, create a (empty) list
+ padding_masks = {}
+ always_padding_mask = (1 << width) - 1 # start with all bits padding
+ for i, c in vec_el_counts.items():
+ print("dpoints", i, "count", c)
+ # calculate part_wid based on overall width divided by number
+ # of elements.
+ part_wid = width // c
+
+ padding_mask = (1 << width) - 1 # start with all bits padding
+
+ def add_p(msg, start, p):
+ print(" adding dpoint", msg, start, part_wid, i, c, p)
+ dpoints[p].append(i) # auto-creates list if key non-existent
+ # for each elwidth, create the required number of vector elements
+ for start in range(c):
+ start_bit = start * part_wid
+ end_bit = start_bit + lane_shapes[i]
+ element_mask = (1 << end_bit) - (1 << start_bit)
+ padding_mask &= ~element_mask # remove element from padding_mask
+ add_p("start", start, start_bit) # start of lane
+ add_p("end ", start, end_bit) # end lane
+ padding_masks[i] = padding_mask
+ always_padding_mask &= padding_mask
+
+ # deduplicate dpoints lists
+ for k in dpoints.keys():
+ dpoints[k] = list({i: None for i in dpoints[k]}.keys())
+
# do not need the breakpoints at the very start or the very end
dpoints.pop(0, None)
dpoints.pop(width, None)
- plist = list(dpoints.keys())
- plist.sort()
+
+ # sort dpoints keys
+ dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
+
print("dpoints")
- for k in plist:
- print(f"{k}: {list(dpoints[k].keys())}")
+ pprint(dpoints)
+
# second stage, add (map to) the elwidth==i expressions.
# TODO: use nmutil.treereduce?
points = {}
- for p in plist:
- it = map(lambda i: elwid == i, dpoints[p])
- points[p] = reduce(operator.or_, it)
+ for p in dpoints.keys():
+ points[p] = map(lambda i: elwid == i, dpoints[p])
+ points[p] = reduce(operator.or_, points[p])
+
# third stage, create the binary values which *if* elwidth is set to i
# *would* result in the mask at that elwidth being set to this value
# these can easily be double-checked through Assertion
bitp = {}
- for i in part_counts.keys():
+ for i in vec_el_counts.keys():
bitp[i] = 0
- for p, elwidths in dpoints.items():
+ for bit_index, (p, elwidths) in enumerate(dpoints.items()):
if i in elwidths:
- bitpos = plist.index(p)
- bitp[i] |= 1 << bitpos
+ bitp[i] |= 1 << bit_index
+
# fourth stage: determine which partitions are 100% unused.
# these can then be "blanked out"
- bmask = (1 << len(plist)) - 1
- for p in bitp.values():
- bmask &= ~p
- return LayoutResult(PartitionPoints(points), bitp, bmask, width,
- lane_shapes, part_wid, full_part_count)
+ # points are the partition separators, not partition indexes
+ partition_ends = [*dpoints.keys(), width]
+ bmask = 0
+ partition_start = 0
+ for bit_index, partition_end in enumerate(partition_ends):
+ pmask = (1 << partition_end) - (1 << partition_start)
+ always_padding = (always_padding_mask & pmask) == pmask
+ if always_padding:
+ bmask |= 1 << bit_index
+ partition_start = partition_end
+ return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
+ part_wid)
+
+# XXX XXX XXX XXX quick tests TODO convert to proper ones but kinda good
+# enough for now. if adding new tests do not alter or delete the old ones
+# XXX XXX XXX XXX
if __name__ == '__main__':
- class FpElWid(Enum):
- F64 = 0
- F32 = 1
- F16 = 2
- BF16 = 3
-
- def __repr__(self):
- return super().__str__()
-
- class IntElWid(Enum):
- I64 = 0
- I32 = 1
- I16 = 2
- I8 = 3
-
- def __repr__(self):
- return super().__str__()
-
- # for each element-width (elwidth 0-3) the number of parts in an element
- # is given:
- # | part0 | part1 | part2 | part3 |
- # elwid=F64 4 parts per element: |<-------------F64------------->|
- # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
- # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
- # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
+ # for each element-width (elwidth 0-3) the number of Vector Elements is:
+ # elwidth=0b00 QTY 1 partitions: | ? |
+ # elwidth=0b01 QTY 1 partitions: | ? |
+ # elwidth=0b10 QTY 2 partitions: | ? | ? |
+ # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
# actual widths of Signals *within* those partitions is given separately
- part_counts = {
- FpElWid.F64: 4,
- FpElWid.F32: 2,
- FpElWid.F16: 1,
- FpElWid.BF16: 1,
+ vec_el_counts = {
+ 0: 1,
+ 1: 1,
+ 2: 2,
+ 3: 4,
}
- # width=3 indicates "we want the same element bit-width (3) at all elwids"
- # elwid=F64 1x 3-bit |<--------i3------->|
- # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
- # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
- # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
- width_for_all_els = 3
-
- for i in FpElWid:
- print(i, layout(i, True, part_counts, width_for_all_els))
-
- # fixed_width=32 and no lane_widths says "allocate maximum"
- # elwid=F64 1x 32-bit |<-------i32------->|
- # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
- # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
- # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
+ # width=3 indicates "same width Vector Elements (3) at all elwidths"
+ # elwidth=0b00 1x 5-bit | unused xx ..3 |
+ # elwidth=0b01 1x 6-bit | unused xx ..3 |
+ # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
+ # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
+ # expected partitions (^) | | | (^)
+ # to be at these points: (|) | | | |
+ width_in_all_parts = 3
+
+ for i in range(4):
+ pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
+
+ # specify that the Vector Element lengths are to be *different* at
+ # each of the elwidths.
+ # combined with vec_el_counts we have:
+ # elwidth=0b00 1x 5-bit |<----unused---------->....5|
+ # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
+ # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
+ # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
+ # expected partitions (^) ^ ^ ^^ (^)
+ # to be at these points: (|) | | || (|)
+ # (24) 18 12 65 (0)
+ widths_at_elwidth = {
+ 0: 5,
+ 1: 6,
+ 2: 6,
+ 3: 6
+ }
- print("maximum allocation from fixed_width=32")
- for i in FpElWid:
- print(i, layout(i, True, part_counts, fixed_width=32))
-
- # specify that the length is to be *different* at each of the elwidths.
- # combined with part_counts we have:
- # elwid=F64 1x 24-bit |<-------i24------->|
- # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
- # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
- # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
+ print("5,6,6,6 elements", widths_at_elwidth)
+ for i in range(4):
+ pp, bitp, bm, b, c, d = \
+ layout(i, vec_el_counts, widths_at_elwidth)
+ pprint((i, (pp, bitp, bm, b, c, d)))
+ # now check that the expected partition points occur
+ print("5,6,6,6 ppt keys", pp.keys())
+ assert list(pp.keys()) == [5, 6, 12, 18]
+ assert bm == 0 # no unused partitions
+
+ # this example was probably what the 5,6,6,6 one was supposed to be.
+ # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
+ # elwidth=0b00 1x 24-bit |.........................24|
+ # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
+ # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
+ # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
+ # expected partitions (^) ^^ ^ ^^ (^)
+ # to be at these points: (|) || | || (|)
+ # (24) 1817 12 65 (0)
widths_at_elwidth = {
- FpElWid.F64: 24,
- FpElWid.F32: 12,
- FpElWid.F16: 6,
- FpElWid.BF16: 5,
+ 0: 24, # QTY 1x 24
+ 1: 12, # QTY 1x 12
+ 2: 5, # QTY 2x 5
+ 3: 6 # QTY 4x 6
}
- for i in FpElWid:
- print(i, layout(i, False, part_counts, widths_at_elwidth))
+ print("24,12,5,6 elements", widths_at_elwidth)
+ for i in range(4):
+ pp, bitp, bm, b, c, d = \
+ layout(i, vec_el_counts, widths_at_elwidth)
+ pprint((i, (pp, bitp, bm, b, c, d)))
+ # now check that the expected partition points occur
+ print("24,12,5,6 ppt keys", pp.keys())
+ assert list(pp.keys()) == [5, 6, 12, 17, 18]
+ print("bmask", bin(bm))
+ assert bm == 0 # no unused partitions
# this tests elwidth as an actual Signal. layout is allowed to
# determine arbitrarily the overall length
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
- elwid = Signal(FpElWid)
- lr = layout(elwid, False, part_counts, widths_at_elwidth)
- print(lr)
- for k, v in lr.bitp.items():
- print(f"bitp elwidth={k}", bin(v))
- print("bmask", bin(lr.bmask))
+ elwid = Signal(2)
+ pp, bitp, bm, b, c, d = layout(
+ elwid, vec_el_counts, widths_at_elwidth)
+ pprint((pp, b, c, d))
+ for k, v in bitp.items():
+ print("bitp elwidth=%d" % k, bin(v))
+ print("bmask", bin(bm))
+ assert bm == 0 # no unused partitions
m = Module()
def process():
- for i in FpElWid:
+ for i in range(4):
yield elwid.eq(i)
yield Settle()
ppt = []
- for pval in lr.ppoints.values():
+ for pval in list(pp.values()):
val = yield pval # get nmigen to evaluate pp
ppt.append(val)
- print(i, ppt)
+ pprint((i, (ppt, b, c, d)))
# check the results against bitp static-expected partition points
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
# https://stackoverflow.com/a/27165694
ival = int(''.join(map(str, ppt[::-1])), 2)
- assert ival == lr.bitp[i]
+ assert ival == bitp[i]
sim = Simulator(m)
sim.add_process(process)
# determine arbitrarily the overall length, it is fixed to 64
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
- elwid = Signal(FpElWid)
- lr = layout(elwid, False, part_counts, widths_at_elwidth, fixed_width=64)
- print(lr)
- for k, v in lr.bitp.items():
- print(f"bitp elwidth={k}", bin(v))
- print("bmask", bin(lr.bmask))
+ # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
+ # elwidth=0b00 1x 24-bit
+ # elwidth=0b01 1x 12-bit
+ # elwidth=0b10 2x 5-bit
+ # elwidth=0b11 4x 6-bit
+ #
+ # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
+ # always unused:| | | || | | | | | | || |
+ # 1111111111000000 1111111111000000 1111111100000000 0000000000000000
+ # | | | || | | | | | | || |
+ # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
+ # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
+ # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
+ # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
+ # ^ ^ ^^ ^ ^ ^ ^ ^ ^^
+ # ppoints: | | || | | | | | ||
+ # | bit-48 /\ | bit-24-/ | | bit-12 /\-bit-5
+ # bit-54 bit-38-/ \ bit-32 | bit-16 /
+ # bit-37 bit-22 bit-6
+
+ elwid = Signal(2)
+ pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
+ widths_at_elwidth,
+ fixed_width=64)
+ pprint((pp, b, c, d))
+ for k, v in bitp.items():
+ print("bitp elwidth=%d" % k, bin(v))
+ print("bmask", bin(bm))
+ assert bm == 0b101001000000
m = Module()
def process():
- for i in FpElWid:
+ for i in range(4):
yield elwid.eq(i)
yield Settle()
ppt = []
- for pval in list(lr.ppoints.values()):
+ for pval in list(pp.values()):
val = yield pval # get nmigen to evaluate pp
ppt.append(val)
- print(f"test elwidth={i}")
- print(i, ppt)
+ print("test elwidth=%d" % i)
+ pprint((i, (ppt, b, c, d)))
# check the results against bitp static-expected partition points
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
# https://stackoverflow.com/a/27165694
ival = int(''.join(map(str, ppt[::-1])), 2)
- assert ival == lr.bitp[i], \
- f"ival {bin(ival)} actual {bin(lr.bitp[i])}"
+ assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
+ bin(bitp[i]))
sim = Simulator(m)
sim.add_process(process)
sim.run()
+
+ # fixed_width=32 and no lane_widths says "allocate maximum"
+ # i.e. Vector Element Widths are auto-allocated
+ # elwidth=0b00 1x 32-bit | .................32 |
+ # elwidth=0b01 1x 32-bit | .................32 |
+ # elwidth=0b10 2x 12-bit | ......16 | ......16 |
+ # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
+ # expected partitions (^) | | | (^)
+ # to be at these points: (|) | | | |
+
+ # TODO, fix this so that it is correct. put it at the end so it
+ # shows that things break and doesn't stop the other tests.
+ print("maximum allocation from fixed_width=32")
+ for i in range(4):
+ pprint((i, layout(i, vec_el_counts, fixed_width=32)))
+
+ # example "exponent"
+ # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
+ # 1xFP64: 11 bits, one exponent
+ # 2xFP32: 8 bits, two exponents
+ # 4xFP16: 5 bits, four exponents
+ # 4xBF16: 8 bits, four exponents
+ vec_el_counts = {
+ 0: 1, # QTY 1x FP64
+ 1: 2, # QTY 2x FP32
+ 2: 4, # QTY 4x FP16
+ 3: 4, # QTY 4x BF16
+ }
+ widths_at_elwidth = {
+ 0: 11, # FP64 ew=0b00
+ 1: 8, # FP32 ew=0b01
+ 2: 5, # FP16 ew=0b10
+ 3: 8 # BF16 ew=0b11
+ }
+
+ # expected results:
+ #
+ # |31| | |24| 16|15 | | 8|7 0 |
+ # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
+ # 32bit | x| x| x| | x| x| x|10 .... 0 |
+ # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
+ # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
+ # unused x x
+
+ print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
+ for i in range(4):
+ pp, bitp, bm, b, c, d = \
+ layout(i, vec_el_counts, widths_at_elwidth,
+ fixed_width=32)
+ pprint((i, (pp, bitp, bin(bm), b, c, d)))
+ # now check that the expected partition points occur
+ print("11,8,5,8 pp keys", pp.keys())
+ #assert list(pp.keys()) == [5,6,12,18]
+
+ ###### ######
+ ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
+ ###### ######
+
+ # example "exponent"
+ vec_el_counts = {
+ 0: 1, # QTY 1x FP64
+ 1: 2, # QTY 2x FP32
+ 2: 4, # QTY 4x FP16
+ 3: 4, # QTY 4x BF16
+ }
+ widths_at_elwidth = {
+ 0: 11, # FP64 ew=0b00
+ 1: 11, # FP32 ew=0b01
+ 2: 5, # FP16 ew=0b10
+ 3: 8 # BF16 ew=0b11
+ }
+
+ # expected results:
+ #
+ # |31| | |24| 16|15 | | 8|7 0 |
+ # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
+ # 32bit | x| x| x| | x| x| x|10 .... 0 |
+ # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
+ # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
+ # unused x x
+
+ print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
+ for i in range(4):
+ pp, bitp, bm, b, c, d = \
+ layout(i, vec_el_counts, widths_at_elwidth,
+ fixed_width=32)
+ pprint((i, (pp, bitp, bin(bm), b, c, d)))
+ # now check that the expected partition points occur
+ print("11,8,5,8 pp keys", pp.keys())
+ #assert list(pp.keys()) == [5,6,12,18]