fix layout bugs
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.sim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18 from enum import Enum
19
20 from collections.abc import Mapping
21 from functools import reduce
22 import operator
23 from collections import defaultdict
24 import dataclasses
25
26 from ieee754.part_mul_add.partpoints import PartitionPoints
27
28
29 @dataclasses.dataclass
30 class LayoutResult:
31 ppoints: PartitionPoints
32 bitp: dict
33 bmask: int
34 width: int
35 lane_shapes: dict
36 part_wid: int
37 full_part_count: int
38
39 def __repr__(self):
40 fields = []
41 for field in dataclasses.fields(LayoutResult):
42 field_v = getattr(self, field.name)
43 if isinstance(field_v, PartitionPoints):
44 field_v = ',\n '.join(
45 f"{k}: {v}" for k, v in field_v.items())
46 field_v = f"{{{field_v}}}"
47 fields.append(f"{field.name}={field_v}")
48 fields = ",\n ".join(fields)
49 return f"LayoutResult({fields})"
50
51
52 # main fn, which started out here in the bugtracker:
53 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
54 def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None):
55 """calculate a SIMD layout.
56
57 Glossary:
58 * element: a single scalar value that is an element of a SIMD vector.
59 it has a width in bits, and a signedness. Every element is made of 1 or
60 more parts. An element optionally includes the padding associated with
61 it.
62 * lane: an element. An element optionally includes the padding associated
63 with it.
64 * ElWid: the element-width (really the element type) of an instruction.
65 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
66 In Python, `ElWid` is either an enum type or is `int`.
67 Example `ElWid` definition for integers:
68
69 class ElWid(Enum):
70 I8 = ...
71 I16 = ...
72 I32 = ...
73 I64 = ...
74
75 Example `ElWid` definition for floats:
76
77 class ElWid(Enum):
78 F16 = ...
79 BF16 = ...
80 F32 = ...
81 F64 = ...
82 * part: (not to be confused with a partition) A piece of a SIMD vector,
83 every SIMD vector is made of a non-negative integer of parts. Elements
84 are made of a power-of-two number of parts. A part is a fixed number
85 of bits wide for each different SIMD layout, it doesn't vary when
86 `elwid` changes. A part can have a bit width of any non-negative
87 integer, it is not restricted to power-of-two. SIMD vectors should
88 have as few parts as necessary, since some circuits have size
89 proportional to the number of parts.
90
91
92 * elwid: ElWid or nmigen Value with ElWid as the shape
93 the current element-width
94 * signed: bool
95 the signedness of all elements in a SIMD layout
96 * part_counts: dict[ElWid, int]
97 a map from `ElWid` values `k` to the number of parts in an element
98 when `elwid == k`. Values should be minimized, since higher values
99 often create bigger circuits.
100
101 Example:
102 # here, an I8 element is 1 part wide
103 part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
104
105 Another Example:
106 # here, an F16 element is 1 part wide
107 part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
108 * lane_shapes: int or Mapping[ElWid, int] (optional)
109 the bit-width of all elements in a SIMD layout.
110 * fixed_width: int (optional)
111 the total width of a SIMD vector. One of lane_shapes and fixed_width
112 must be provided.
113 """
114 print(f"layout(elwid={elwid},\n"
115 f" signed={signed},\n"
116 f" part_counts={part_counts},\n"
117 f" lane_shapes={lane_shapes},\n"
118 f" fixed_width={fixed_width})")
119 assert isinstance(part_counts, Mapping)
120 # assert all part_counts are powers of two
121 assert all(v != 0 and (v & (v - 1)) == 0 for v in part_counts.values()),\
122 "part_counts values must all be powers of two"
123
124 full_part_count = max(part_counts.values())
125
126 # when there are no lane_shapes specified, this indicates a
127 # desire to use the maximum available space based on the fixed width
128 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
129 if lane_shapes is None:
130 assert fixed_width is not None, \
131 "both fixed_width and lane_shapes cannot be None"
132 lane_shapes = {}
133 for k, cur_part_count in part_counts.items():
134 cur_element_count = full_part_count // cur_part_count
135 assert fixed_width % cur_element_count == 0, (
136 f"fixed_width ({fixed_width}) can't be split evenly into "
137 f"{cur_element_count} elements")
138 lane_shapes[k] = fixed_width // cur_element_count
139 print("lane_shapes", fixed_width, lane_shapes)
140 # identify if the lane_shapes is a mapping (dict, etc.)
141 # if not, then assume that it is an integer (width) that
142 # needs to be requested across all partitions
143 if not isinstance(lane_shapes, Mapping):
144 lane_shapes = {i: lane_shapes for i in part_counts}
145 # calculate the minimum possible bit-width of a part.
146 # we divide each element's width by the number of parts in an element,
147 # giving the number of bits needed per part.
148 # we use `-min(-a // b for ...)` to get `max(ceil(a / b) for ...)`,
149 # but using integers.
150 min_part_wid = -min(-lane_shapes[i] // c for i, c in part_counts.items())
151 # calculate the minimum bit-width required
152 min_width = min_part_wid * full_part_count
153 print("width", min_width, min_part_wid, full_part_count)
154 if fixed_width is not None: # override the width and part_wid
155 assert min_width <= fixed_width, "not enough space to fit partitions"
156 part_wid = fixed_width // full_part_count
157 assert fixed_width % full_part_count == 0, \
158 "fixed_width must be a multiple of full_part_count"
159 width = fixed_width
160 print("part_wid", part_wid, "count", full_part_count)
161 else:
162 # go with computed width
163 width = min_width
164 part_wid = min_part_wid
165 # create the breakpoints dictionary.
166 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
167 # https://stackoverflow.com/questions/26367812/
168 # dpoints: dict from bit-index to dict[ElWid, None]
169 # we use a dict from ElWid to None as the values of dpoints in order to
170 # get an ordered set
171 dpoints = defaultdict(dict) # if empty key, create a (empty) dict
172 for i, cur_part_count in part_counts.items():
173 def add_p(bit_index):
174 # auto-creates dict if key non-existent
175 dpoints[bit_index][i] = None
176 # go through all elements for elwid `i`, each element starts at
177 # part index `start_part`, and goes for `cur_part_count` parts
178 for start_part in range(0, full_part_count, cur_part_count):
179 start_bit = start_part * part_wid
180 add_p(start_bit) # start of lane
181 add_p(start_bit + lane_shapes[i]) # start of padding
182 # do not need the breakpoints at the very start or the very end
183 dpoints.pop(0, None)
184 dpoints.pop(width, None)
185 plist = list(dpoints.keys())
186 plist.sort()
187 print("dpoints")
188 for k in plist:
189 print(f"{k}: {list(dpoints[k].keys())}")
190 # second stage, add (map to) the elwidth==i expressions.
191 # TODO: use nmutil.treereduce?
192 points = {}
193 for p in plist:
194 it = map(lambda i: elwid == i, dpoints[p])
195 points[p] = reduce(operator.or_, it)
196 # third stage, create the binary values which *if* elwidth is set to i
197 # *would* result in the mask at that elwidth being set to this value
198 # these can easily be double-checked through Assertion
199 bitp = {}
200 for i in part_counts.keys():
201 bitp[i] = 0
202 for p, elwidths in dpoints.items():
203 if i in elwidths:
204 bitpos = plist.index(p)
205 bitp[i] |= 1 << bitpos
206 # fourth stage: determine which partitions are 100% unused.
207 # these can then be "blanked out"
208 bmask = (1 << len(plist)) - 1
209 for p in bitp.values():
210 bmask &= ~p
211 return LayoutResult(PartitionPoints(points), bitp, bmask, width,
212 lane_shapes, part_wid, full_part_count)
213
214
215 if __name__ == '__main__':
216
217 class FpElWid(Enum):
218 F64 = 0
219 F32 = 1
220 F16 = 2
221 BF16 = 3
222
223 def __repr__(self):
224 return super().__str__()
225
226 class IntElWid(Enum):
227 I64 = 0
228 I32 = 1
229 I16 = 2
230 I8 = 3
231
232 def __repr__(self):
233 return super().__str__()
234
235 # for each element-width (elwidth 0-3) the number of parts in an element
236 # is given:
237 # | part0 | part1 | part2 | part3 |
238 # elwid=F64 4 parts per element: |<-------------F64------------->|
239 # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
240 # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
241 # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
242 # actual widths of Signals *within* those partitions is given separately
243 part_counts = {
244 FpElWid.F64: 4,
245 FpElWid.F32: 2,
246 FpElWid.F16: 1,
247 FpElWid.BF16: 1,
248 }
249
250 # width=3 indicates "we want the same element bit-width (3) at all elwids"
251 # elwid=F64 1x 3-bit |<--------i3------->|
252 # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
253 # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
254 # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
255 width_for_all_els = 3
256
257 for i in FpElWid:
258 print(i, layout(i, True, part_counts, width_for_all_els))
259
260 # fixed_width=32 and no lane_widths says "allocate maximum"
261 # elwid=F64 1x 32-bit |<-------i32------->|
262 # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
263 # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
264 # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
265
266 print("maximum allocation from fixed_width=32")
267 for i in FpElWid:
268 print(i, layout(i, True, part_counts, fixed_width=32))
269
270 # specify that the length is to be *different* at each of the elwidths.
271 # combined with part_counts we have:
272 # elwid=F64 1x 24-bit |<-------i24------->|
273 # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
274 # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
275 # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
276 widths_at_elwidth = {
277 FpElWid.F64: 24,
278 FpElWid.F32: 12,
279 FpElWid.F16: 6,
280 FpElWid.BF16: 5,
281 }
282
283 for i in FpElWid:
284 print(i, layout(i, False, part_counts, widths_at_elwidth))
285
286 # this tests elwidth as an actual Signal. layout is allowed to
287 # determine arbitrarily the overall length
288 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
289
290 elwid = Signal(FpElWid)
291 lr = layout(elwid, False, part_counts, widths_at_elwidth)
292 print(lr)
293 for k, v in lr.bitp.items():
294 print(f"bitp elwidth={k}", bin(v))
295 print("bmask", bin(lr.bmask))
296
297 m = Module()
298
299 def process():
300 for i in FpElWid:
301 yield elwid.eq(i)
302 yield Settle()
303 ppt = []
304 for pval in lr.ppoints.values():
305 val = yield pval # get nmigen to evaluate pp
306 ppt.append(val)
307 print(i, ppt)
308 # check the results against bitp static-expected partition points
309 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
310 # https://stackoverflow.com/a/27165694
311 ival = int(''.join(map(str, ppt[::-1])), 2)
312 assert ival == lr.bitp[i]
313
314 sim = Simulator(m)
315 sim.add_process(process)
316 sim.run()
317
318 # this tests elwidth as an actual Signal. layout is *not* allowed to
319 # determine arbitrarily the overall length, it is fixed to 64
320 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
321
322 elwid = Signal(FpElWid)
323 lr = layout(elwid, False, part_counts, widths_at_elwidth, fixed_width=64)
324 print(lr)
325 for k, v in lr.bitp.items():
326 print(f"bitp elwidth={k}", bin(v))
327 print("bmask", bin(lr.bmask))
328
329 m = Module()
330
331 def process():
332 for i in FpElWid:
333 yield elwid.eq(i)
334 yield Settle()
335 ppt = []
336 for pval in list(lr.ppoints.values()):
337 val = yield pval # get nmigen to evaluate pp
338 ppt.append(val)
339 print(f"test elwidth={i}")
340 print(i, ppt)
341 # check the results against bitp static-expected partition points
342 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
343 # https://stackoverflow.com/a/27165694
344 ival = int(''.join(map(str, ppt[::-1])), 2)
345 assert ival == lr.bitp[i], \
346 f"ival {bin(ival)} actual {bin(lr.bitp[i])}"
347
348 sim = Simulator(m)
349 sim.add_process(process)
350 sim.run()