switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # main fn, which started out here in the bugtracker:
29 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
30 # note that signed is **NOT** part of the layout, and will NOT
31 # be added (because it is not relevant or appropriate).
32 # sign belongs in ast.Shape and is the only appropriate location.
33 # there is absolutely nothing within this function that in any
34 # way requires a sign. it is *purely* performing numerical width
35 # computations that have absolutely nothing to do with whether the
36 # actual data is signed or unsigned.
37 def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
38 """calculate a SIMD layout.
39
40 Glossary:
41 * element: a single scalar value that is an element of a SIMD vector.
42 it has a width in bits. Every element is made of 1 or
43 more parts.
44 * ElWid: the element-width (really the element type) of an instruction.
45 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
46 In Python, `ElWid` is either an enum type or is `int`.
47 Example `ElWid` definition for integers:
48
49 class ElWid(Enum):
50 I64 = ... # SVP64 value 0b00
51 I32 = ... # SVP64 value 0b01
52 I16 = ... # SVP64 value 0b10
53 I8 = ... # SVP64 value 0b11
54
55 Example `ElWid` definition for floats:
56
57 class ElWid(Enum):
58 F64 = ... # SVP64 value 0b00
59 F32 = ... # SVP64 value 0b01
60 F16 = ... # SVP64 value 0b10
61 BF16 = ... # SVP64 value 0b11
62
63 * elwid: ElWid or nmigen Value with ElWid as the shape
64 the current element-width
65
66 * vec_el_counts: dict[ElWid, int]
67 a map from `ElWid` values `k` to the number of vector elements
68 required within a partition when `elwid == k`.
69
70 Example:
71 vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
72 ElWid.I16(==0b10): 4, # 4 vector elements
73 ElWid.I32(==0b01): 2, # 2 vector elements
74 ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
75
76 Another Example:
77 vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
78 ElWid.F16(==0b10): 4, # 4 vector elements
79 ElWid.F32(==0b01): 2, # 2 vector elements
80 ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
81
82 * lane_shapes: int or Mapping[ElWid, int] (optional)
83 the bit-width of all elements in a SIMD layout.
84 if not provided, the lane_shapes are computed from fixed_width
85 and vec_el_counts at each elwidth.
86
87 * fixed_width: int (optional)
88 the total width of a SIMD vector. One or both of lane_shapes or
89 fixed_width may be provided. Both may not be left out.
90 """
91 # when there are no lane_shapes specified, this indicates a
92 # desire to use the maximum available space based on the fixed width
93 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
94 if lane_shapes is None:
95 assert fixed_width is not None, \
96 "both fixed_width and lane_shapes cannot be None"
97 lane_shapes = {i: fixed_width // vec_el_counts[i]
98 for i in vec_el_counts}
99 print("lane_shapes", fixed_width, lane_shapes)
100
101 # identify if the lane_shapes is a mapping (dict, etc.)
102 # if not, then assume that it is an integer (width) that
103 # needs to be requested across all partitions
104 if not isinstance(lane_shapes, Mapping):
105 lane_shapes = {i: lane_shapes for i in vec_el_counts}
106
107 # compute a set of partition widths
108 print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
109 cpart_wid = 0
110 width = 0
111 for i, lwid in lane_shapes.items():
112 required_width = lwid * vec_el_counts[i]
113 print(" required width", cpart_wid, i, lwid, required_width)
114 if required_width > width:
115 cpart_wid = lwid
116 width = required_width
117
118 # calculate the minumum width required if fixed_width specified
119 part_count = max(vec_el_counts.values())
120 print("width", width, cpart_wid, part_count)
121 if fixed_width is not None: # override the width and part_wid
122 assert width <= fixed_width, "not enough space to fit partitions"
123 part_wid = fixed_width // part_count
124 assert part_wid * part_count == fixed_width, \
125 "calculated width not aligned multiples"
126 width = fixed_width
127 print("part_wid", part_wid, "count", part_count, "width", width)
128
129 # create the breakpoints dictionary.
130 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
131 # https://stackoverflow.com/questions/26367812/
132 dpoints = defaultdict(list) # if empty key, create a (empty) list
133 padding_masks = {}
134 always_padding_mask = (1 << width) - 1 # start with all bits padding
135 for i, c in vec_el_counts.items():
136 print("dpoints", i, "count", c)
137 # calculate part_wid based on overall width divided by number
138 # of elements.
139 part_wid = width // c
140
141 padding_mask = (1 << width) - 1 # start with all bits padding
142
143 def add_p(msg, start, p):
144 print(" adding dpoint", msg, start, part_wid, i, c, p)
145 dpoints[p].append(i) # auto-creates list if key non-existent
146 # for each elwidth, create the required number of vector elements
147 for start in range(c):
148 start_bit = start * part_wid
149 end_bit = start_bit + lane_shapes[i]
150 element_mask = (1 << end_bit) - (1 << start_bit)
151 padding_mask &= ~element_mask # remove element from padding_mask
152 add_p("start", start, start_bit) # start of lane
153 add_p("end ", start, end_bit) # end lane
154 padding_masks[i] = padding_mask
155 always_padding_mask &= padding_mask
156
157 # deduplicate dpoints lists
158 for k in dpoints.keys():
159 dpoints[k] = list({i: None for i in dpoints[k]}.keys())
160
161 # do not need the breakpoints at the very start or the very end
162 dpoints.pop(0, None)
163 dpoints.pop(width, None)
164
165 # sort dpoints keys
166 dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
167
168 print("dpoints")
169 pprint(dpoints)
170
171 # second stage, add (map to) the elwidth==i expressions.
172 # TODO: use nmutil.treereduce?
173 points = {}
174 for p in dpoints.keys():
175 points[p] = map(lambda i: elwid == i, dpoints[p])
176 points[p] = reduce(operator.or_, points[p])
177
178 # third stage, create the binary values which *if* elwidth is set to i
179 # *would* result in the mask at that elwidth being set to this value
180 # these can easily be double-checked through Assertion
181 bitp = {}
182 for i in vec_el_counts.keys():
183 bitp[i] = 0
184 for bit_index, (p, elwidths) in enumerate(dpoints.items()):
185 if i in elwidths:
186 bitp[i] |= 1 << bit_index
187
188 # fourth stage: determine which partitions are 100% unused.
189 # these can then be "blanked out"
190
191 # points are the partition separators, not partition indexes
192 partition_ends = [*dpoints.keys(), width]
193 bmask = 0
194 partition_start = 0
195 for bit_index, partition_end in enumerate(partition_ends):
196 pmask = (1 << partition_end) - (1 << partition_start)
197 always_padding = (always_padding_mask & pmask) == pmask
198 if always_padding:
199 bmask |= 1 << bit_index
200 partition_start = partition_end
201 return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
202 part_wid)
203
204
205 if __name__ == '__main__':
206
207 # for each element-width (elwidth 0-3) the number of Vector Elements is:
208 # elwidth=0b00 QTY 1 partitions: | ? |
209 # elwidth=0b01 QTY 1 partitions: | ? |
210 # elwidth=0b10 QTY 2 partitions: | ? | ? |
211 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
212 # actual widths of Signals *within* those partitions is given separately
213 vec_el_counts = {
214 0: 1,
215 1: 1,
216 2: 2,
217 3: 4,
218 }
219
220 # width=3 indicates "same width Vector Elements (3) at all elwidths"
221 # elwidth=0b00 1x 5-bit | unused xx ..3 |
222 # elwidth=0b01 1x 6-bit | unused xx ..3 |
223 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
224 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
225 # expected partitions (^) | | | (^)
226 # to be at these points: (|) | | | |
227 width_in_all_parts = 3
228
229 for i in range(4):
230 pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
231
232 # specify that the Vector Element lengths are to be *different* at
233 # each of the elwidths.
234 # combined with vec_el_counts we have:
235 # elwidth=0b00 1x 5-bit |<----unused---------->....5|
236 # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
237 # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
238 # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
239 # expected partitions (^) ^ ^ ^^ (^)
240 # to be at these points: (|) | | || (|)
241 # (24) 18 12 65 (0)
242 widths_at_elwidth = {
243 0: 5,
244 1: 6,
245 2: 6,
246 3: 6
247 }
248
249 print("5,6,6,6 elements", widths_at_elwidth)
250 for i in range(4):
251 pp, bitp, bm, b, c, d = \
252 layout(i, vec_el_counts, widths_at_elwidth)
253 pprint((i, (pp, bitp, bm, b, c, d)))
254 # now check that the expected partition points occur
255 print("5,6,6,6 ppt keys", pp.keys())
256 assert list(pp.keys()) == [5, 6, 12, 18]
257 assert bm == 0 # no unused partitions
258
259 # this example was probably what the 5,6,6,6 one was supposed to be.
260 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
261 # elwidth=0b00 1x 24-bit |.........................24|
262 # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
263 # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
264 # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
265 # expected partitions (^) ^^ ^ ^^ (^)
266 # to be at these points: (|) || | || (|)
267 # (24) 1817 12 65 (0)
268 widths_at_elwidth = {
269 0: 24, # QTY 1x 24
270 1: 12, # QTY 1x 12
271 2: 5, # QTY 2x 5
272 3: 6 # QTY 4x 6
273 }
274
275 print("24,12,5,6 elements", widths_at_elwidth)
276 for i in range(4):
277 pp, bitp, bm, b, c, d = \
278 layout(i, vec_el_counts, widths_at_elwidth)
279 pprint((i, (pp, bitp, bm, b, c, d)))
280 # now check that the expected partition points occur
281 print("24,12,5,6 ppt keys", pp.keys())
282 assert list(pp.keys()) == [5, 6, 12, 17, 18]
283 print("bmask", bin(bm))
284 assert bm == 0 # no unused partitions
285
286 # this tests elwidth as an actual Signal. layout is allowed to
287 # determine arbitrarily the overall length
288 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
289
290 elwid = Signal(2)
291 pp, bitp, bm, b, c, d = layout(
292 elwid, vec_el_counts, widths_at_elwidth)
293 pprint((pp, b, c, d))
294 for k, v in bitp.items():
295 print("bitp elwidth=%d" % k, bin(v))
296 print("bmask", bin(bm))
297 assert bm == 0 # no unused partitions
298
299 m = Module()
300
301 def process():
302 for i in range(4):
303 yield elwid.eq(i)
304 yield Settle()
305 ppt = []
306 for pval in list(pp.values()):
307 val = yield pval # get nmigen to evaluate pp
308 ppt.append(val)
309 pprint((i, (ppt, b, c, d)))
310 # check the results against bitp static-expected partition points
311 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
312 # https://stackoverflow.com/a/27165694
313 ival = int(''.join(map(str, ppt[::-1])), 2)
314 assert ival == bitp[i]
315
316 sim = Simulator(m)
317 sim.add_process(process)
318 sim.run()
319
320 # this tests elwidth as an actual Signal. layout is *not* allowed to
321 # determine arbitrarily the overall length, it is fixed to 64
322 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
323
324 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
325 # elwidth=0b00 1x 24-bit
326 # elwidth=0b01 1x 12-bit
327 # elwidth=0b10 2x 5-bit
328 # elwidth=0b11 4x 6-bit
329 #
330 # bmask<--------1<----0<---------10<---0<-------1<0<----0<---0<----00<---0
331 # always unused:| | | || | | | | | | || |
332 # 1111111111000000 1111111111000000 1111111100000000 0000000000000000
333 # | | | || | | | | | | || |
334 # 0b00 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxx........ ..............24|
335 # 0b01 xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxxxxxxxxxxxxxx xxxx..........12|
336 # 0b10 xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|xxxxxxxxxxxxxxxx xxxxxxxxxxx....5|
337 # 0b11 xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|xxxxxxxxxx.....6|
338 # ^ ^ ^^ ^ ^ ^ ^ ^ ^^
339 # ppoints: | | || | | | | | ||
340 # | bit-48 /\ | bit-24-/ | | bit-12 /\-bit-5
341 # bit-54 bit-38-/ \ bit-32 | bit-16 /
342 # bit-37 bit-22 bit-6
343
344 elwid = Signal(2)
345 pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
346 widths_at_elwidth,
347 fixed_width=64)
348 pprint((pp, b, c, d))
349 for k, v in bitp.items():
350 print("bitp elwidth=%d" % k, bin(v))
351 print("bmask", bin(bm))
352 assert bm == 0b101001000000
353
354 m = Module()
355
356 def process():
357 for i in range(4):
358 yield elwid.eq(i)
359 yield Settle()
360 ppt = []
361 for pval in list(pp.values()):
362 val = yield pval # get nmigen to evaluate pp
363 ppt.append(val)
364 print("test elwidth=%d" % i)
365 pprint((i, (ppt, b, c, d)))
366 # check the results against bitp static-expected partition points
367 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
368 # https://stackoverflow.com/a/27165694
369 ival = int(''.join(map(str, ppt[::-1])), 2)
370 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
371 bin(bitp[i]))
372
373 sim = Simulator(m)
374 sim.add_process(process)
375 sim.run()
376
377 # fixed_width=32 and no lane_widths says "allocate maximum"
378 # i.e. Vector Element Widths are auto-allocated
379 # elwidth=0b00 1x 32-bit | .................32 |
380 # elwidth=0b01 1x 32-bit | .................32 |
381 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
382 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
383 # expected partitions (^) | | | (^)
384 # to be at these points: (|) | | | |
385
386 # TODO, fix this so that it is correct. put it at the end so it
387 # shows that things break and doesn't stop the other tests.
388 print("maximum allocation from fixed_width=32")
389 for i in range(4):
390 pprint((i, layout(i, vec_el_counts, fixed_width=32)))
391
392 # example "exponent"
393 # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
394 # 1xFP64: 11 bits, one exponent
395 # 2xFP32: 8 bits, two exponents
396 # 4xFP16: 5 bits, four exponents
397 # 4xBF16: 8 bits, four exponents
398 vec_el_counts = {
399 0: 1, # QTY 1x FP64
400 1: 2, # QTY 2x FP32
401 2: 4, # QTY 4x FP16
402 3: 4, # QTY 4x BF16
403 }
404 widths_at_elwidth = {
405 0: 11, # FP64 ew=0b00
406 1: 8, # FP32 ew=0b01
407 2: 5, # FP16 ew=0b10
408 3: 8 # BF16 ew=0b11
409 }
410
411 # expected results:
412 #
413 # |31| | |24| 16|15 | | 8|7 0 |
414 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
415 # 32bit | x| x| x| | x| x| x|10 .... 0 |
416 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
417 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
418 # unused x x
419
420 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
421 for i in range(4):
422 pp, bitp, bm, b, c, d = \
423 layout(i, vec_el_counts, widths_at_elwidth,
424 fixed_width=32)
425 pprint((i, (pp, bitp, bin(bm), b, c, d)))
426 # now check that the expected partition points occur
427 print("11,8,5,8 pp keys", pp.keys())
428 #assert list(pp.keys()) == [5,6,12,18]
429
430 ###### ######
431 ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
432 ###### ######
433
434 # example "exponent"
435 vec_el_counts = {
436 0: 1, # QTY 1x FP64
437 1: 2, # QTY 2x FP32
438 2: 4, # QTY 4x FP16
439 3: 4, # QTY 4x BF16
440 }
441 widths_at_elwidth = {
442 0: 11, # FP64 ew=0b00
443 1: 11, # FP32 ew=0b01
444 2: 5, # FP16 ew=0b10
445 3: 8 # BF16 ew=0b11
446 }
447
448 # expected results:
449 #
450 # |31| | |24| 16|15 | | 8|7 0 |
451 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
452 # 32bit | x| x| x| | x| x| x|10 .... 0 |
453 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
454 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
455 # unused x x
456
457 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
458 for i in range(4):
459 pp, bitp, bm, b, c, d = \
460 layout(i, vec_el_counts, widths_at_elwidth,
461 fixed_width=32)
462 pprint((i, (pp, bitp, bin(bm), b, c, d)))
463 # now check that the expected partition points occur
464 print("11,8,5,8 pp keys", pp.keys())
465 #assert list(pp.keys()) == [5,6,12,18]