add option to specify fixed_width and no lane_shaps only to find
[ieee754fpu.git] / src / ieee754 / part_ass / assign.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "assign" class, directly equivalent
8 to nmigen Assign
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/assign
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=709
14
15 """
16
17 from nmigen import Signal, Module, Elaboratable, Cat, Const, signed
18 from nmigen.back.pysim import Simulator, Settle
19 from nmutil.extend import ext
20
21 from ieee754.part_mul_add.partpoints import PartitionPoints
22 from ieee754.part.partsig import SimdSignal
23
24
25 def get_runlengths(pbit, size):
26 res = []
27 count = 1
28 # identify where the 1s are, which indicates "start of a new partition"
29 # we want a list of the lengths of all partitions
30 for i in range(size):
31 if pbit & (1<<i): # it's a 1: ends old partition, starts new
32 res.append(count) # add partition
33 count = 1 # start again
34 else:
35 count += 1
36 # end reached, add whatever is left. could have done this by creating
37 # "fake" extra bit on the partitions, but hey
38 res.append(count)
39
40 print ("get_runlengths", bin(pbit), size, res)
41
42 return res
43
44
45 class PartitionedAssign(Elaboratable):
46 def __init__(self, shape, assign, ctx):
47 """Create a ``PartitionedAssign`` operator
48 """
49 # work out the length (total of all SimdSignals)
50 self.assign = assign
51 self.ptype = ctx
52 self.shape = shape
53 mask = ctx.get_mask()
54 self.output = SimdSignal(mask, self.shape, reset_less=True)
55 self.partition_points = self.output.partpoints
56 self.mwidth = len(self.partition_points)+1
57
58 def get_chunk(self, y, numparts):
59 x = self.assign
60 if not isinstance(x, SimdSignal):
61 # assume Scalar. totally different rules
62 end = numparts * (len(x) // self.mwidth)
63 return x[:end]
64 # SimdSignal: start at partition point
65 keys = [0] + list(x.partpoints.keys()) + [len(x)]
66 # get current index and increment it (for next Assign chunk)
67 upto = y[0]
68 y[0] += numparts
69 print ("getting", upto, numparts, keys, len(x))
70 # get the partition point as far as we are up to
71 start = keys[upto]
72 end = keys[upto+numparts]
73 print ("start end", start, end, len(x))
74 return x[start:end]
75
76 def elaborate(self, platform):
77 m = Module()
78 comb = m.d.comb
79
80 keys = list(self.partition_points.keys())
81 print ("keys", keys, "values", self.partition_points.values())
82 print ("ptype", self.ptype)
83 outpartsize = len(self.output) // self.mwidth
84 width, signed = self.output.shape()
85 print ("width, signed", width, signed)
86
87 with m.Switch(self.ptype.get_switch()):
88 # for each partition possibility, create a Assign sequence
89 for pbit in self.ptype.get_cases():
90 # set up some indices pointing to where things have got
91 # then when called below in the inner nested loop they give
92 # the relevant sequential chunk
93 output = []
94 y = [0]
95 # get a list of the length of each partition run
96 runlengths = get_runlengths(pbit, len(keys))
97 print ("pbit", bin(pbit), "runs", runlengths)
98 for i in runlengths: # for each partition
99 thing = self.get_chunk(y, i) # sequential chunks
100 # now check the length: truncate, extend or leave-alone
101 outlen = i * outpartsize
102 tlen = len(thing)
103 thing = ext(thing, (tlen, signed), outlen)
104 output.append(thing)
105 with m.Case(pbit):
106 # direct access to the underlying Signal
107 comb += self.output.sig.eq(Cat(*output))
108
109 return m
110
111 def ports(self):
112 if isinstance(self.assign, SimdSignal):
113 return [self.assign.lower(), self.output.lower()]
114 return [self.assign, self.output.lower()]
115
116
117 if __name__ == "__main__":
118 from ieee754.part.test.test_partsig import create_simulator
119 m = Module()
120 mask = Signal(3)
121 a = SimdSignal(mask, 32)
122 m.submodules.ass = ass = PartitionedAssign(signed(48), a, a.ptype)
123 omask = (1<<len(ass.output))-1
124
125 traces = ass.ports()
126 sim = create_simulator(m, traces, "partass")
127
128 def process():
129 yield mask.eq(0b000)
130 yield a.sig.eq(0xa12345c7)
131 yield Settle()
132 out = yield ass.output.sig
133 print("out 000", bin(out&omask), hex(out&omask))
134 yield mask.eq(0b010)
135 yield Settle()
136 out = yield ass.output.sig
137 print("out 010", bin(out&omask), hex(out&omask))
138 yield mask.eq(0b110)
139 yield Settle()
140 out = yield ass.output.sig
141 print("out 110", bin(out&omask), hex(out&omask))
142 yield mask.eq(0b111)
143 yield Settle()
144 out = yield ass.output.sig
145 print("out 111", bin(out&omask), hex(out&omask))
146
147 sim.add_process(process)
148 with sim.write_vcd("partition_ass.vcd", "partition_ass.gtkw",
149 traces=traces):
150 sim.run()
151
152 # Scalar
153 m = Module()
154 mask = Signal(3)
155 a = Signal(32)
156 class PartType:
157 def __init__(self, mask):
158 self.mask = mask
159 def get_mask(self):
160 return mask
161 def get_switch(self):
162 return Cat(self.get_mask())
163 def get_cases(self):
164 return range(1<<len(self.get_mask()))
165 @property
166 def blanklanes(self):
167 return 0
168 ptype = PartType(mask)
169 m.submodules.ass = ass = PartitionedAssign(signed(48), a, ptype)
170 omask = (1<<len(ass.output))-1
171
172 traces = ass.ports()
173 sim = create_simulator(m, traces, "partass")
174
175 def process():
176 yield mask.eq(0b000)
177 yield a.eq(0xa12345c7)
178 yield Settle()
179 out = yield ass.output.sig
180 print("out 000", bin(out&omask), hex(out&omask))
181 yield mask.eq(0b010)
182 yield Settle()
183 out = yield ass.output.sig
184 print("out 010", bin(out&omask), hex(out&omask))
185 yield mask.eq(0b110)
186 yield Settle()
187 out = yield ass.output.sig
188 print("out 110", bin(out&omask), hex(out&omask))
189 yield mask.eq(0b111)
190 yield Settle()
191 out = yield ass.output.sig
192 print("out 111", bin(out&omask), hex(out&omask))
193
194 sim.add_process(process)
195 with sim.write_vcd("partition_ass.vcd", "partition_ass.gtkw",
196 traces=traces):
197 sim.run()