"""
from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
-from nmigen.back.pysim import Simulator, Delay, Settle
+from nmigen.sim import Simulator, Delay, Settle
from nmigen.cli import rtlil
+from enum import Enum
from collections.abc import Mapping
from functools import reduce
import operator
from collections import defaultdict
-from pprint import pprint
+import dataclasses
from ieee754.part_mul_add.partpoints import PartitionPoints
+@dataclasses.dataclass
+class LayoutResult:
+ ppoints: PartitionPoints
+ bitp: dict
+ bmask: int
+ width: int
+ lane_shapes: dict
+ part_wid: int
+ full_part_count: int
+
+ def __repr__(self):
+ fields = []
+ for field in dataclasses.fields(LayoutResult):
+ field_v = getattr(self, field.name)
+ if isinstance(field_v, PartitionPoints):
+ field_v = ',\n '.join(
+ f"{k}: {v}" for k, v in field_v.items())
+ field_v = f"{{{field_v}}}"
+ fields.append(f"{field.name}={field_v}")
+ fields = ",\n ".join(fields)
+ return f"LayoutResult({fields})"
+
+
# main fn, which started out here in the bugtracker:
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None):
Glossary:
* element: a single scalar value that is an element of a SIMD vector.
it has a width in bits, and a signedness. Every element is made of 1 or
- more parts.
+ more parts. An element optionally includes the padding associated with
+ it.
+ * lane: an element. An element optionally includes the padding associated
+ with it.
* ElWid: the element-width (really the element type) of an instruction.
Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
In Python, `ElWid` is either an enum type or is `int`.
BF16 = ...
F32 = ...
F64 = ...
- * part: A piece of a SIMD vector, every SIMD vector is made of a
- non-negative integer of parts. Elements are made of a power-of-two
- number of parts. A part is a fixed number of bits wide for each
- different SIMD layout, it doesn't vary when `elwid` changes. A part
- can have a bit width of any non-negative integer, it is not restricted
- to power-of-two. SIMD vectors should have as few parts as necessary,
- since some circuits have size proportional to the number of parts.
+ * part: (not to be confused with a partition) A piece of a SIMD vector,
+ every SIMD vector is made of a non-negative integer of parts. Elements
+ are made of a power-of-two number of parts. A part is a fixed number
+ of bits wide for each different SIMD layout, it doesn't vary when
+ `elwid` changes. A part can have a bit width of any non-negative
+ integer, it is not restricted to power-of-two. SIMD vectors should
+ have as few parts as necessary, since some circuits have size
+ proportional to the number of parts.
* elwid: ElWid or nmigen Value with ElWid as the shape
the total width of a SIMD vector. One of lane_shapes and fixed_width
must be provided.
"""
+ print(f"layout(elwid={elwid},\n"
+ f" signed={signed},\n"
+ f" part_counts={part_counts},\n"
+ f" lane_shapes={lane_shapes},\n"
+ f" fixed_width={fixed_width})")
+ assert isinstance(part_counts, Mapping)
+ # assert all part_counts are powers of two
+ assert all(v != 0 and (v & (v - 1)) == 0 for v in part_counts.values()),\
+ "part_counts values must all be powers of two"
+
+ full_part_count = max(part_counts.values())
+
# when there are no lane_shapes specified, this indicates a
# desire to use the maximum available space based on the fixed width
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
if lane_shapes is None:
assert fixed_width is not None, \
"both fixed_width and lane_shapes cannot be None"
- lane_shapes = {i: fixed_width // part_counts[i] for i in part_counts}
+ lane_shapes = {}
+ for k, cur_part_count in part_counts.items():
+ cur_element_count = full_part_count // cur_part_count
+ assert fixed_width % cur_element_count == 0, (
+ f"fixed_width ({fixed_width}) can't be split evenly into "
+ f"{cur_element_count} elements")
+ lane_shapes[k] = fixed_width // cur_element_count
print("lane_shapes", fixed_width, lane_shapes)
# identify if the lane_shapes is a mapping (dict, etc.)
# if not, then assume that it is an integer (width) that
# needs to be requested across all partitions
if not isinstance(lane_shapes, Mapping):
lane_shapes = {i: lane_shapes for i in part_counts}
- # compute a set of partition widths
- cpart_wid = [-lane_shapes[i] // c for i, c in part_counts.items()]
- print("cpart_wid", cpart_wid, "part_counts", part_counts)
- cpart_wid = -min(cpart_wid)
- part_count = max(part_counts.values())
- # calculate the minumum width required
- width = cpart_wid * part_count
- print("width", width, cpart_wid, part_count)
+ # calculate the minimum possible bit-width of a part.
+ # we divide each element's width by the number of parts in an element,
+ # giving the number of bits needed per part.
+ # we use `-min(-a // b for ...)` to get `max(ceil(a / b) for ...)`,
+ # but using integers.
+ min_part_wid = -min(-lane_shapes[i] // c for i, c in part_counts.items())
+ # calculate the minimum bit-width required
+ min_width = min_part_wid * full_part_count
+ print("width", min_width, min_part_wid, full_part_count)
if fixed_width is not None: # override the width and part_wid
- assert width < fixed_width, "not enough space to fit partitions"
- part_wid = fixed_width // part_count
- assert part_wid * part_count == fixed_width, \
- "calculated width not aligned multiples"
+ assert min_width <= fixed_width, "not enough space to fit partitions"
+ part_wid = fixed_width // full_part_count
+ assert fixed_width % full_part_count == 0, \
+ "fixed_width must be a multiple of full_part_count"
width = fixed_width
- print("part_wid", part_wid, "count", part_count)
+ print("part_wid", part_wid, "count", full_part_count)
else:
# go with computed width
- part_wid = cpart_wid
+ width = min_width
+ part_wid = min_part_wid
# create the breakpoints dictionary.
# do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
# https://stackoverflow.com/questions/26367812/
- dpoints = defaultdict(list) # if empty key, create a (empty) list
- for i, c in part_counts.items():
- def add_p(p):
- dpoints[p].append(i) # auto-creates list if key non-existent
- for start in range(0, part_count, c):
- add_p(start * part_wid) # start of lane
- add_p(start * part_wid + lane_shapes[i]) # start of padding
+ # dpoints: dict from bit-index to dict[ElWid, None]
+ # we use a dict from ElWid to None as the values of dpoints in order to
+ # get an ordered set
+ dpoints = defaultdict(dict) # if empty key, create a (empty) dict
+ for i, cur_part_count in part_counts.items():
+ def add_p(bit_index):
+ # auto-creates dict if key non-existent
+ dpoints[bit_index][i] = None
+ # go through all elements for elwid `i`, each element starts at
+ # part index `start_part`, and goes for `cur_part_count` parts
+ for start_part in range(0, full_part_count, cur_part_count):
+ start_bit = start_part * part_wid
+ add_p(start_bit) # start of lane
+ add_p(start_bit + lane_shapes[i]) # start of padding
# do not need the breakpoints at the very start or the very end
dpoints.pop(0, None)
dpoints.pop(width, None)
plist = list(dpoints.keys())
plist.sort()
print("dpoints")
- pprint(dict(dpoints))
+ for k in plist:
+ print(f"{k}: {list(dpoints[k].keys())}")
# second stage, add (map to) the elwidth==i expressions.
# TODO: use nmutil.treereduce?
points = {}
for p in plist:
- points[p] = map(lambda i: elwid == i, dpoints[p])
- points[p] = reduce(operator.or_, points[p])
+ it = map(lambda i: elwid == i, dpoints[p])
+ points[p] = reduce(operator.or_, it)
# third stage, create the binary values which *if* elwidth is set to i
# *would* result in the mask at that elwidth being set to this value
# these can easily be double-checked through Assertion
bitp[i] |= 1 << bitpos
# fourth stage: determine which partitions are 100% unused.
# these can then be "blanked out"
- bmask = (1 << len(plist))-1
+ bmask = (1 << len(plist)) - 1
for p in bitp.values():
bmask &= ~p
- return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
- part_wid, part_count)
+ return LayoutResult(PartitionPoints(points), bitp, bmask, width,
+ lane_shapes, part_wid, full_part_count)
if __name__ == '__main__':
- # for each element-width (elwidth 0-3) the number of partitions is given
- # elwidth=0b00 QTY 1 partitions: | ? |
- # elwidth=0b01 QTY 1 partitions: | ? |
- # elwidth=0b10 QTY 2 partitions: | ? | ? |
- # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
+ class FpElWid(Enum):
+ F64 = 0
+ F32 = 1
+ F16 = 2
+ BF16 = 3
+
+ def __repr__(self):
+ return super().__str__()
+
+ class IntElWid(Enum):
+ I64 = 0
+ I32 = 1
+ I16 = 2
+ I8 = 3
+
+ def __repr__(self):
+ return super().__str__()
+
+ # for each element-width (elwidth 0-3) the number of parts in an element
+ # is given:
+ # | part0 | part1 | part2 | part3 |
+ # elwid=F64 4 parts per element: |<-------------F64------------->|
+ # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->|
+ # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->|
+ # elwid=BF16 1 part per element: |<BF16->|<BF16->|<BF16->|<BF16->|
# actual widths of Signals *within* those partitions is given separately
part_counts = {
- 0: 1,
- 1: 1,
- 2: 2,
- 3: 4,
+ FpElWid.F64: 4,
+ FpElWid.F32: 2,
+ FpElWid.F16: 1,
+ FpElWid.BF16: 1,
}
- # width=3 indicates "we want the same width (3) at all elwidths"
- # elwidth=0b00 1x 5-bit | ..3 |
- # elwidth=0b01 1x 6-bit | ..3 |
- # elwidth=0b10 2x 12-bit | ..3 | ..3 |
- # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
- width_in_all_parts = 3
+ # width=3 indicates "we want the same element bit-width (3) at all elwids"
+ # elwid=F64 1x 3-bit |<--------i3------->|
+ # elwid=F32 2x 3-bit |<---i3-->|<---i3-->|
+ # elwid=F16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
+ # elwid=BF16 4x 3-bit |<i3>|<i3>|<i3>|<i3>|
+ width_for_all_els = 3
- for i in range(4):
- pprint((i, layout(i, True, part_counts, width_in_all_parts)))
+ for i in FpElWid:
+ print(i, layout(i, True, part_counts, width_for_all_els))
# fixed_width=32 and no lane_widths says "allocate maximum"
- # elwidth=0b00 1x 32-bit | .................32 |
- # elwidth=0b01 1x 32-bit | .................32 |
- # elwidth=0b10 2x 12-bit | ......16 | ......16 |
- # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
+ # elwid=F64 1x 32-bit |<-------i32------->|
+ # elwid=F32 2x 16-bit |<--i16-->|<--i16-->|
+ # elwid=F16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
+ # elwid=BF16 4x 8-bit |<i8>|<i8>|<i8>|<i8>|
- #print ("maximum allocation from fixed_width=32")
- # for i in range(4):
- # pprint((i, layout(i, True, part_counts, fixed_width=32)))
+ print("maximum allocation from fixed_width=32")
+ for i in FpElWid:
+ print(i, layout(i, True, part_counts, fixed_width=32))
# specify that the length is to be *different* at each of the elwidths.
# combined with part_counts we have:
- # elwidth=0b00 1x 5-bit | ....5 |
- # elwidth=0b01 1x 6-bit | .....6 |
- # elwidth=0b10 2x 12-bit | ....12 | .....12 |
- # elwidth=0b11 3x 24-bit | 24 | 24 | 24 | 24 |
+ # elwid=F64 1x 24-bit |<-------i24------->|
+ # elwid=F32 2x 12-bit |<--i12-->|<--i12-->|
+ # elwid=F16 4x 6-bit |<i6>|<i6>|<i6>|<i6>|
+ # elwid=BF16 4x 5-bit |<i5>|<i5>|<i5>|<i5>|
widths_at_elwidth = {
- 0: 5,
- 1: 6,
- 2: 12,
- 3: 24
+ FpElWid.F64: 24,
+ FpElWid.F32: 12,
+ FpElWid.F16: 6,
+ FpElWid.BF16: 5,
}
- for i in range(4):
- pprint((i, layout(i, False, part_counts, widths_at_elwidth)))
+ for i in FpElWid:
+ print(i, layout(i, False, part_counts, widths_at_elwidth))
# this tests elwidth as an actual Signal. layout is allowed to
# determine arbitrarily the overall length
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
- elwid = Signal(2)
- pp, bitp, bm, b, c, d, e = layout(
- elwid, False, part_counts, widths_at_elwidth)
- pprint((pp, b, c, d, e))
- for k, v in bitp.items():
- print("bitp elwidth=%d" % k, bin(v))
- print("bmask", bin(bm))
+ elwid = Signal(FpElWid)
+ lr = layout(elwid, False, part_counts, widths_at_elwidth)
+ print(lr)
+ for k, v in lr.bitp.items():
+ print(f"bitp elwidth={k}", bin(v))
+ print("bmask", bin(lr.bmask))
m = Module()
def process():
- for i in range(4):
+ for i in FpElWid:
yield elwid.eq(i)
yield Settle()
ppt = []
- for pval in list(pp.values()):
+ for pval in lr.ppoints.values():
val = yield pval # get nmigen to evaluate pp
ppt.append(val)
- pprint((i, (ppt, b, c, d, e)))
+ print(i, ppt)
# check the results against bitp static-expected partition points
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
# https://stackoverflow.com/a/27165694
ival = int(''.join(map(str, ppt[::-1])), 2)
- assert ival == bitp[i]
+ assert ival == lr.bitp[i]
sim = Simulator(m)
sim.add_process(process)
# determine arbitrarily the overall length, it is fixed to 64
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
- elwid = Signal(2)
- pp, bitp, bm, b, c, d, e = layout(elwid, False, part_counts, widths_at_elwidth,
- fixed_width=64)
- pprint((pp, b, c, d, e))
- for k, v in bitp.items():
- print("bitp elwidth=%d" % k, bin(v))
- print("bmask", bin(bm))
+ elwid = Signal(FpElWid)
+ lr = layout(elwid, False, part_counts, widths_at_elwidth, fixed_width=64)
+ print(lr)
+ for k, v in lr.bitp.items():
+ print(f"bitp elwidth={k}", bin(v))
+ print("bmask", bin(lr.bmask))
m = Module()
def process():
- for i in range(4):
+ for i in FpElWid:
yield elwid.eq(i)
yield Settle()
ppt = []
- for pval in list(pp.values()):
+ for pval in list(lr.ppoints.values()):
val = yield pval # get nmigen to evaluate pp
ppt.append(val)
- print("test elwidth=%d" % i)
- pprint((i, (ppt, b, c, d, e)))
+ print(f"test elwidth={i}")
+ print(i, ppt)
# check the results against bitp static-expected partition points
# https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
# https://stackoverflow.com/a/27165694
ival = int(''.join(map(str, ppt[::-1])), 2)
- assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
- bin(bitp[i]))
+ assert ival == lr.bitp[i], \
+ f"ival {bin(ival)} actual {bin(lr.bitp[i])}"
sim = Simulator(m)
sim.add_process(process)