return layout points from layout()
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # XXX MAKE SURE TO PRESERVE ALL THESE COMMENTS XXX
29
30 # main fn, which started out here in the bugtracker:
31 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
32 # note that signed is **NOT** part of the layout, and will NOT
33 # be added (because it is not relevant or appropriate).
34 # sign belongs in ast.Shape and is the only appropriate location.
35 # there is absolutely nothing within this function that in any
36 # way requires a sign. it is *purely* performing numerical width
37 # computations that have absolutely nothing to do with whether the
38 # actual data is signed or unsigned.
39 #
40 # context for parameters:
41 # http://lists.libre-soc.org/pipermail/libre-soc-dev/2021-October/003921.html
42 # XXX tempted to suggest that this function remain as a function, because
43 # it takes all the context it needs as parameters. its usefulness goes
44 # beyond a single class, and there is actually nothing realistically
45 # that it needs whixh is context-sensitive. therefore, on balance,
46 # it should remain a function
47 def layout(elwid, # comes from SimdScope constructor
48 vec_el_counts, # comes from SimdScope constructor
49 lane_shapes=None, # from SimdScope.Signal via a SimdShape
50 fixed_width=None): # from SimdScope.Signal via a SimdShape
51 """calculate a SIMD layout.
52
53 Glossary:
54 * element: a single scalar value that is an element of a SIMD vector.
55 it has a width in bits. Every element is made of 1 or
56 more parts.
57 * ElWid: the element-width (really the element type) of an instruction.
58 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
59 In Python, `ElWid` is either an enum type or is `int`.
60 Example `ElWid` definition for integers:
61
62 class ElWid(Enum):
63 I64 = ... # SVP64 value 0b00
64 I32 = ... # SVP64 value 0b01
65 I16 = ... # SVP64 value 0b10
66 I8 = ... # SVP64 value 0b11
67
68 Example `ElWid` definition for floats:
69
70 class ElWid(Enum):
71 F64 = ... # SVP64 value 0b00
72 F32 = ... # SVP64 value 0b01
73 F16 = ... # SVP64 value 0b10
74 BF16 = ... # SVP64 value 0b11
75
76 * elwid: ElWid or nmigen Value with ElWid as the shape
77 the current element-width
78
79 * vec_el_counts: dict[ElWid, int]
80 a map from `ElWid` values `k` to the number of vector elements
81 required within a partition when `elwid == k`.
82
83 Example:
84 vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
85 ElWid.I16(==0b10): 4, # 4 vector elements
86 ElWid.I32(==0b01): 2, # 2 vector elements
87 ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
88
89 Another Example:
90 vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
91 ElWid.F16(==0b10): 4, # 4 vector elements
92 ElWid.F32(==0b01): 2, # 2 vector elements
93 ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
94
95 * lane_shapes: int or Mapping[ElWid, int] (optional)
96 the bit-width of all elements in a SIMD layout.
97 if not provided, the lane_shapes are computed from fixed_width
98 and vec_el_counts at each elwidth.
99
100 * fixed_width: int (optional)
101 the total width of a SIMD vector. One or both of lane_shapes or
102 fixed_width may be provided. Both may not be left out.
103 """
104 # when there are no lane_shapes specified, this indicates a
105 # desire to use the maximum available space based on the fixed width
106 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
107 if lane_shapes is None:
108 assert fixed_width is not None, \
109 "both fixed_width and lane_shapes cannot be None"
110 lane_shapes = {i: fixed_width // vec_el_counts[i]
111 for i in vec_el_counts}
112 print("lane_shapes", fixed_width, lane_shapes)
113
114 # identify if the lane_shapes is a mapping (dict, etc.)
115 # if not, then assume that it is an integer (width) that
116 # needs to be requested across all partitions
117 if not isinstance(lane_shapes, Mapping):
118 lane_shapes = {i: lane_shapes for i in vec_el_counts}
119
120 # compute a set of partition widths
121 print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
122 cpart_wid = 0
123 width = 0
124 for i, lwid in lane_shapes.items():
125 required_width = lwid * vec_el_counts[i]
126 print(" required width", cpart_wid, i, lwid, required_width)
127 if required_width > width:
128 cpart_wid = lwid
129 width = required_width
130
131 # calculate the minumum width required if fixed_width specified
132 part_count = max(vec_el_counts.values())
133 print("width", width, cpart_wid, part_count)
134 if fixed_width is not None: # override the width and part_wid
135 assert width <= fixed_width, "not enough space to fit partitions"
136 part_wid = fixed_width // part_count
137 assert part_wid * part_count == fixed_width, \
138 "calculated width not aligned multiples"
139 width = fixed_width
140 print("part_wid", part_wid, "count", part_count, "width", width)
141
142 # create the breakpoints dictionary.
143 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
144 # https://stackoverflow.com/questions/26367812/
145
146 # first phase: create a dictionary of partition points (dpoints) and
147 # also keep a record of where each element starts and ends (lpoints)
148 dpoints = defaultdict(list) # if empty key, create a (empty) list
149 lpoints = defaultdict(list) # dict of list of start-end points
150 padding_masks = {}
151 always_padding_mask = (1 << width) - 1 # start with all bits padding
152 for i, c in vec_el_counts.items():
153 print("dpoints", i, "count", c)
154 # calculate part_wid based on overall width divided by number
155 # of elements.
156 part_wid = width // c
157
158 padding_mask = (1 << width) - 1 # start with all bits padding
159
160 def add_p(msg, start, p):
161 print(" adding dpoint", msg, start, part_wid, i, c, p)
162 dpoints[p].append(i) # auto-creates list if key non-existent
163 # for each elwidth, create the required number of vector elements
164 for start in range(c):
165 start_bit = start * part_wid
166 end_bit = start_bit + lane_shapes[i]
167 element_mask = (1 << end_bit) - (1 << start_bit)
168 padding_mask &= ~element_mask # remove element from padding_mask
169 lpoints[i].append(range(start_bit, end_bit))
170 add_p("start", start, start_bit) # start of lane
171 add_p("end ", start, end_bit) # end lane
172 padding_masks[i] = padding_mask
173 always_padding_mask &= padding_mask
174
175 # deduplicate dpoints lists
176 for k in dpoints.keys():
177 dpoints[k] = list({i: None for i in dpoints[k]}.keys())
178
179 # do not need the breakpoints at the very start or the very end
180 dpoints.pop(0, None)
181 dpoints.pop(width, None)
182
183 # sort dpoints keys
184 dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
185
186 print("dpoints")
187 pprint(dpoints)
188
189 # second stage, add (map to) the elwidth==i expressions.
190 # TODO: use nmutil.treereduce?
191 points = {}
192 for p in dpoints.keys():
193 points[p] = map(lambda i: elwid == i, dpoints[p])
194 points[p] = reduce(operator.or_, points[p])
195
196 # third stage, create the binary values which *if* elwidth is set to i
197 # *would* result in the mask at that elwidth being set to this value
198 # these can easily be double-checked through Assertion
199 bitp = {}
200 for i in vec_el_counts.keys():
201 bitp[i] = 0
202 for bit_index, (p, elwidths) in enumerate(dpoints.items()):
203 if i in elwidths:
204 bitp[i] |= 1 << bit_index
205
206 # fourth stage: determine which partitions are 100% unused.
207 # these can then be "blanked out"
208
209 # points are the partition separators, not partition indexes
210 partition_ends = [*dpoints.keys(), width]
211 bmask = 0
212 partition_start = 0
213 for bit_index, partition_end in enumerate(partition_ends):
214 pmask = (1 << partition_end) - (1 << partition_start)
215 always_padding = (always_padding_mask & pmask) == pmask
216 if always_padding:
217 bmask |= 1 << bit_index
218 partition_start = partition_end
219 return (PartitionPoints(points), bitp, lpoints, bmask, width, lane_shapes,
220 part_wid)
221
222 # XXX XXX XXX XXX quick tests TODO convert to proper ones but kinda good
223 # enough for now. if adding new tests do not alter or delete the old ones
224 # XXX XXX XXX XXX
225
226 if __name__ == '__main__':
227
228 # for each element-width (elwidth 0-3) the number of Vector Elements is:
229 # elwidth=0b00 QTY 1 partitions: | ? |
230 # elwidth=0b01 QTY 1 partitions: | ? |
231 # elwidth=0b10 QTY 2 partitions: | ? | ? |
232 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
233 # actual widths of Signals *within* those partitions is given separately
234 vec_el_counts = {
235 0: 1,
236 1: 1,
237 2: 2,
238 3: 4,
239 }
240
241 # width=3 indicates "same width Vector Elements (3) at all elwidths"
242 # elwidth=0b00 1x 5-bit | unused xx ..3 |
243 # elwidth=0b01 1x 6-bit | unused xx ..3 |
244 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
245 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
246 # expected partitions (^) | | | (^)
247 # to be at these points: (|) | | | |
248 width_in_all_parts = 3
249
250 for i in range(4):
251 pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
252
253 # specify that the Vector Element lengths are to be *different* at
254 # each of the elwidths.
255 # combined with vec_el_counts we have:
256 # elwidth=0b00 1x 5-bit |<----unused---------->....5|
257 # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
258 # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
259 # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
260 # expected partitions (^) ^ ^ ^^ (^)
261 # to be at these points: (|) | | || (|)
262 # (24) 18 12 65 (0)
263 widths_at_elwidth = {
264 0: 5,
265 1: 6,
266 2: 6,
267 3: 6
268 }
269
270 print("5,6,6,6 elements", widths_at_elwidth)
271 for i in range(4):
272 pp, bitp, lp, bm, b, c, d = \
273 layout(i, vec_el_counts, widths_at_elwidth)
274 pprint((i, (pp, bitp, lp, bm, b, c, d)))
275 # now check that the expected partition points occur
276 print("5,6,6,6 ppt keys", pp.keys())
277 assert list(pp.keys()) == [5, 6, 12, 18]
278 assert bm == 0 # no unused partitions
279
280 # this example was probably what the 5,6,6,6 one was supposed to be.
281 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
282 # elwidth=0b00 1x 24-bit |.........................24|
283 # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
284 # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
285 # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
286 # expected partitions (^) ^^ ^ ^^ (^)
287 # to be at these points: (|) || | || (|)
288 # (24) 1817 12 65 (0)
289 widths_at_elwidth = {
290 0: 24, # QTY 1x 24
291 1: 12, # QTY 1x 12
292 2: 5, # QTY 2x 5
293 3: 6 # QTY 4x 6
294 }
295
296 print("24,12,5,6 elements", widths_at_elwidth)
297 for i in range(4):
298 pp, bitp, lp, bm, b, c, d = \
299 layout(i, vec_el_counts, widths_at_elwidth)
300 pprint((i, (pp, bitp, lp, bm, b, c, d)))
301 # now check that the expected partition points occur
302 print("24,12,5,6 ppt keys", pp.keys())
303 assert list(pp.keys()) == [5, 6, 12, 17, 18]
304 print("bmask", bin(bm))
305 assert bm == 0 # no unused partitions
306
307 # this tests elwidth as an actual Signal. layout is allowed to
308 # determine arbitrarily the overall length
309 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
310
311 elwid = Signal(2)
312 pp, bitp, lp, bm, b, c, d = layout(
313 elwid, vec_el_counts, widths_at_elwidth)
314 pprint((pp, lp, b, c, d))
315 for k, v in bitp.items():
316 print("bitp elwidth=%d" % k, bin(v))
317 print("bmask", bin(bm))
318 assert bm == 0 # no unused partitions
319
320 m = Module()
321
322 def process():
323 for i in range(4):
324 yield elwid.eq(i)
325 yield Settle()
326 ppt = []
327 for pval in list(pp.values()):
328 val = yield pval # get nmigen to evaluate pp
329 ppt.append(val)
330 pprint((i, (ppt, b, c, d)))
331 # check the results against bitp static-expected partition points
332 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
333 # https://stackoverflow.com/a/27165694
334 ival = int(''.join(map(str, ppt[::-1])), 2)
335 assert ival == bitp[i]
336
337 sim = Simulator(m)
338 sim.add_process(process)
339 sim.run()
340
341 # this tests elwidth as an actual Signal. layout is *not* allowed to
342 # determine arbitrarily the overall length, it is fixed to 64
343 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
344
345 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
346 # elwidth=0b00 1x 24-bit
347 # elwidth=0b01 1x 12-bit
348 # elwidth=0b10 2x 5-bit
349 # elwidth=0b11 4x 6-bit
350 #
351 # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
352 # always unused:| | | || | | | | | | || |
353 # 1111111111000000 1111111111000000 1111111100000000 0000000000000000
354 # | | | || | | | | | | || |
355 # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
356 # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
357 # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
358 # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
359 # ^ ^ ^^ ^ ^ ^ ^ ^ ^^
360 # ppoints: | | || | | | | | ||
361 # | bit-48 /\ | bit-24-/ | | bit-12 /\-bit-5
362 # bit-54 bit-38-/ \ bit-32 | bit-16 /
363 # bit-37 bit-22 bit-6
364
365 elwid = Signal(2)
366 pp, bitp, lp, bm, b, c, d = layout(elwid, vec_el_counts,
367 widths_at_elwidth,
368 fixed_width=64)
369 pprint((pp, lp, b, c, d))
370 for k, v in bitp.items():
371 print("bitp elwidth=%d" % k, bin(v))
372 print("bmask", bin(bm))
373 assert bm == 0b101001000000
374
375 m = Module()
376
377 def process():
378 for i in range(4):
379 yield elwid.eq(i)
380 yield Settle()
381 ppt = []
382 for pval in list(pp.values()):
383 val = yield pval # get nmigen to evaluate pp
384 ppt.append(val)
385 print("test elwidth=%d" % i)
386 pprint((i, (ppt, b, c, d)))
387 # check the results against bitp static-expected partition points
388 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
389 # https://stackoverflow.com/a/27165694
390 ival = int(''.join(map(str, ppt[::-1])), 2)
391 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
392 bin(bitp[i]))
393
394 sim = Simulator(m)
395 sim.add_process(process)
396 sim.run()
397
398 # fixed_width=32 and no lane_widths says "allocate maximum"
399 # i.e. Vector Element Widths are auto-allocated
400 # elwidth=0b00 1x 32-bit | .................32 |
401 # elwidth=0b01 1x 32-bit | .................32 |
402 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
403 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
404 # expected partitions (^) | | | (^)
405 # to be at these points: (|) | | | |
406
407 # TODO, fix this so that it is correct. put it at the end so it
408 # shows that things break and doesn't stop the other tests.
409 print("maximum allocation from fixed_width=32")
410 for i in range(4):
411 pprint((i, layout(i, vec_el_counts, fixed_width=32)))
412
413 # example "exponent"
414 # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
415 # 1xFP64: 11 bits, one exponent
416 # 2xFP32: 8 bits, two exponents
417 # 4xFP16: 5 bits, four exponents
418 # 4xBF16: 8 bits, four exponents
419 vec_el_counts = {
420 0: 1, # QTY 1x FP64
421 1: 2, # QTY 2x FP32
422 2: 4, # QTY 4x FP16
423 3: 4, # QTY 4x BF16
424 }
425 widths_at_elwidth = {
426 0: 11, # FP64 ew=0b00
427 1: 8, # FP32 ew=0b01
428 2: 5, # FP16 ew=0b10
429 3: 8 # BF16 ew=0b11
430 }
431
432 # expected results:
433 #
434 # |31| | |24| 16|15 | | 8|7 0 |
435 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
436 # 32bit | x| x| x| | x| x| x|10 .... 0 |
437 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
438 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
439 # unused x x
440
441 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
442 for i in range(4):
443 pp, bitp, lp, bm, b, c, d = \
444 layout(i, vec_el_counts, widths_at_elwidth,
445 fixed_width=32)
446 pprint((i, (pp, lp, bitp, bin(bm), b, c, d)))
447 # now check that the expected partition points occur
448 print("11,8,5,8 pp keys", pp.keys())
449 #assert list(pp.keys()) == [5,6,12,18]
450
451 ###### ######
452 ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
453 ###### ######
454
455 # example "exponent"
456 vec_el_counts = {
457 0: 1, # QTY 1x FP64
458 1: 2, # QTY 2x FP32
459 2: 4, # QTY 4x FP16
460 3: 4, # QTY 4x BF16
461 }
462 widths_at_elwidth = {
463 0: 11, # FP64 ew=0b00
464 1: 11, # FP32 ew=0b01
465 2: 5, # FP16 ew=0b10
466 3: 8 # BF16 ew=0b11
467 }
468
469 # expected results:
470 #
471 # |31| | |24| 16|15 | | 8|7 0 |
472 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
473 # 32bit | x| x| x| | x| x| x|10 .... 0 |
474 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
475 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
476 # unused x x
477
478 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
479 for i in range(4):
480 pp, bitp, lp, bm, b, c, d = \
481 layout(i, vec_el_counts, widths_at_elwidth,
482 fixed_width=32)
483 pprint((i, (pp, lp, bitp, bin(bm), b, c, d)))
484 # now check that the expected partition points occur
485 print("11,8,5,8 pp keys", pp.keys())
486 #assert list(pp.keys()) == [5,6,12,18]