2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
15 from nmigen
import Signal
, Module
, Elaboratable
, Mux
, Cat
, Shape
, Repl
16 from nmigen
.sim
import Simulator
, Delay
, Settle
17 from nmigen
.cli
import rtlil
20 from collections
.abc
import Mapping
21 from functools
import reduce
23 from collections
import defaultdict
25 from ieee754
.part
.util
import XLEN
, FpElWid
, IntElWid
, SimdMap
, SimdScope
26 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
29 @dataclasses.dataclass
31 ppoints
: PartitionPoints
41 for field
in dataclasses
.fields(LayoutResult
):
42 field_v
= getattr(self
, field
.name
)
43 if isinstance(field_v
, PartitionPoints
):
44 field_v
= ',\n '.join(
45 f
"{k}: {v}" for k
, v
in field_v
.items())
46 field_v
= f
"{{{field_v}}}"
47 fields
.append(f
"{field.name}={field_v}")
48 fields
= ",\n ".join(fields
)
49 return f
"LayoutResult({fields})"
52 class SimdLayout(Shape
):
53 def __init__(self
, lane_shapes
=None, signed
=None, *, fixed_width
=None,
54 width_follows_hint
=True, scope
=None):
55 """calculate a SIMD layout.
58 * element: a single scalar value that is an element of a SIMD vector.
59 it has a width in bits, and a signedness. Every element is made of
60 1 or more parts. An element optionally includes the padding
62 * lane: an element. An element optionally includes the padding
64 * ElWid: the element-width (really the element type) of an instruction.
65 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
66 In Python, `ElWid` is either an enum type or is `int`.
67 Example `ElWid` definition for integers:
75 Example `ElWid` definition for floats:
82 * part: (not to be confused with a partition) A piece of a SIMD vector,
83 every SIMD vector is made of a non-negative integer of parts.
84 Elements are made of a power-of-two number of parts. A part is a
85 fixed number of bits wide for each different SIMD layout, it
86 doesn't vary when `elwid` changes. A part can have a bit width of
87 any non-negative integer, it is not restricted to power-of-two.
91 * lane_shapes: int or Mapping[ElWid, int] or SimdMap (optional)
92 the bit-width of all elements in this SIMD layout.
94 the signedness of all elements in this SIMD layout
95 * fixed_width: int (optional)
96 the total width of a SIMD vector. One of lane_shapes and fixed_width
98 * width_follows_hint: bool
99 if fixed_width defaults to SimdScope.get().simd_full_width_hint
101 Values used from SimdScope:
102 * elwid: ElWid or nmigen Value with ElWid as the shape
103 the current ElWid value
104 * part_counts: SimdMap
105 a map from `ElWid` values `k` to the number of parts in an element
106 when `elwid == k`. Values should be minimized, since higher values
107 often create bigger circuits.
110 # here, an I8 element is 1 part wide
111 part_counts = SimdMap({
119 # here, an F16 element is 1 part wide
120 part_counts = SimdMap({
128 scope
= SimdScope
.get()
129 assert isinstance(scope
, SimdScope
)
131 elwid
= self
.scope
.elwid
132 part_counts
= self
.scope
.part_counts
133 assert isinstance(part_counts
, SimdMap
)
134 simd_full_width_hint
= self
.scope
.simd_full_width_hint
135 full_part_count
= self
.scope
.full_part_count
136 print(f
"layout(elwid={elwid},\n"
137 f
" signed={signed},\n"
138 f
" part_counts={part_counts},\n"
139 f
" lane_shapes={lane_shapes},\n"
140 f
" fixed_width={fixed_width},\n"
141 f
" simd_full_width_hint={simd_full_width_hint},\n"
142 f
" width_follows_hint={width_follows_hint})")
144 # when there are no lane_shapes specified, this indicates a
145 # desire to use the maximum available space based on the fixed width
146 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
147 if lane_shapes
is None:
148 assert fixed_width
is not None, \
149 "both fixed_width and lane_shapes cannot be None"
151 for k
, cur_part_count
in part_counts
.items():
152 cur_element_count
= full_part_count
// cur_part_count
153 assert fixed_width
% cur_element_count
== 0, (
154 f
"fixed_width ({fixed_width}) can't be split evenly into "
155 f
"{cur_element_count} elements")
156 lane_shapes
[k
] = fixed_width
// cur_element_count
157 print("lane_shapes", fixed_width
, lane_shapes
)
158 # convert lane_shapes to a Mapping[ElWid, Any]
159 lane_shapes
= SimdMap(lane_shapes
).mapping
160 # filter out unsupported elwidths
161 lane_shapes
= {i
: lane_shapes
[i
] for i
in part_counts
.keys()}
162 self
.lane_shapes
= lane_shapes
163 # calculate the minimum possible bit-width of a part.
164 # we divide each element's width by the number of parts in an element,
165 # giving the number of bits needed per part.
167 for i
, c
in part_counts
.items():
168 # double negate to get ceil division
169 needed
= -(-lane_shapes
[i
] // c
)
170 min_part_wid
= max(min_part_wid
, needed
)
171 # calculate the minimum bit-width required
172 min_width
= min_part_wid
* full_part_count
173 print("width", min_width
, min_part_wid
, full_part_count
)
174 if width_follows_hint \
175 and min_width
<= simd_full_width_hint \
176 and fixed_width
is None:
177 fixed_width
= simd_full_width_hint
179 if fixed_width
is not None: # override the width and part_wid
180 assert min_width
<= fixed_width
, \
181 "not enough space to fit partitions"
182 self
.part_wid
= fixed_width
// full_part_count
183 assert fixed_width
% full_part_count
== 0, \
184 "fixed_width must be a multiple of full_part_count"
186 print("part_wid", self
.part_wid
, "count", full_part_count
)
188 # go with computed width
190 self
.part_wid
= min_part_wid
191 super().__init
__(width
, signed
)
192 # create the breakpoints dictionary.
193 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
194 # https://stackoverflow.com/questions/26367812/
195 # dpoints: dict from bit-index to dict[ElWid, None]
196 # we use a dict from ElWid to None as the values of dpoints in order to
198 dpoints
= defaultdict(dict) # if empty key, create a (empty) dict
199 for i
, cur_part_count
in part_counts
.items():
200 def add_p(bit_index
):
201 # auto-creates dict if key non-existent
202 dpoints
[bit_index
][i
] = None
203 # go through all elements for elwid `i`, each element starts at
204 # part index `start_part`, and goes for `cur_part_count` parts
205 for start_part
in range(0, full_part_count
, cur_part_count
):
206 start_bit
= start_part
* self
.part_wid
207 add_p(start_bit
) # start of lane
208 add_p(start_bit
+ lane_shapes
[i
]) # start of padding
209 # do not need the breakpoints at the very start or the very end
211 dpoints
.pop(self
.width
, None)
212 plist
= list(dpoints
.keys())
214 dpoints
= {k
: dpoints
[k
].keys() for k
in plist
}
215 self
.dpoints
= dpoints
218 print(f
"{k}: {list(dpoints[k])}")
219 # second stage, add (map to) the elwidth==i expressions.
220 # TODO: use nmutil.treereduce?
223 it
= map(lambda i
: elwid
== i
, dpoints
[p
])
224 points
[p
] = reduce(operator
.or_
, it
)
225 # third stage, create the binary values which *if* elwidth is set to i
226 # *would* result in the mask at that elwidth being set to this value
227 # these can easily be double-checked through Assertion
229 for i
in part_counts
.keys():
231 for p
, elwidths
in dpoints
.items():
233 bitpos
= plist
.index(p
)
234 self
.bitp
[i
] |
= 1 << bitpos
235 # fourth stage: determine which partitions are 100% unused.
236 # these can then be "blanked out"
237 self
.bmask
= (1 << len(plist
)) - 1
238 for p
in self
.bitp
.values():
240 self
.ppoints
= PartitionPoints(points
)
243 bitp
= ", ".join(f
"{k}: {bin(v)}" for k
, v
in self
.bitp
.items())
245 for k
, v
in self
.dpoints
.items():
246 dpoints
.append(f
"{k}: {list(v)}")
247 dpoints
= ",\n ".join(dpoints
)
249 for k
, v
in self
.ppoints
.items():
250 ppoints
.append(f
"{k}: {list(v)}")
251 ppoints
= ",\n ".join(ppoints
)
252 return (f
"SimdLayout(lane_shapes={self.lane_shapes},\n"
253 f
" signed={self.signed},\n"
254 f
" fixed_width={self.width},\n"
255 f
" scope={self.scope},\n"
256 f
" bitp={{{bitp}}},\n"
257 f
" bmask={bin(self.bmask)},\n"
260 f
" part_wid={self.part_wid},\n"
261 f
" ppoints=PartitionPoints({{\n"
265 if __name__
== '__main__':
266 # for each element-width (elwidth 0-3) the number of parts in an element
268 # | part0 | part1 | part2 | part3 |
269 # elwid=F64 4 parts per element: |<-------------F64------------->|
270 # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
271 # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
272 # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
273 # actual widths of Signals *within* those partitions is given separately
281 # width=3 indicates "we want the same element bit-width (3) at all elwids"
282 # elwid=F64 1x 3-bit |<--------i3------->|
283 # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
284 # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
285 # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
286 width_for_all_els
= 3
289 with
SimdScope(elwid
=i
, part_counts
=part_counts
):
290 print(i
, SimdLayout(width_for_all_els
, True, width_follows_hint
=False))
292 # fixed_width=32 and no lane_widths says "allocate maximum"
293 # elwid=F64 1x 32-bit |<-------i32------->|
294 # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
295 # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
296 # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
298 print("maximum allocation from fixed_width=32")
300 with
SimdScope(elwid
=i
, part_counts
=part_counts
):
301 print(i
, SimdLayout(signed
=True, fixed_width
=32))
303 # specify that the length is to be *different* at each of the elwidths.
304 # combined with part_counts we have:
305 # elwid=F64 1x 24-bit |<-------i24------->|
306 # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
307 # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
308 # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
309 widths_at_elwidth
= {
317 with
SimdScope(elwid
=i
, part_counts
=part_counts
):
318 print(i
, SimdLayout(widths_at_elwidth
,
319 False, width_follows_hint
=False))
321 # this tests elwidth as an actual Signal. layout is allowed to
322 # determine arbitrarily the overall length
323 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
325 with
SimdScope(elwid_type
=FpElWid
, part_counts
=part_counts
) as scope
:
326 l
= SimdLayout(widths_at_elwidth
, False, width_follows_hint
=False)
337 for pval
in l
.ppoints
.values():
338 val
= yield pval
# get nmigen to evaluate pp
341 # check the results against bitp static-expected partition points
342 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
343 # https://stackoverflow.com/a/27165694
344 ival
= int(''.join(map(str, ppt
[::-1])), 2)
345 assert ival
== l
.bitp
[i
]
348 sim
.add_process(process
)
351 # this tests elwidth as an actual Signal. layout uses the width hint
352 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
354 with
SimdScope(elwid_type
=FpElWid
, part_counts
=part_counts
) as scope
:
355 l
= SimdLayout(widths_at_elwidth
, False)
366 for pval
in l
.ppoints
.values():
367 val
= yield pval
# get nmigen to evaluate pp
370 # check the results against bitp static-expected partition points
371 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
372 # https://stackoverflow.com/a/27165694
373 ival
= int(''.join(map(str, ppt
[::-1])), 2)
374 assert ival
== l
.bitp
[i
]
377 sim
.add_process(process
)
380 # this tests elwidth as an actual Signal. layout is *not* allowed to
381 # determine arbitrarily the overall length, it is fixed to 64
382 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
384 with
SimdScope(elwid_type
=FpElWid
, part_counts
=part_counts
) as scope
:
385 l
= SimdLayout(widths_at_elwidth
, False, fixed_width
=64)
396 for pval
in list(l
.ppoints
.values()):
397 val
= yield pval
# get nmigen to evaluate pp
399 print(f
"test elwidth={i}")
401 # check the results against bitp static-expected partition points
402 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
403 # https://stackoverflow.com/a/27165694
404 ival
= int(''.join(map(str, ppt
[::-1])), 2)
405 assert ival
== l
.bitp
[i
], \
406 f
"ival {bin(ival)} actual {bin(l.bitp[i])}"
409 sim
.add_process(process
)
413 with
SimdScope(elwid_type
=IntElWid
):
414 print("\nSimdLayout(XLEN):")
415 l1
= SimdLayout(XLEN
)
417 print("\nSimdLayout(XLEN // 2):")
418 l2
= SimdLayout(XLEN
// 2)