2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
15 from nmigen
import Signal
, Module
, Elaboratable
, Mux
, Cat
, Shape
, Repl
16 from nmigen
.sim
import Simulator
, Delay
, Settle
17 from nmigen
.cli
import rtlil
20 from collections
.abc
import Mapping
21 from functools
import reduce
23 from collections
import defaultdict
26 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
29 @dataclasses.dataclass
31 ppoints
: PartitionPoints
41 for field
in dataclasses
.fields(LayoutResult
):
42 field_v
= getattr(self
, field
.name
)
43 if isinstance(field_v
, PartitionPoints
):
44 field_v
= ',\n '.join(
45 f
"{k}: {v}" for k
, v
in field_v
.items())
46 field_v
= f
"{{{field_v}}}"
47 fields
.append(f
"{field.name}={field_v}")
48 fields
= ",\n ".join(fields
)
49 return f
"LayoutResult({fields})"
52 # main fn, which started out here in the bugtracker:
53 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
54 def layout(elwid
, signed
, part_counts
, lane_shapes
=None, fixed_width
=None):
55 """calculate a SIMD layout.
58 * element: a single scalar value that is an element of a SIMD vector.
59 it has a width in bits, and a signedness. Every element is made of 1 or
60 more parts. An element optionally includes the padding associated with
62 * lane: an element. An element optionally includes the padding associated
64 * ElWid: the element-width (really the element type) of an instruction.
65 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
66 In Python, `ElWid` is either an enum type or is `int`.
67 Example `ElWid` definition for integers:
75 Example `ElWid` definition for floats:
82 * part: (not to be confused with a partition) A piece of a SIMD vector,
83 every SIMD vector is made of a non-negative integer of parts. Elements
84 are made of a power-of-two number of parts. A part is a fixed number
85 of bits wide for each different SIMD layout, it doesn't vary when
86 `elwid` changes. A part can have a bit width of any non-negative
87 integer, it is not restricted to power-of-two. SIMD vectors should
88 have as few parts as necessary, since some circuits have size
89 proportional to the number of parts.
92 * elwid: ElWid or nmigen Value with ElWid as the shape
93 the current element-width
95 the signedness of all elements in a SIMD layout
96 * part_counts: dict[ElWid, int]
97 a map from `ElWid` values `k` to the number of parts in an element
98 when `elwid == k`. Values should be minimized, since higher values
99 often create bigger circuits.
102 # here, an I8 element is 1 part wide
103 part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
106 # here, an F16 element is 1 part wide
107 part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
108 * lane_shapes: int or Mapping[ElWid, int] (optional)
109 the bit-width of all elements in a SIMD layout.
110 * fixed_width: int (optional)
111 the total width of a SIMD vector. One of lane_shapes and fixed_width
114 print(f
"layout(elwid={elwid},\n"
115 f
" signed={signed},\n"
116 f
" part_counts={part_counts},\n"
117 f
" lane_shapes={lane_shapes},\n"
118 f
" fixed_width={fixed_width})")
119 assert isinstance(part_counts
, Mapping
)
120 # assert all part_counts are powers of two
121 assert all(v
!= 0 and (v
& (v
- 1)) == 0 for v
in part_counts
.values()),\
122 "part_counts values must all be powers of two"
124 full_part_count
= max(part_counts
.values())
126 # when there are no lane_shapes specified, this indicates a
127 # desire to use the maximum available space based on the fixed width
128 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
129 if lane_shapes
is None:
130 assert fixed_width
is not None, \
131 "both fixed_width and lane_shapes cannot be None"
133 for k
, cur_part_count
in part_counts
.items():
134 cur_element_count
= full_part_count
// cur_part_count
135 assert fixed_width
% cur_element_count
== 0, (
136 f
"fixed_width ({fixed_width}) can't be split evenly into "
137 f
"{cur_element_count} elements")
138 lane_shapes
[k
] = fixed_width
// cur_element_count
139 print("lane_shapes", fixed_width
, lane_shapes
)
140 # identify if the lane_shapes is a mapping (dict, etc.)
141 # if not, then assume that it is an integer (width) that
142 # needs to be requested across all partitions
143 if not isinstance(lane_shapes
, Mapping
):
144 lane_shapes
= {i
: lane_shapes
for i
in part_counts
}
145 # calculate the minimum possible bit-width of a part.
146 # we divide each element's width by the number of parts in an element,
147 # giving the number of bits needed per part.
148 # we use `-min(-a // b for ...)` to get `max(ceil(a / b) for ...)`,
149 # but using integers.
150 min_part_wid
= -min(-lane_shapes
[i
] // c
for i
, c
in part_counts
.items())
151 # calculate the minimum bit-width required
152 min_width
= min_part_wid
* full_part_count
153 print("width", min_width
, min_part_wid
, full_part_count
)
154 if fixed_width
is not None: # override the width and part_wid
155 assert min_width
<= fixed_width
, "not enough space to fit partitions"
156 part_wid
= fixed_width
// full_part_count
157 assert fixed_width
% full_part_count
== 0, \
158 "fixed_width must be a multiple of full_part_count"
160 print("part_wid", part_wid
, "count", full_part_count
)
162 # go with computed width
164 part_wid
= min_part_wid
165 # create the breakpoints dictionary.
166 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
167 # https://stackoverflow.com/questions/26367812/
168 # dpoints: dict from bit-index to dict[ElWid, None]
169 # we use a dict from ElWid to None as the values of dpoints in order to
171 dpoints
= defaultdict(dict) # if empty key, create a (empty) dict
172 for i
, cur_part_count
in part_counts
.items():
173 def add_p(bit_index
):
174 # auto-creates dict if key non-existent
175 dpoints
[bit_index
][i
] = None
176 # go through all elements for elwid `i`, each element starts at
177 # part index `start_part`, and goes for `cur_part_count` parts
178 for start_part
in range(0, full_part_count
, cur_part_count
):
179 start_bit
= start_part
* part_wid
180 add_p(start_bit
) # start of lane
181 add_p(start_bit
+ lane_shapes
[i
]) # start of padding
182 # do not need the breakpoints at the very start or the very end
184 dpoints
.pop(width
, None)
185 plist
= list(dpoints
.keys())
189 print(f
"{k}: {list(dpoints[k].keys())}")
190 # second stage, add (map to) the elwidth==i expressions.
191 # TODO: use nmutil.treereduce?
194 it
= map(lambda i
: elwid
== i
, dpoints
[p
])
195 points
[p
] = reduce(operator
.or_
, it
)
196 # third stage, create the binary values which *if* elwidth is set to i
197 # *would* result in the mask at that elwidth being set to this value
198 # these can easily be double-checked through Assertion
200 for i
in part_counts
.keys():
202 for p
, elwidths
in dpoints
.items():
204 bitpos
= plist
.index(p
)
205 bitp
[i
] |
= 1 << bitpos
206 # fourth stage: determine which partitions are 100% unused.
207 # these can then be "blanked out"
208 bmask
= (1 << len(plist
)) - 1
209 for p
in bitp
.values():
211 return LayoutResult(PartitionPoints(points
), bitp
, bmask
, width
,
212 lane_shapes
, part_wid
, full_part_count
)
215 if __name__
== '__main__':
224 return super().__str
__()
226 class IntElWid(Enum
):
233 return super().__str
__()
235 # for each element-width (elwidth 0-3) the number of parts in an element
237 # | part0 | part1 | part2 | part3 |
238 # elwid=F64 4 parts per element: |<-------------F64------------->|
239 # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
240 # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
241 # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
242 # actual widths of Signals *within* those partitions is given separately
250 # width=3 indicates "we want the same element bit-width (3) at all elwids"
251 # elwid=F64 1x 3-bit |<--------i3------->|
252 # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
253 # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
254 # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
255 width_for_all_els
= 3
258 print(i
, layout(i
, True, part_counts
, width_for_all_els
))
260 # fixed_width=32 and no lane_widths says "allocate maximum"
261 # elwid=F64 1x 32-bit |<-------i32------->|
262 # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
263 # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
264 # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
266 print("maximum allocation from fixed_width=32")
268 print(i
, layout(i
, True, part_counts
, fixed_width
=32))
270 # specify that the length is to be *different* at each of the elwidths.
271 # combined with part_counts we have:
272 # elwid=F64 1x 24-bit |<-------i24------->|
273 # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
274 # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
275 # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
276 widths_at_elwidth
= {
284 print(i
, layout(i
, False, part_counts
, widths_at_elwidth
))
286 # this tests elwidth as an actual Signal. layout is allowed to
287 # determine arbitrarily the overall length
288 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
290 elwid
= Signal(FpElWid
)
291 lr
= layout(elwid
, False, part_counts
, widths_at_elwidth
)
293 for k
, v
in lr
.bitp
.items():
294 print(f
"bitp elwidth={k}", bin(v
))
295 print("bmask", bin(lr
.bmask
))
304 for pval
in lr
.ppoints
.values():
305 val
= yield pval
# get nmigen to evaluate pp
308 # check the results against bitp static-expected partition points
309 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
310 # https://stackoverflow.com/a/27165694
311 ival
= int(''.join(map(str, ppt
[::-1])), 2)
312 assert ival
== lr
.bitp
[i
]
315 sim
.add_process(process
)
318 # this tests elwidth as an actual Signal. layout is *not* allowed to
319 # determine arbitrarily the overall length, it is fixed to 64
320 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
322 elwid
= Signal(FpElWid
)
323 lr
= layout(elwid
, False, part_counts
, widths_at_elwidth
, fixed_width
=64)
325 for k
, v
in lr
.bitp
.items():
326 print(f
"bitp elwidth={k}", bin(v
))
327 print("bmask", bin(lr
.bmask
))
336 for pval
in list(lr
.ppoints
.values()):
337 val
= yield pval
# get nmigen to evaluate pp
339 print(f
"test elwidth={i}")
341 # check the results against bitp static-expected partition points
342 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
343 # https://stackoverflow.com/a/27165694
344 ival
= int(''.join(map(str, ppt
[::-1])), 2)
345 assert ival
== lr
.bitp
[i
], \
346 f
"ival {bin(ival)} actual {bin(lr.bitp[i])}"
349 sim
.add_process(process
)