# https://stackoverflow.com/questions/26367812/
dpoints = defaultdict(list) # if empty key, create a (empty) list
for i, c in vec_el_counts.items():
+ print ("dpoints", i, "count", c)
# calculate part_wid based on overall width divided by number
# of elements.
part_wid = width // c
- def add_p(p):
+ def add_p(msg, start, p):
+ print (" adding dpoint", msg, start, part_wid, i, c, p)
dpoints[p].append(i) # auto-creates list if key non-existent
# for each elwidth, create the required number of vector elements
for start in range(c):
- add_p(start * part_wid) # start of lane
- add_p(start * part_wid + lane_shapes[i]) # start of padding
+ add_p("start", start, start * part_wid) # start of lane
+ add_p("end ", start, start * part_wid + lane_shapes[i]) # end lane
# do not need the breakpoints at the very start or the very end
dpoints.pop(0, None)
- dpoints.pop(width, None)
+ if fixed_width is not None:
+ dpoints.pop(fixed_width, None)
+ else:
+ dpoints.pop(width, None)
plist = list(dpoints.keys())
plist.sort()
print("dpoints")
for i in range(4):
pprint((i, layout(i, vec_el_counts, fixed_width=32)))
+ # example "exponent"
+ # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
+ # 1xFP64: 11 bits, one exponent
+ # 2xFP32: 8 bits, two exponents
+ # 4xFP16: 5 bits, four exponents
+ # 4xBF16: 8 bits, four exponents
+ vec_el_counts = {
+ 0: 1, # QTY 1x FP64
+ 1: 2, # QTY 2x FP32
+ 2: 4, # QTY 4x FP16
+ 3: 4, # QTY 4x BF16
+ }
+ widths_at_elwidth = {
+ 0: 11, # FP64 ew=0b00
+ 1: 8, # FP32 ew=0b01
+ 2: 5, # FP16 ew=0b10
+ 3: 8 # BF16 ew=0b11
+ }
+
+ # expected results:
+ #
+ # |31| | |24| 16|15 | | 8|7 0 |
+ # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
+ # 32bit | x| x| x| | x| x| x|10 .... 0 |
+ # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
+ # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
+ # unused x x
+
+ print ("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
+ for i in range(4):
+ pp, bitp, bm, b, c, d = \
+ layout(i, vec_el_counts, widths_at_elwidth,
+ fixed_width=32)
+ pprint((i, (pp, bitp, bm, b, c, d)))
+ # now check that the expected partition points occur
+ print("11,8,5,8 pp keys", pp.keys())
+ #assert list(pp.keys()) == [5,6,12,18]
+