change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # XXX MAKE SURE TO PRESERVE ALL THESE COMMENTS XXX
29
30 # main fn, which started out here in the bugtracker:
31 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
32 # note that signed is **NOT** part of the layout, and will NOT
33 # be added (because it is not relevant or appropriate).
34 # sign belongs in ast.Shape and is the only appropriate location.
35 # there is absolutely nothing within this function that in any
36 # way requires a sign. it is *purely* performing numerical width
37 # computations that have absolutely nothing to do with whether the
38 # actual data is signed or unsigned.
39 #
40 # context for parameters:
41 # http://lists.libre-soc.org/pipermail/libre-soc-dev/2021-October/003921.html
42 # XXX tempted to suggest that this function remain as a function, because
43 # it takes all the context it needs as parameters. its usefulness goes
44 # beyond a single class, and there is actually nothing realistically
45 # that it needs whixh is context-sensitive. therefore, on balance,
46 # it should remain a function
47 def layout(elwid, # comes from SimdScope constructor
48 vec_el_counts, # comes from SimdScope constructor
49 lane_shapes=None, # from SimdScope.Signal via a SimdShape
50 fixed_width=None): # from SimdScope.Signal via a SimdShape
51 """calculate a SIMD layout.
52
53 Glossary:
54 * element: a single scalar value that is an element of a SIMD vector.
55 it has a width in bits. Every element is made of 1 or
56 more parts.
57 * ElWid: the element-width (really the element type) of an instruction.
58 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
59 In Python, `ElWid` is either an enum type or is `int`.
60 Example `ElWid` definition for integers:
61
62 class ElWid(Enum):
63 I64 = ... # SVP64 value 0b00
64 I32 = ... # SVP64 value 0b01
65 I16 = ... # SVP64 value 0b10
66 I8 = ... # SVP64 value 0b11
67
68 Example `ElWid` definition for floats:
69
70 class ElWid(Enum):
71 F64 = ... # SVP64 value 0b00
72 F32 = ... # SVP64 value 0b01
73 F16 = ... # SVP64 value 0b10
74 BF16 = ... # SVP64 value 0b11
75
76 * elwid: ElWid or nmigen Value with ElWid as the shape
77 the current element-width
78
79 * vec_el_counts: dict[ElWid, int]
80 a map from `ElWid` values `k` to the number of vector elements
81 required within a partition when `elwid == k`.
82
83 Example:
84 vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
85 ElWid.I16(==0b10): 4, # 4 vector elements
86 ElWid.I32(==0b01): 2, # 2 vector elements
87 ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
88
89 Another Example:
90 vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
91 ElWid.F16(==0b10): 4, # 4 vector elements
92 ElWid.F32(==0b01): 2, # 2 vector elements
93 ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
94
95 * lane_shapes: int or Mapping[ElWid, int] (optional)
96 the bit-width of all elements in a SIMD layout.
97 if not provided, the lane_shapes are computed from fixed_width
98 and vec_el_counts at each elwidth.
99
100 * fixed_width: int (optional)
101 the total width of a SIMD vector. One or both of lane_shapes or
102 fixed_width may be provided. Both may not be left out.
103 """
104 # when there are no lane_shapes specified, this indicates a
105 # desire to use the maximum available space based on the fixed width
106 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
107 if lane_shapes is None:
108 assert fixed_width is not None, \
109 "both fixed_width and lane_shapes cannot be None"
110 lane_shapes = {i: fixed_width // vec_el_counts[i]
111 for i in vec_el_counts}
112 print("lane_shapes", fixed_width, lane_shapes)
113
114 # identify if the lane_shapes is a mapping (dict, etc.)
115 # if not, then assume that it is an integer (width) that
116 # needs to be requested across all partitions.
117 # Note: back in SimdScope.Shape(), this is how code that was
118 # formerly scalar can be "converted" to "look like" it is
119 # still scalar, by not altering the width at any of the elwids.
120 if not isinstance(lane_shapes, Mapping):
121 lane_shapes = {i: lane_shapes for i in vec_el_counts}
122
123 # compute a set of partition widths
124 print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
125 cpart_wid = 0
126 width = 0
127 for i, lwid in lane_shapes.items():
128 required_width = lwid * vec_el_counts[i]
129 print(" required width", cpart_wid, i, lwid, required_width)
130 if required_width > width:
131 cpart_wid = lwid
132 width = required_width
133
134 # calculate the minumum width required if fixed_width specified
135 part_count = max(vec_el_counts.values())
136 print("width", width, cpart_wid, part_count)
137 if fixed_width is not None: # override the width and part_wid
138 assert width <= fixed_width, "not enough space to fit partitions"
139 part_wid = fixed_width // part_count
140 assert part_wid * part_count == fixed_width, \
141 "calculated width not aligned multiples"
142 width = fixed_width
143 print("part_wid", part_wid, "count", part_count, "width", width)
144
145 # create the breakpoints dictionary.
146 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
147 # https://stackoverflow.com/questions/26367812/
148
149 # first phase: create a dictionary of partition points (dpoints) and
150 # also keep a record of where each element starts and ends (lpoints)
151 dpoints = defaultdict(list) # if empty key, create a (empty) list
152 lpoints = defaultdict(list) # dict of list of start-end points
153 padding_masks = {}
154 always_padding_mask = (1 << width) - 1 # start with all bits padding
155 for i, c in vec_el_counts.items():
156 print("dpoints", i, "count", c)
157 # calculate part_wid based on overall width divided by number
158 # of elements.
159 part_wid = width // c
160
161 padding_mask = (1 << width) - 1 # start with all bits padding
162
163 def add_p(msg, start, p):
164 print(" adding dpoint", msg, start, part_wid, i, c, p)
165 dpoints[p].append(i) # auto-creates list if key non-existent
166 # for each elwidth, create the required number of vector elements
167 for start in range(c):
168 start_bit = start * part_wid
169 end_bit = start_bit + lane_shapes[i]
170 element_mask = (1 << end_bit) - (1 << start_bit)
171 padding_mask &= ~element_mask # remove element from padding_mask
172 lpoints[i].append(range(start_bit, end_bit))
173 add_p("start", start, start_bit) # start of lane
174 add_p("end ", start, end_bit) # end lane
175 padding_masks[i] = padding_mask
176 always_padding_mask &= padding_mask
177
178 # deduplicate dpoints lists
179 for k in dpoints.keys():
180 dpoints[k] = list({i: None for i in dpoints[k]}.keys())
181
182 # do not need the breakpoints at the very start or the very end
183 dpoints.pop(0, None)
184 dpoints.pop(width, None)
185
186 # sort dpoints keys
187 dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
188
189 print("dpoints")
190 pprint(dpoints)
191
192 # second stage, add (map to) the elwidth==i expressions.
193 # TODO: use nmutil.treereduce?
194 points = {}
195 for p in dpoints.keys():
196 points[p] = map(lambda i: elwid == i, dpoints[p])
197 points[p] = reduce(operator.or_, points[p])
198
199 # third stage, create the binary values which *if* elwidth is set to i
200 # *would* result in the mask at that elwidth being set to this value
201 # these can easily be double-checked through Assertion
202 bitp = {}
203 for i in vec_el_counts.keys():
204 bitp[i] = 0
205 for bit_index, (p, elwidths) in enumerate(dpoints.items()):
206 if i in elwidths:
207 bitp[i] |= 1 << bit_index
208
209 # fourth stage: determine which partitions are 100% unused.
210 # these can then be "blanked out"
211
212 # points are the partition separators, not partition indexes
213 partition_ends = [*dpoints.keys(), width]
214 bmask = 0
215 partition_start = 0
216 for bit_index, partition_end in enumerate(partition_ends):
217 pmask = (1 << partition_end) - (1 << partition_start)
218 always_padding = (always_padding_mask & pmask) == pmask
219 if always_padding:
220 bmask |= 1 << bit_index
221 partition_start = partition_end
222 return (PartitionPoints(points), bitp, lpoints, bmask, width, lane_shapes,
223 part_wid)
224
225 # XXX XXX XXX XXX quick tests TODO convert to proper ones but kinda good
226 # enough for now. if adding new tests do not alter or delete the old ones
227 # XXX XXX XXX XXX
228
229 if __name__ == '__main__':
230
231 # for each element-width (elwidth 0-3) the number of Vector Elements is:
232 # elwidth=0b00 QTY 1 partitions: | ? |
233 # elwidth=0b01 QTY 1 partitions: | ? |
234 # elwidth=0b10 QTY 2 partitions: | ? | ? |
235 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
236 # actual widths of Signals *within* those partitions is given separately
237 vec_el_counts = {
238 0: 1,
239 1: 1,
240 2: 2,
241 3: 4,
242 }
243
244 # width=3 indicates "same width Vector Elements (3) at all elwidths"
245 # elwidth=0b00 1x 5-bit | unused xx ..3 |
246 # elwidth=0b01 1x 6-bit | unused xx ..3 |
247 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
248 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
249 # expected partitions (^) | | | (^)
250 # to be at these points: (|) | | | |
251 width_in_all_parts = 3
252
253 for i in range(4):
254 pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
255
256 # specify that the Vector Element lengths are to be *different* at
257 # each of the elwidths.
258 # combined with vec_el_counts we have:
259 # elwidth=0b00 1x 5-bit |<----unused---------->....5|
260 # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
261 # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
262 # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
263 # expected partitions (^) ^ ^ ^^ (^)
264 # to be at these points: (|) | | || (|)
265 # (24) 18 12 65 (0)
266 widths_at_elwidth = {
267 0: 5,
268 1: 6,
269 2: 6,
270 3: 6
271 }
272
273 print("5,6,6,6 elements", widths_at_elwidth)
274 for i in range(4):
275 pp, bitp, lp, bm, b, c, d = \
276 layout(i, vec_el_counts, widths_at_elwidth)
277 pprint((i, (pp, bitp, lp, bm, b, c, d)))
278 # now check that the expected partition points occur
279 print("5,6,6,6 ppt keys", pp.keys())
280 assert list(pp.keys()) == [5, 6, 12, 18]
281 assert bm == 0 # no unused partitions
282
283 # this example was probably what the 5,6,6,6 one was supposed to be.
284 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
285 # elwidth=0b00 1x 24-bit |.........................24|
286 # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
287 # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
288 # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
289 # expected partitions (^) ^^ ^ ^^ (^)
290 # to be at these points: (|) || | || (|)
291 # (24) 1817 12 65 (0)
292 widths_at_elwidth = {
293 0: 24, # QTY 1x 24
294 1: 12, # QTY 1x 12
295 2: 5, # QTY 2x 5
296 3: 6 # QTY 4x 6
297 }
298
299 print("24,12,5,6 elements", widths_at_elwidth)
300 for i in range(4):
301 pp, bitp, lp, bm, b, c, d = \
302 layout(i, vec_el_counts, widths_at_elwidth)
303 pprint((i, (pp, bitp, lp, bm, b, c, d)))
304 # now check that the expected partition points occur
305 print("24,12,5,6 ppt keys", pp.keys())
306 assert list(pp.keys()) == [5, 6, 12, 17, 18]
307 print("bmask", bin(bm))
308 assert bm == 0 # no unused partitions
309
310 # this tests elwidth as an actual Signal. layout is allowed to
311 # determine arbitrarily the overall length
312 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
313
314 elwid = Signal(2)
315 pp, bitp, lp, bm, b, c, d = layout(
316 elwid, vec_el_counts, widths_at_elwidth)
317 pprint((pp, lp, b, c, d))
318 for k, v in bitp.items():
319 print("bitp elwidth=%d" % k, bin(v))
320 print("bmask", bin(bm))
321 assert bm == 0 # no unused partitions
322
323 m = Module()
324
325 def process():
326 for i in range(4):
327 yield elwid.eq(i)
328 yield Settle()
329 ppt = []
330 for pval in list(pp.values()):
331 val = yield pval # get nmigen to evaluate pp
332 ppt.append(val)
333 pprint((i, (ppt, b, c, d)))
334 # check the results against bitp static-expected partition points
335 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
336 # https://stackoverflow.com/a/27165694
337 ival = int(''.join(map(str, ppt[::-1])), 2)
338 assert ival == bitp[i]
339
340 sim = Simulator(m)
341 sim.add_process(process)
342 sim.run()
343
344 # this tests elwidth as an actual Signal. layout is *not* allowed to
345 # determine arbitrarily the overall length, it is fixed to 64
346 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
347
348 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
349 # elwidth=0b00 1x 24-bit
350 # elwidth=0b01 1x 12-bit
351 # elwidth=0b10 2x 5-bit
352 # elwidth=0b11 4x 6-bit
353 #
354 # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
355 # always unused:| | | || | | | | | | || |
356 # 1111111111000000 1111111111000000 1111111100000000 0000000000000000
357 # | | | || | | | | | | || |
358 # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
359 # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
360 # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
361 # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
362 # ^ ^ ^^ ^ ^ ^ ^ ^ ^^
363 # ppoints: | | || | | | | | ||
364 # | bit-48 /\ | bit-24-/ | | bit-12 /\-bit-5
365 # bit-54 bit-38-/ \ bit-32 | bit-16 /
366 # bit-37 bit-22 bit-6
367
368 elwid = Signal(2)
369 pp, bitp, lp, bm, b, c, d = layout(elwid, vec_el_counts,
370 widths_at_elwidth,
371 fixed_width=64)
372 pprint((pp, lp, b, c, d))
373 for k, v in bitp.items():
374 print("bitp elwidth=%d" % k, bin(v))
375 print("bmask", bin(bm))
376 assert bm == 0b101001000000
377
378 m = Module()
379
380 def process():
381 for i in range(4):
382 yield elwid.eq(i)
383 yield Settle()
384 ppt = []
385 for pval in list(pp.values()):
386 val = yield pval # get nmigen to evaluate pp
387 ppt.append(val)
388 print("test elwidth=%d" % i)
389 pprint((i, (ppt, b, c, d)))
390 # check the results against bitp static-expected partition points
391 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
392 # https://stackoverflow.com/a/27165694
393 ival = int(''.join(map(str, ppt[::-1])), 2)
394 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
395 bin(bitp[i]))
396
397 sim = Simulator(m)
398 sim.add_process(process)
399 sim.run()
400
401 # fixed_width=32 and no lane_widths says "allocate maximum"
402 # i.e. Vector Element Widths are auto-allocated
403 # elwidth=0b00 1x 32-bit | .................32 |
404 # elwidth=0b01 1x 32-bit | .................32 |
405 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
406 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
407 # expected partitions (^) | | | (^)
408 # to be at these points: (|) | | | |
409
410 # TODO, fix this so that it is correct. put it at the end so it
411 # shows that things break and doesn't stop the other tests.
412 print("maximum allocation from fixed_width=32")
413 for i in range(4):
414 pprint((i, layout(i, vec_el_counts, fixed_width=32)))
415
416 # example "exponent"
417 # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
418 # 1xFP64: 11 bits, one exponent
419 # 2xFP32: 8 bits, two exponents
420 # 4xFP16: 5 bits, four exponents
421 # 4xBF16: 8 bits, four exponents
422 vec_el_counts = {
423 0: 1, # QTY 1x FP64
424 1: 2, # QTY 2x FP32
425 2: 4, # QTY 4x FP16
426 3: 4, # QTY 4x BF16
427 }
428 widths_at_elwidth = {
429 0: 11, # FP64 ew=0b00
430 1: 8, # FP32 ew=0b01
431 2: 5, # FP16 ew=0b10
432 3: 8 # BF16 ew=0b11
433 }
434
435 # expected results:
436 #
437 # |31| | |24| 16|15 | | 8|7 0 |
438 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
439 # 32bit | x| x| x| | x| x| x|10 .... 0 |
440 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
441 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
442 # unused x x
443
444 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
445 for i in range(4):
446 pp, bitp, lp, bm, b, c, d = \
447 layout(i, vec_el_counts, widths_at_elwidth,
448 fixed_width=32)
449 pprint((i, (pp, lp, bitp, bin(bm), b, c, d)))
450 # now check that the expected partition points occur
451 print("11,8,5,8 pp keys", pp.keys())
452 #assert list(pp.keys()) == [5,6,12,18]
453
454 ###### ######
455 ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
456 ###### ######
457
458 # example "exponent"
459 vec_el_counts = {
460 0: 1, # QTY 1x FP64
461 1: 2, # QTY 2x FP32
462 2: 4, # QTY 4x FP16
463 3: 4, # QTY 4x BF16
464 }
465 widths_at_elwidth = {
466 0: 11, # FP64 ew=0b00
467 1: 11, # FP32 ew=0b01
468 2: 5, # FP16 ew=0b10
469 3: 8 # BF16 ew=0b11
470 }
471
472 # expected results:
473 #
474 # |31| | |24| 16|15 | | 8|7 0 |
475 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
476 # 32bit | x| x| x| | x| x| x|10 .... 0 |
477 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
478 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
479 # unused x x
480
481 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
482 for i in range(4):
483 pp, bitp, lp, bm, b, c, d = \
484 layout(i, vec_el_counts, widths_at_elwidth,
485 fixed_width=32)
486 pprint((i, (pp, lp, bitp, bin(bm), b, c, d)))
487 # now check that the expected partition points occur
488 print("11,8,5,8 pp keys", pp.keys())
489 #assert list(pp.keys()) == [5,6,12,18]