refactor layout to use SimdScope and XLEN
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.sim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18 from enum import Enum
19
20 from collections.abc import Mapping
21 from functools import reduce
22 import operator
23 from collections import defaultdict
24 import dataclasses
25 from ieee754.part.util import XLEN, FpElWid, IntElWid, SimdMap, SimdScope
26 from ieee754.part_mul_add.partpoints import PartitionPoints
27
28
29 @dataclasses.dataclass
30 class LayoutResult:
31 ppoints: PartitionPoints
32 bitp: dict
33 bmask: int
34 width: int
35 lane_shapes: dict
36 part_wid: int
37 full_part_count: int
38
39 def __repr__(self):
40 fields = []
41 for field in dataclasses.fields(LayoutResult):
42 field_v = getattr(self, field.name)
43 if isinstance(field_v, PartitionPoints):
44 field_v = ',\n '.join(
45 f"{k}: {v}" for k, v in field_v.items())
46 field_v = f"{{{field_v}}}"
47 fields.append(f"{field.name}={field_v}")
48 fields = ",\n ".join(fields)
49 return f"LayoutResult({fields})"
50
51
52 class SimdLayout(Shape):
53 def __init__(self, lane_shapes=None, signed=None, *, fixed_width=None,
54 width_follows_hint=True, scope=None):
55 """calculate a SIMD layout.
56
57 Glossary:
58 * element: a single scalar value that is an element of a SIMD vector.
59 it has a width in bits, and a signedness. Every element is made of
60 1 or more parts. An element optionally includes the padding
61 associated with it.
62 * lane: an element. An element optionally includes the padding
63 associated with it.
64 * ElWid: the element-width (really the element type) of an instruction.
65 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
66 In Python, `ElWid` is either an enum type or is `int`.
67 Example `ElWid` definition for integers:
68
69 class ElWid(Enum):
70 I8 = ...
71 I16 = ...
72 I32 = ...
73 I64 = ...
74
75 Example `ElWid` definition for floats:
76
77 class ElWid(Enum):
78 F16 = ...
79 BF16 = ...
80 F32 = ...
81 F64 = ...
82 * part: (not to be confused with a partition) A piece of a SIMD vector,
83 every SIMD vector is made of a non-negative integer of parts.
84 Elements are made of a power-of-two number of parts. A part is a
85 fixed number of bits wide for each different SIMD layout, it
86 doesn't vary when `elwid` changes. A part can have a bit width of
87 any non-negative integer, it is not restricted to power-of-two.
88
89
90 Arguments:
91 * lane_shapes: int or Mapping[ElWid, int] or SimdMap (optional)
92 the bit-width of all elements in this SIMD layout.
93 * signed: bool
94 the signedness of all elements in this SIMD layout
95 * fixed_width: int (optional)
96 the total width of a SIMD vector. One of lane_shapes and fixed_width
97 must be provided.
98 * width_follows_hint: bool
99 if fixed_width defaults to SimdScope.get().simd_full_width_hint
100
101 Values used from SimdScope:
102 * elwid: ElWid or nmigen Value with ElWid as the shape
103 the current ElWid value
104 * part_counts: SimdMap
105 a map from `ElWid` values `k` to the number of parts in an element
106 when `elwid == k`. Values should be minimized, since higher values
107 often create bigger circuits.
108
109 Example:
110 # here, an I8 element is 1 part wide
111 part_counts = SimdMap({
112 IntElWid.I8: 1,
113 IntElWid.I16: 2,
114 IntElWid.I32: 4,
115 IntElWid.I64: 8,
116 })
117
118 Another Example:
119 # here, an F16 element is 1 part wide
120 part_counts = SimdMap({
121 FpElWid.F16: 1,
122 FpElWid.BF16: 1,
123 FpElWid.F32: 2,
124 FpElWid.F64: 4,
125 })
126 """
127 if scope is None:
128 scope = SimdScope.get()
129 assert isinstance(scope, SimdScope)
130 self.scope = scope
131 elwid = self.scope.elwid
132 part_counts = self.scope.part_counts
133 assert isinstance(part_counts, SimdMap)
134 simd_full_width_hint = self.scope.simd_full_width_hint
135 full_part_count = self.scope.full_part_count
136 print(f"layout(elwid={elwid},\n"
137 f" signed={signed},\n"
138 f" part_counts={part_counts},\n"
139 f" lane_shapes={lane_shapes},\n"
140 f" fixed_width={fixed_width},\n"
141 f" simd_full_width_hint={simd_full_width_hint},\n"
142 f" width_follows_hint={width_follows_hint})")
143
144 # when there are no lane_shapes specified, this indicates a
145 # desire to use the maximum available space based on the fixed width
146 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
147 if lane_shapes is None:
148 assert fixed_width is not None, \
149 "both fixed_width and lane_shapes cannot be None"
150 lane_shapes = {}
151 for k, cur_part_count in part_counts.items():
152 cur_element_count = full_part_count // cur_part_count
153 assert fixed_width % cur_element_count == 0, (
154 f"fixed_width ({fixed_width}) can't be split evenly into "
155 f"{cur_element_count} elements")
156 lane_shapes[k] = fixed_width // cur_element_count
157 print("lane_shapes", fixed_width, lane_shapes)
158 # convert lane_shapes to a Mapping[ElWid, Any]
159 lane_shapes = SimdMap(lane_shapes).mapping
160 # filter out unsupported elwidths
161 lane_shapes = {i: lane_shapes[i] for i in part_counts.keys()}
162 self.lane_shapes = lane_shapes
163 # calculate the minimum possible bit-width of a part.
164 # we divide each element's width by the number of parts in an element,
165 # giving the number of bits needed per part.
166 min_part_wid = 0
167 for i, c in part_counts.items():
168 # double negate to get ceil division
169 needed = -(-lane_shapes[i] // c)
170 min_part_wid = max(min_part_wid, needed)
171 # calculate the minimum bit-width required
172 min_width = min_part_wid * full_part_count
173 print("width", min_width, min_part_wid, full_part_count)
174 if width_follows_hint \
175 and min_width <= simd_full_width_hint \
176 and fixed_width is None:
177 fixed_width = simd_full_width_hint
178
179 if fixed_width is not None: # override the width and part_wid
180 assert min_width <= fixed_width, \
181 "not enough space to fit partitions"
182 self.part_wid = fixed_width // full_part_count
183 assert fixed_width % full_part_count == 0, \
184 "fixed_width must be a multiple of full_part_count"
185 width = fixed_width
186 print("part_wid", self.part_wid, "count", full_part_count)
187 else:
188 # go with computed width
189 width = min_width
190 self.part_wid = min_part_wid
191 super().__init__(width, signed)
192 # create the breakpoints dictionary.
193 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
194 # https://stackoverflow.com/questions/26367812/
195 # dpoints: dict from bit-index to dict[ElWid, None]
196 # we use a dict from ElWid to None as the values of dpoints in order to
197 # get an ordered set
198 dpoints = defaultdict(dict) # if empty key, create a (empty) dict
199 for i, cur_part_count in part_counts.items():
200 def add_p(bit_index):
201 # auto-creates dict if key non-existent
202 dpoints[bit_index][i] = None
203 # go through all elements for elwid `i`, each element starts at
204 # part index `start_part`, and goes for `cur_part_count` parts
205 for start_part in range(0, full_part_count, cur_part_count):
206 start_bit = start_part * self.part_wid
207 add_p(start_bit) # start of lane
208 add_p(start_bit + lane_shapes[i]) # start of padding
209 # do not need the breakpoints at the very start or the very end
210 dpoints.pop(0, None)
211 dpoints.pop(self.width, None)
212 plist = list(dpoints.keys())
213 plist.sort()
214 dpoints = {k: dpoints[k].keys() for k in plist}
215 self.dpoints = dpoints
216 print("dpoints")
217 for k in plist:
218 print(f"{k}: {list(dpoints[k])}")
219 # second stage, add (map to) the elwidth==i expressions.
220 # TODO: use nmutil.treereduce?
221 points = {}
222 for p in plist:
223 it = map(lambda i: elwid == i, dpoints[p])
224 points[p] = reduce(operator.or_, it)
225 # third stage, create the binary values which *if* elwidth is set to i
226 # *would* result in the mask at that elwidth being set to this value
227 # these can easily be double-checked through Assertion
228 self.bitp = {}
229 for i in part_counts.keys():
230 self.bitp[i] = 0
231 for p, elwidths in dpoints.items():
232 if i in elwidths:
233 bitpos = plist.index(p)
234 self.bitp[i] |= 1 << bitpos
235 # fourth stage: determine which partitions are 100% unused.
236 # these can then be "blanked out"
237 self.bmask = (1 << len(plist)) - 1
238 for p in self.bitp.values():
239 self.bmask &= ~p
240 self.ppoints = PartitionPoints(points)
241
242 def __repr__(self):
243 bitp = ", ".join(f"{k}: {bin(v)}" for k, v in self.bitp.items())
244 dpoints = []
245 for k, v in self.dpoints.items():
246 dpoints.append(f"{k}: {list(v)}")
247 dpoints = ",\n ".join(dpoints)
248 ppoints = []
249 for k, v in self.ppoints.items():
250 ppoints.append(f"{k}: {list(v)}")
251 ppoints = ",\n ".join(ppoints)
252 return (f"SimdLayout(lane_shapes={self.lane_shapes},\n"
253 f" signed={self.signed},\n"
254 f" fixed_width={self.width},\n"
255 f" scope={self.scope},\n"
256 f" bitp={{{bitp}}},\n"
257 f" bmask={bin(self.bmask)},\n"
258 f" dpoints={{\n"
259 f" {dpoints}}},\n"
260 f" part_wid={self.part_wid},\n"
261 f" ppoints=PartitionPoints({{\n"
262 f" {ppoints}}}))")
263
264
265 if __name__ == '__main__':
266 # for each element-width (elwidth 0-3) the number of parts in an element
267 # is given:
268 # | part0 | part1 | part2 | part3 |
269 # elwid=F64 4 parts per element: |<-------------F64------------->|
270 # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
271 # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
272 # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
273 # actual widths of Signals *within* those partitions is given separately
274 part_counts = {
275 FpElWid.F64: 4,
276 FpElWid.F32: 2,
277 FpElWid.F16: 1,
278 FpElWid.BF16: 1,
279 }
280
281 # width=3 indicates "we want the same element bit-width (3) at all elwids"
282 # elwid=F64 1x 3-bit |<--------i3------->|
283 # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
284 # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
285 # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
286 width_for_all_els = 3
287
288 for i in FpElWid:
289 with SimdScope(elwid=i, part_counts=part_counts):
290 print(i, SimdLayout(width_for_all_els, True, width_follows_hint=False))
291
292 # fixed_width=32 and no lane_widths says "allocate maximum"
293 # elwid=F64 1x 32-bit |<-------i32------->|
294 # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
295 # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
296 # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
297
298 print("maximum allocation from fixed_width=32")
299 for i in FpElWid:
300 with SimdScope(elwid=i, part_counts=part_counts):
301 print(i, SimdLayout(signed=True, fixed_width=32))
302
303 # specify that the length is to be *different* at each of the elwidths.
304 # combined with part_counts we have:
305 # elwid=F64 1x 24-bit |<-------i24------->|
306 # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
307 # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
308 # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
309 widths_at_elwidth = {
310 FpElWid.F64: 24,
311 FpElWid.F32: 12,
312 FpElWid.F16: 6,
313 FpElWid.BF16: 5,
314 }
315
316 for i in FpElWid:
317 with SimdScope(elwid=i, part_counts=part_counts):
318 print(i, SimdLayout(widths_at_elwidth,
319 False, width_follows_hint=False))
320
321 # this tests elwidth as an actual Signal. layout is allowed to
322 # determine arbitrarily the overall length
323 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
324
325 with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
326 l = SimdLayout(widths_at_elwidth, False, width_follows_hint=False)
327 elwid = scope.elwid
328 print(l)
329
330 m = Module()
331
332 def process():
333 for i in FpElWid:
334 yield elwid.eq(i)
335 yield Settle()
336 ppt = []
337 for pval in l.ppoints.values():
338 val = yield pval # get nmigen to evaluate pp
339 ppt.append(val)
340 print(i, ppt)
341 # check the results against bitp static-expected partition points
342 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
343 # https://stackoverflow.com/a/27165694
344 ival = int(''.join(map(str, ppt[::-1])), 2)
345 assert ival == l.bitp[i]
346
347 sim = Simulator(m)
348 sim.add_process(process)
349 sim.run()
350
351 # this tests elwidth as an actual Signal. layout uses the width hint
352 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
353
354 with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
355 l = SimdLayout(widths_at_elwidth, False)
356 elwid = scope.elwid
357 print(l)
358
359 m = Module()
360
361 def process():
362 for i in FpElWid:
363 yield elwid.eq(i)
364 yield Settle()
365 ppt = []
366 for pval in l.ppoints.values():
367 val = yield pval # get nmigen to evaluate pp
368 ppt.append(val)
369 print(i, ppt)
370 # check the results against bitp static-expected partition points
371 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
372 # https://stackoverflow.com/a/27165694
373 ival = int(''.join(map(str, ppt[::-1])), 2)
374 assert ival == l.bitp[i]
375
376 sim = Simulator(m)
377 sim.add_process(process)
378 sim.run()
379
380 # this tests elwidth as an actual Signal. layout is *not* allowed to
381 # determine arbitrarily the overall length, it is fixed to 64
382 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
383
384 with SimdScope(elwid_type=FpElWid, part_counts=part_counts) as scope:
385 l = SimdLayout(widths_at_elwidth, False, fixed_width=64)
386 elwid = scope.elwid
387 print(l)
388
389 m = Module()
390
391 def process():
392 for i in FpElWid:
393 yield elwid.eq(i)
394 yield Settle()
395 ppt = []
396 for pval in list(l.ppoints.values()):
397 val = yield pval # get nmigen to evaluate pp
398 ppt.append(val)
399 print(f"test elwidth={i}")
400 print(i, ppt)
401 # check the results against bitp static-expected partition points
402 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
403 # https://stackoverflow.com/a/27165694
404 ival = int(''.join(map(str, ppt[::-1])), 2)
405 assert ival == l.bitp[i], \
406 f"ival {bin(ival)} actual {bin(l.bitp[i])}"
407
408 sim = Simulator(m)
409 sim.add_process(process)
410 sim.run()
411
412 # test XLEN
413 with SimdScope(elwid_type=IntElWid):
414 print("\nSimdLayout(XLEN):")
415 l1 = SimdLayout(XLEN)
416 print(l1)
417 print("\nSimdLayout(XLEN // 2):")
418 l2 = SimdLayout(XLEN // 2)
419 print(l2)