sort dpoints keys
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # main fn, which started out here in the bugtracker:
29 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
30 # note that signed is **NOT** part of the layout, and will NOT
31 # be added (because it is not relevant or appropriate).
32 # sign belongs in ast.Shape and is the only appropriate location.
33 # there is absolutely nothing within this function that in any
34 # way requires a sign. it is *purely* performing numerical width
35 # computations that have absolutely nothing to do with whether the
36 # actual data is signed or unsigned.
37 def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
38 """calculate a SIMD layout.
39
40 Glossary:
41 * element: a single scalar value that is an element of a SIMD vector.
42 it has a width in bits. Every element is made of 1 or
43 more parts.
44 * ElWid: the element-width (really the element type) of an instruction.
45 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
46 In Python, `ElWid` is either an enum type or is `int`.
47 Example `ElWid` definition for integers:
48
49 class ElWid(Enum):
50 I64 = ... # SVP64 value 0b00
51 I32 = ... # SVP64 value 0b01
52 I16 = ... # SVP64 value 0b10
53 I8 = ... # SVP64 value 0b11
54
55 Example `ElWid` definition for floats:
56
57 class ElWid(Enum):
58 F64 = ... # SVP64 value 0b00
59 F32 = ... # SVP64 value 0b01
60 F16 = ... # SVP64 value 0b10
61 BF16 = ... # SVP64 value 0b11
62
63 * elwid: ElWid or nmigen Value with ElWid as the shape
64 the current element-width
65
66 * vec_el_counts: dict[ElWid, int]
67 a map from `ElWid` values `k` to the number of vector elements
68 required within a partition when `elwid == k`.
69
70 Example:
71 vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
72 ElWid.I16(==0b10): 4, # 4 vector elements
73 ElWid.I32(==0b01): 2, # 2 vector elements
74 ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
75
76 Another Example:
77 vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
78 ElWid.F16(==0b10): 4, # 4 vector elements
79 ElWid.F32(==0b01): 2, # 2 vector elements
80 ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
81
82 * lane_shapes: int or Mapping[ElWid, int] (optional)
83 the bit-width of all elements in a SIMD layout.
84 if not provided, the lane_shapes are computed from fixed_width
85 and vec_el_counts at each elwidth.
86
87 * fixed_width: int (optional)
88 the total width of a SIMD vector. One or both of lane_shapes or
89 fixed_width may be provided. Both may not be left out.
90 """
91 # when there are no lane_shapes specified, this indicates a
92 # desire to use the maximum available space based on the fixed width
93 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
94 if lane_shapes is None:
95 assert fixed_width is not None, \
96 "both fixed_width and lane_shapes cannot be None"
97 lane_shapes = {i: fixed_width // vec_el_counts[i]
98 for i in vec_el_counts}
99 print("lane_shapes", fixed_width, lane_shapes)
100
101 # identify if the lane_shapes is a mapping (dict, etc.)
102 # if not, then assume that it is an integer (width) that
103 # needs to be requested across all partitions
104 if not isinstance(lane_shapes, Mapping):
105 lane_shapes = {i: lane_shapes for i in vec_el_counts}
106
107 # compute a set of partition widths
108 print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
109 cpart_wid = 0
110 width = 0
111 for i, lwid in lane_shapes.items():
112 required_width = lwid * vec_el_counts[i]
113 print(" required width", cpart_wid, i, lwid, required_width)
114 if required_width > width:
115 cpart_wid = lwid
116 width = required_width
117
118 # calculate the minumum width required if fixed_width specified
119 part_count = max(vec_el_counts.values())
120 print("width", width, cpart_wid, part_count)
121 if fixed_width is not None: # override the width and part_wid
122 assert width <= fixed_width, "not enough space to fit partitions"
123 part_wid = fixed_width // part_count
124 assert part_wid * part_count == fixed_width, \
125 "calculated width not aligned multiples"
126 width = fixed_width
127 print("part_wid", part_wid, "count", part_count, "width", width)
128
129 # create the breakpoints dictionary.
130 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
131 # https://stackoverflow.com/questions/26367812/
132 dpoints = defaultdict(list) # if empty key, create a (empty) list
133 for i, c in vec_el_counts.items():
134 print("dpoints", i, "count", c)
135 # calculate part_wid based on overall width divided by number
136 # of elements.
137 part_wid = width // c
138
139 def add_p(msg, start, p):
140 print(" adding dpoint", msg, start, part_wid, i, c, p)
141 dpoints[p].append(i) # auto-creates list if key non-existent
142 # for each elwidth, create the required number of vector elements
143 for start in range(c):
144 add_p("start", start, start * part_wid) # start of lane
145 add_p("end ", start, start * part_wid +
146 lane_shapes[i]) # end lane
147
148 # deduplicate dpoints lists
149 for k in dpoints.keys():
150 dpoints[k] = list({i: None for i in dpoints[k]}.keys())
151
152 # do not need the breakpoints at the very start or the very end
153 dpoints.pop(0, None)
154 dpoints.pop(width, None)
155
156 # sort dpoints keys
157 dpoints = dict(sorted(dpoints.items(), key=lambda i: i[0]))
158
159 plist = list(dpoints.keys())
160 print("dpoints")
161 pprint(dpoints)
162
163 # second stage, add (map to) the elwidth==i expressions.
164 # TODO: use nmutil.treereduce?
165 points = {}
166 for p in plist:
167 points[p] = map(lambda i: elwid == i, dpoints[p])
168 points[p] = reduce(operator.or_, points[p])
169
170 # third stage, create the binary values which *if* elwidth is set to i
171 # *would* result in the mask at that elwidth being set to this value
172 # these can easily be double-checked through Assertion
173 bitp = {}
174 for i in vec_el_counts.keys():
175 bitp[i] = 0
176 for p, elwidths in dpoints.items():
177 if i in elwidths:
178 bitpos = plist.index(p)
179 bitp[i] |= 1 << bitpos
180
181 # fourth stage: determine which partitions are 100% unused.
182 # these can then be "blanked out"
183 bmask = (1 << len(plist))-1
184 for p in bitp.values():
185 bmask &= ~p
186 return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
187 part_wid)
188
189
190 if __name__ == '__main__':
191
192 # for each element-width (elwidth 0-3) the number of Vector Elements is:
193 # elwidth=0b00 QTY 1 partitions: | ? |
194 # elwidth=0b01 QTY 1 partitions: | ? |
195 # elwidth=0b10 QTY 2 partitions: | ? | ? |
196 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
197 # actual widths of Signals *within* those partitions is given separately
198 vec_el_counts = {
199 0: 1,
200 1: 1,
201 2: 2,
202 3: 4,
203 }
204
205 # width=3 indicates "same width Vector Elements (3) at all elwidths"
206 # elwidth=0b00 1x 5-bit | unused xx ..3 |
207 # elwidth=0b01 1x 6-bit | unused xx ..3 |
208 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
209 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
210 # expected partitions (^) | | | (^)
211 # to be at these points: (|) | | | |
212 width_in_all_parts = 3
213
214 for i in range(4):
215 pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
216
217 # specify that the Vector Element lengths are to be *different* at
218 # each of the elwidths.
219 # combined with vec_el_counts we have:
220 # elwidth=0b00 1x 5-bit |<----unused---------->....5|
221 # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
222 # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
223 # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
224 # expected partitions (^) ^ ^ ^^ (^)
225 # to be at these points: (|) | | || (|)
226 # (24) 18 12 65 (0)
227 widths_at_elwidth = {
228 0: 5,
229 1: 6,
230 2: 6,
231 3: 6
232 }
233
234 print("5,6,6,6 elements", widths_at_elwidth)
235 for i in range(4):
236 pp, bitp, bm, b, c, d = \
237 layout(i, vec_el_counts, widths_at_elwidth)
238 pprint((i, (pp, bitp, bm, b, c, d)))
239 # now check that the expected partition points occur
240 print("5,6,6,6 ppt keys", pp.keys())
241 assert list(pp.keys()) == [5, 6, 12, 18]
242
243 # this example was probably what the 5,6,6,6 one was supposed to be.
244 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
245 # elwidth=0b00 1x 24-bit |.........................24|
246 # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
247 # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
248 # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
249 # expected partitions (^) ^^ ^ ^^ (^)
250 # to be at these points: (|) || | || (|)
251 # (24) 1817 12 65 (0)
252 widths_at_elwidth = {
253 0: 24, # QTY 1x 24
254 1: 12, # QTY 1x 12
255 2: 5, # QTY 2x 5
256 3: 6 # QTY 4x 6
257 }
258
259 print("24,12,5,6 elements", widths_at_elwidth)
260 for i in range(4):
261 pp, bitp, bm, b, c, d = \
262 layout(i, vec_el_counts, widths_at_elwidth)
263 pprint((i, (pp, bitp, bm, b, c, d)))
264 # now check that the expected partition points occur
265 print("24,12,5,6 ppt keys", pp.keys())
266 assert list(pp.keys()) == [5, 6, 12, 17, 18]
267
268 # this tests elwidth as an actual Signal. layout is allowed to
269 # determine arbitrarily the overall length
270 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
271
272 elwid = Signal(2)
273 pp, bitp, bm, b, c, d = layout(
274 elwid, vec_el_counts, widths_at_elwidth)
275 pprint((pp, b, c, d))
276 for k, v in bitp.items():
277 print("bitp elwidth=%d" % k, bin(v))
278 print("bmask", bin(bm))
279
280 m = Module()
281
282 def process():
283 for i in range(4):
284 yield elwid.eq(i)
285 yield Settle()
286 ppt = []
287 for pval in list(pp.values()):
288 val = yield pval # get nmigen to evaluate pp
289 ppt.append(val)
290 pprint((i, (ppt, b, c, d)))
291 # check the results against bitp static-expected partition points
292 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
293 # https://stackoverflow.com/a/27165694
294 ival = int(''.join(map(str, ppt[::-1])), 2)
295 assert ival == bitp[i]
296
297 sim = Simulator(m)
298 sim.add_process(process)
299 sim.run()
300
301 # this tests elwidth as an actual Signal. layout is *not* allowed to
302 # determine arbitrarily the overall length, it is fixed to 64
303 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
304
305 elwid = Signal(2)
306 pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
307 widths_at_elwidth,
308 fixed_width=64)
309 pprint((pp, b, c, d))
310 for k, v in bitp.items():
311 print("bitp elwidth=%d" % k, bin(v))
312 print("bmask", bin(bm))
313
314 m = Module()
315
316 def process():
317 for i in range(4):
318 yield elwid.eq(i)
319 yield Settle()
320 ppt = []
321 for pval in list(pp.values()):
322 val = yield pval # get nmigen to evaluate pp
323 ppt.append(val)
324 print("test elwidth=%d" % i)
325 pprint((i, (ppt, b, c, d)))
326 # check the results against bitp static-expected partition points
327 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
328 # https://stackoverflow.com/a/27165694
329 ival = int(''.join(map(str, ppt[::-1])), 2)
330 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
331 bin(bitp[i]))
332
333 sim = Simulator(m)
334 sim.add_process(process)
335 sim.run()
336
337 # fixed_width=32 and no lane_widths says "allocate maximum"
338 # i.e. Vector Element Widths are auto-allocated
339 # elwidth=0b00 1x 32-bit | .................32 |
340 # elwidth=0b01 1x 32-bit | .................32 |
341 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
342 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
343 # expected partitions (^) | | | (^)
344 # to be at these points: (|) | | | |
345
346 # TODO, fix this so that it is correct. put it at the end so it
347 # shows that things break and doesn't stop the other tests.
348 print("maximum allocation from fixed_width=32")
349 for i in range(4):
350 pprint((i, layout(i, vec_el_counts, fixed_width=32)))
351
352 # example "exponent"
353 # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
354 # 1xFP64: 11 bits, one exponent
355 # 2xFP32: 8 bits, two exponents
356 # 4xFP16: 5 bits, four exponents
357 # 4xBF16: 8 bits, four exponents
358 vec_el_counts = {
359 0: 1, # QTY 1x FP64
360 1: 2, # QTY 2x FP32
361 2: 4, # QTY 4x FP16
362 3: 4, # QTY 4x BF16
363 }
364 widths_at_elwidth = {
365 0: 11, # FP64 ew=0b00
366 1: 8, # FP32 ew=0b01
367 2: 5, # FP16 ew=0b10
368 3: 8 # BF16 ew=0b11
369 }
370
371 # expected results:
372 #
373 # |31| | |24| 16|15 | | 8|7 0 |
374 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
375 # 32bit | x| x| x| | x| x| x|10 .... 0 |
376 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
377 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
378 # unused x x
379
380 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
381 for i in range(4):
382 pp, bitp, bm, b, c, d = \
383 layout(i, vec_el_counts, widths_at_elwidth,
384 fixed_width=32)
385 pprint((i, (pp, bitp, bin(bm), b, c, d)))
386 # now check that the expected partition points occur
387 print("11,8,5,8 pp keys", pp.keys())
388 #assert list(pp.keys()) == [5,6,12,18]
389
390 ###### ######
391 ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
392 ###### ######
393
394 # example "exponent"
395 vec_el_counts = {
396 0: 1, # QTY 1x FP64
397 1: 2, # QTY 2x FP32
398 2: 4, # QTY 4x FP16
399 3: 4, # QTY 4x BF16
400 }
401 widths_at_elwidth = {
402 0: 11, # FP64 ew=0b00
403 1: 11, # FP32 ew=0b01
404 2: 5, # FP16 ew=0b10
405 3: 8 # BF16 ew=0b11
406 }
407
408 # expected results:
409 #
410 # |31| | |24| 16|15 | | 8|7 0 |
411 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
412 # 32bit | x| x| x| | x| x| x|10 .... 0 |
413 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
414 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
415 # unused x x
416
417 print("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
418 for i in range(4):
419 pp, bitp, bm, b, c, d = \
420 layout(i, vec_el_counts, widths_at_elwidth,
421 fixed_width=32)
422 pprint((i, (pp, bitp, bin(bm), b, c, d)))
423 # now check that the expected partition points occur
424 print("11,8,5,8 pp keys", pp.keys())
425 #assert list(pp.keys()) == [5,6,12,18]