redefine part_counts to be "number of vector elements in a partition"
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # main fn, which started out here in the bugtracker:
29 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
30 def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None):
31 """calculate a SIMD layout.
32
33 Glossary:
34 * element: a single scalar value that is an element of a SIMD vector.
35 it has a width in bits, and a signedness. Every element is made of 1 or
36 more parts.
37 * ElWid: the element-width (really the element type) of an instruction.
38 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
39 In Python, `ElWid` is either an enum type or is `int`.
40 Example `ElWid` definition for integers:
41
42 class ElWid(Enum):
43 I64 = ... # SVP64 value 0b00
44 I32 = ... # SVP64 value 0b01
45 I16 = ... # SVP64 value 0b10
46 I8 = ... # SVP64 value 0b11
47
48 Example `ElWid` definition for floats:
49
50 class ElWid(Enum):
51 F64 = ... # SVP64 value 0b00
52 F32 = ... # SVP64 value 0b01
53 F16 = ... # SVP64 value 0b10
54 BF16 = ... # SVP64 value 0b11
55
56 * part: A piece of a SIMD vector, every SIMD vector is made of a
57 non-negative integer of parts. Elements are made of a power-of-two
58 number of parts. A part is a fixed number of bits wide for each
59 different SIMD layout, it doesn't vary when `elwid` changes. A part
60 can have a bit width of any non-negative integer, it is not restricted
61 to power-of-two. SIMD vectors should have as few parts as necessary,
62 since some circuits have size proportional to the number of parts.
63
64
65 * elwid: ElWid or nmigen Value with ElWid as the shape
66 the current element-width
67 * signed: bool
68 the signedness of all elements in a SIMD layout
69 * part_counts: dict[ElWid, int]
70 a map from `ElWid` values `k` to the number of parts in an element
71 when `elwid == k`. Values should be minimized, since higher values
72 often create bigger circuits.
73
74 Example:
75 # here, an I8 element is 1 part wide
76 part_counts = {ElWid.I8: 1,
77 ElWid.I16: 2,
78 ElWid.I32: 4,
79 ElWid.I64: 8}
80
81 Another Example:
82 # here, an F16 element is 1 part wide
83 part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
84 * lane_shapes: int or Mapping[ElWid, int] (optional)
85 the bit-width of all elements in a SIMD layout.
86 * fixed_width: int (optional)
87 the total width of a SIMD vector. One of lane_shapes and fixed_width
88 must be provided.
89 """
90 # when there are no lane_shapes specified, this indicates a
91 # desire to use the maximum available space based on the fixed width
92 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
93 if lane_shapes is None:
94 assert fixed_width is not None, \
95 "both fixed_width and lane_shapes cannot be None"
96 lane_shapes = {i: fixed_width // part_counts[i] for i in part_counts}
97 print("lane_shapes", fixed_width, lane_shapes)
98 # identify if the lane_shapes is a mapping (dict, etc.)
99 # if not, then assume that it is an integer (width) that
100 # needs to be requested across all partitions
101 if not isinstance(lane_shapes, Mapping):
102 lane_shapes = {i: lane_shapes for i in part_counts}
103 # compute a set of partition widths
104 cpart_wid = [-lane_shapes[i] for i, c in part_counts.items()]
105 print("cpart_wid", cpart_wid, "part_counts", part_counts)
106 cpart_wid = -min(cpart_wid)
107 part_count = max(part_counts.values())
108 # calculate the minumum width required
109 width = cpart_wid * part_count
110 print("width", width, cpart_wid, part_count)
111 if fixed_width is not None: # override the width and part_wid
112 assert width < fixed_width, "not enough space to fit partitions"
113 part_wid = fixed_width // part_count
114 assert part_wid * part_count == fixed_width, \
115 "calculated width not aligned multiples"
116 width = fixed_width
117 print("part_wid", part_wid, "count", part_count)
118 else:
119 # go with computed width
120 part_wid = cpart_wid
121 # create the breakpoints dictionary.
122 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
123 # https://stackoverflow.com/questions/26367812/
124 dpoints = defaultdict(list) # if empty key, create a (empty) list
125 for i, c in part_counts.items():
126 def add_p(p):
127 dpoints[p].append(i) # auto-creates list if key non-existent
128 for start in range(0, part_count, c):
129 add_p(start * part_wid) # start of lane
130 add_p(start * part_wid + lane_shapes[i]) # start of padding
131 # do not need the breakpoints at the very start or the very end
132 dpoints.pop(0, None)
133 dpoints.pop(width, None)
134 plist = list(dpoints.keys())
135 plist.sort()
136 print("dpoints")
137 pprint(dict(dpoints))
138 # second stage, add (map to) the elwidth==i expressions.
139 # TODO: use nmutil.treereduce?
140 points = {}
141 for p in plist:
142 points[p] = map(lambda i: elwid == i, dpoints[p])
143 points[p] = reduce(operator.or_, points[p])
144 # third stage, create the binary values which *if* elwidth is set to i
145 # *would* result in the mask at that elwidth being set to this value
146 # these can easily be double-checked through Assertion
147 bitp = {}
148 for i in part_counts.keys():
149 bitp[i] = 0
150 for p, elwidths in dpoints.items():
151 if i in elwidths:
152 bitpos = plist.index(p)
153 bitp[i] |= 1 << bitpos
154 # fourth stage: determine which partitions are 100% unused.
155 # these can then be "blanked out"
156 bmask = (1 << len(plist))-1
157 for p in bitp.values():
158 bmask &= ~p
159 return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
160 part_wid, part_count)
161
162
163 if __name__ == '__main__':
164
165 # for each element-width (elwidth 0-3) the number of Vector Elements is:
166 # elwidth=0b00 QTY 1 partitions: | ? |
167 # elwidth=0b01 QTY 1 partitions: | ? |
168 # elwidth=0b10 QTY 2 partitions: | ? | ? |
169 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
170 # actual widths of Signals *within* those partitions is given separately
171 part_counts = {
172 0: 1,
173 1: 1,
174 2: 2,
175 3: 4,
176 }
177
178 # width=3 indicates "same width Vector Elements (3) at all elwidths"
179 # elwidth=0b00 1x 5-bit | unused xx ..3 |
180 # elwidth=0b01 1x 6-bit | unused xx ..3 |
181 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
182 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
183 # expected partitions (^) | | | (^)
184 # to be at these points: (|) | | | |
185 width_in_all_parts = 3
186
187 for i in range(4):
188 pprint((i, layout(i, True, part_counts, width_in_all_parts)))
189
190 # fixed_width=32 and no lane_widths says "allocate maximum"
191 # i.e. Vector Element Widths are auto-allocated
192 # elwidth=0b00 1x 32-bit | .................32 |
193 # elwidth=0b01 1x 32-bit | .................32 |
194 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
195 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
196 # expected partitions (^) | | | (^)
197 # to be at these points: (|) | | | |
198
199 # TODO, fix this so that it is correct
200 #print ("maximum allocation from fixed_width=32")
201 # for i in range(4):
202 # pprint((i, layout(i, True, part_counts, fixed_width=32)))
203
204 # specify that the Vector Element lengths are to be *different* at
205 # each of the elwidths.
206 # combined with part_counts we have:
207 # elwidth=0b00 1x 5-bit | <-- unused -->....5 |
208 # elwidth=0b01 1x 6-bit | <-- unused -->.....6 |
209 # elwidth=0b10 2x 12-bit | unused .....6 | unused .....6 |
210 # elwidth=0b11 3x 24-bit | .....6 | .....6 | .....6 | .....6 |
211 # expected partitions (^) ^ ^ ^^ (^)
212 # to be at these points: (|) | | || (|)
213 widths_at_elwidth = {
214 0: 5,
215 1: 6,
216 2: 6,
217 3: 6
218 }
219
220 print ("5,6,6,6 elements", widths_at_elwidth)
221 for i in range(4):
222 pprint((i, layout(i, False, part_counts, widths_at_elwidth)))
223
224 # this tests elwidth as an actual Signal. layout is allowed to
225 # determine arbitrarily the overall length
226 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
227
228 elwid = Signal(2)
229 pp, bitp, bm, b, c, d, e = layout(
230 elwid, False, part_counts, widths_at_elwidth)
231 pprint((pp, b, c, d, e))
232 for k, v in bitp.items():
233 print("bitp elwidth=%d" % k, bin(v))
234 print("bmask", bin(bm))
235
236 m = Module()
237
238 def process():
239 for i in range(4):
240 yield elwid.eq(i)
241 yield Settle()
242 ppt = []
243 for pval in list(pp.values()):
244 val = yield pval # get nmigen to evaluate pp
245 ppt.append(val)
246 pprint((i, (ppt, b, c, d, e)))
247 # check the results against bitp static-expected partition points
248 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
249 # https://stackoverflow.com/a/27165694
250 ival = int(''.join(map(str, ppt[::-1])), 2)
251 assert ival == bitp[i]
252
253 sim = Simulator(m)
254 sim.add_process(process)
255 sim.run()
256
257 # this tests elwidth as an actual Signal. layout is *not* allowed to
258 # determine arbitrarily the overall length, it is fixed to 64
259 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
260
261 elwid = Signal(2)
262 pp, bitp, bm, b, c, d, e = layout(elwid, False, part_counts,
263 widths_at_elwidth,
264 fixed_width=64)
265 pprint((pp, b, c, d, e))
266 for k, v in bitp.items():
267 print("bitp elwidth=%d" % k, bin(v))
268 print("bmask", bin(bm))
269
270 m = Module()
271
272 def process():
273 for i in range(4):
274 yield elwid.eq(i)
275 yield Settle()
276 ppt = []
277 for pval in list(pp.values()):
278 val = yield pval # get nmigen to evaluate pp
279 ppt.append(val)
280 print("test elwidth=%d" % i)
281 pprint((i, (ppt, b, c, d, e)))
282 # check the results against bitp static-expected partition points
283 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
284 # https://stackoverflow.com/a/27165694
285 ival = int(''.join(map(str, ppt[::-1])), 2)
286 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
287 bin(bitp[i]))
288
289 sim = Simulator(m)
290 sim.add_process(process)
291 sim.run()