improve code-comments
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 """
13
14 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
15 from nmigen.back.pysim import Simulator, Delay, Settle
16 from nmigen.cli import rtlil
17
18 from collections.abc import Mapping
19 from functools import reduce
20 import operator
21 from collections import defaultdict
22 from pprint import pprint
23
24 from ieee754.part_mul_add.partpoints import PartitionPoints
25
26
27 # main fn
28 def layout(elwid, signed, part_counts, lane_shapes, fixed_width=None):
29 # identify if the lane_shapes is a mapping (dict, etc.)
30 # if not, then assume that it is an integer (width) that
31 # needs to be requested across all partitions
32 if not isinstance(lane_shapes, Mapping):
33 lane_shapes = {i: lane_shapes for i in part_counts}
34 # compute a set of partition widths
35 cpart_wid = -min(-lane_shapes[i] // c for i, c in part_counts.items())
36 part_count = max(part_counts.values())
37 # calculate the minumum width required
38 width = cpart_wid * part_count
39 if fixed_width is not None: # override the width and part_wid
40 assert width < fixed_width, "not enough space to fit partitions"
41 part_wid = fixed_width // part_count
42 assert part_wid * part_count == fixed_width, \
43 "calculated width not aligned multiples"
44 width = fixed_width
45 print ("part_wid", part_wid, "count", part_count)
46 else:
47 # go with computed width
48 part_wid = cpart_wid
49 # create the breakpoints dictionary.
50 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
51 # https://stackoverflow.com/questions/26367812/
52 dpoints = defaultdict(list) # if empty key, create a (empty) list
53 for i, c in part_counts.items():
54 def add_p(p):
55 dpoints[p].append(i) # auto-creates list if key non-existent
56 for start in range(0, part_count, c):
57 add_p(start * part_wid) # start of lane
58 add_p(start * part_wid + lane_shapes[i]) # start of padding
59 # do not need the breakpoints at the very start or the very end
60 dpoints.pop(0, None)
61 dpoints.pop(width, None)
62 plist = list(dpoints.keys())
63 plist.sort()
64 print ("dpoints")
65 pprint(dict(dpoints))
66 # second stage, add (map to) the elwidth==i expressions.
67 # TODO: use nmutil.treereduce?
68 points = {}
69 for p in plist:
70 points[p] = map(lambda i: elwid == i, dpoints[p])
71 points[p] = reduce(operator.or_, points[p])
72 # third stage, create the binary values which *if* elwidth is set to i
73 # *would* result in the mask at that elwidth being set to this value
74 # these can easily be double-checked through Assertion
75 bitp = {}
76 for i in part_counts.keys():
77 bitp[i] = 0
78 for p, elwidths in dpoints.items():
79 if i in elwidths:
80 bitpos = plist.index(p)
81 bitp[i] |= 1<< bitpos
82 return (PartitionPoints(points), bitp, width, lane_shapes,
83 part_wid, part_count)
84
85
86 if __name__ == '__main__':
87
88 # for each element-width (elwidth 0-3) the number of partitions is given
89 # at elwidth=0b00 we want QTY 1 partitions
90 # at elwidth=0b01 we want QTY 1 partitions
91 # at elwidth=0b10 we want QTY 2 partitions
92 # at elwidth=0b11 we want QTY 3 partitions
93 # actual widths of Signals *within* those partitions is given separately
94 part_counts = {
95 0: 1,
96 1: 1,
97 2: 2,
98 3: 4,
99 }
100
101 # width=3 indicates "we want the same width (3) at all elwidths"
102 for i in range(4):
103 pprint((i, layout(i, True, part_counts, 3)))
104
105 # specify that the length is to be *different* at each of the elwidths.
106 # combined with part_counts we have:
107 # at elwidth=0b00 we want 1x 5-bit
108 # at elwidth=0b01 we want 1x 6-bit
109 # at elwidth=0b10 we want 2x 12-bit
110 # at elwidth=0b11 we want 3x 24-bit
111 widths_at_elwidth = {0: 5, 1: 6, 2: 12, 3: 24}
112
113 for i in range(4):
114 pprint((i, layout(i, False, part_counts, widths_at_elwidth)))
115
116 # this tests elwidth as an actual Signal. layout is allowed to
117 # determine arbitrarily the overall length
118 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
119
120 elwid = Signal(2)
121 pp,bitp,b,c,d,e = layout(elwid, False, part_counts, widths_at_elwidth)
122 pprint ((pp,b,c,d,e))
123 for k, v in bitp.items():
124 print ("bitp elwidth=%d" % k, bin(v))
125
126 m = Module()
127 def process():
128 for i in range(4):
129 yield elwid.eq(i)
130 yield Settle()
131 ppt = []
132 for pval in list(pp.values()):
133 val = yield pval # get nmigen to evaluate pp
134 ppt.append(val)
135 pprint((i, (ppt,b,c,d,e)))
136 # check the results against bitp static-expected partition points
137 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
138 # https://stackoverflow.com/a/27165694
139 ival = int(''.join(map(str, ppt[::-1])), 2)
140 assert ival == bitp[i]
141
142 sim = Simulator(m)
143 sim.add_process(process)
144 sim.run()
145
146 # this tests elwidth as an actual Signal. layout is *not* allowed to
147 # determine arbitrarily the overall length, it is fixed to 64
148 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
149
150 elwid = Signal(2)
151 pp,bitp,b,c,d,e = layout(elwid, False, part_counts, widths_at_elwidth,
152 fixed_width=64)
153 pprint ((pp,b,c,d,e))
154 for k, v in bitp.items():
155 print ("bitp elwidth=%d" % k, bin(v))
156
157 m = Module()
158 def process():
159 for i in range(4):
160 yield elwid.eq(i)
161 yield Settle()
162 ppt = []
163 for pval in list(pp.values()):
164 val = yield pval # get nmigen to evaluate pp
165 ppt.append(val)
166 print ("test elwidth=%d" % i)
167 pprint((i, (ppt,b,c,d,e)))
168 # check the results against bitp static-expected partition points
169 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
170 # https://stackoverflow.com/a/27165694
171 ival = int(''.join(map(str, ppt[::-1])), 2)
172 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
173 bin(bitp[i]))
174
175 sim = Simulator(m)
176 sim.add_process(process)
177 sim.run()