change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / part_repl / repl.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "repl" class, directly equivalent
8 to nmigen Repl
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/repl
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=709
14
15 """
16
17 from nmigen import Signal, Module, Elaboratable, Cat, Repl
18 from nmigen.back.pysim import Simulator, Settle
19 from nmigen.cli import rtlil
20
21 from ieee754.part_mul_add.partpoints import PartitionPoints
22 from ieee754.part.partsig import PartitionedSignal
23
24
25 def get_runlengths(pbit, size):
26 res = []
27 count = 1
28 # identify where the 1s are, which indicates "start of a new partition"
29 # we want a list of the lengths of all partitions
30 for i in range(size):
31 if pbit & (1<<i): # it's a 1: ends old partition, starts new
32 res.append(count) # add partition
33 count = 1 # start again
34 else:
35 count += 1
36 # end reached, add whatever is left. could have done this by creating
37 # "fake" extra bit on the partitions, but hey
38 res.append(count)
39
40 print ("get_runlengths", bin(pbit), size, res)
41
42 return res
43
44
45 class PartitionedRepl(Elaboratable):
46 def __init__(self, repl, qty, mask):
47 """Create a ``PartitionedRepl`` operator
48 """
49 # work out the length (total of all PartitionedSignals)
50 self.repl = repl
51 self.qty = qty
52 width, signed = repl.shape()
53 if isinstance(mask, dict):
54 mask = list(mask.values())
55 self.mask = mask
56 self.shape = (width * qty), signed
57 self.output = PartitionedSignal(mask, self.shape, reset_less=True)
58 self.partition_points = self.output.partpoints
59 self.mwidth = len(self.partition_points)+1
60
61 def get_chunk(self, y, numparts):
62 x = self.repl
63 if not isinstance(x, PartitionedSignal):
64 # assume Scalar. totally different rules
65 end = numparts * (len(x) // self.mwidth)
66 return x[:end]
67 # PartitionedSignal: start at partition point
68 keys = [0] + list(x.partpoints.keys()) + [len(x)]
69 # get current index and increment it (for next Repl chunk)
70 upto = y[0]
71 y[0] += numparts
72 print ("getting", upto, numparts, keys, len(x))
73 # get the partition point as far as we are up to
74 start = keys[upto]
75 end = keys[upto+numparts]
76 print ("start end", start, end, len(x))
77 return x[start:end]
78
79 def elaborate(self, platform):
80 m = Module()
81 comb = m.d.comb
82
83 keys = list(self.partition_points.keys())
84 print ("keys", keys, "values", self.partition_points.values())
85 print ("mask", self.mask)
86 outpartsize = len(self.output) // self.mwidth
87 width, signed = self.output.shape()
88 print ("width, signed", width, signed)
89
90 with m.Switch(Cat(self.mask)):
91 # for each partition possibility, create a Repl sequence
92 for pbit in range(1<<len(keys)):
93 # set up some indices pointing to where things have got
94 # then when called below in the inner nested loop they give
95 # the relevant sequential chunk
96 output = []
97 y = [0]
98 # get a list of the length of each partition run
99 runlengths = get_runlengths(pbit, len(keys))
100 print ("pbit", bin(pbit), "runs", runlengths)
101 for i in runlengths: # for each partition
102 thing = self.get_chunk(y, i) # get sequential chunk
103 output.append(Repl(thing, self.qty)) # and replicate it
104 with m.Case(pbit):
105 # direct access to the underlying Signal
106 comb += self.output.sig.eq(Cat(*output)) # cat all chunks
107
108 return m
109
110 def ports(self):
111 if isinstance(self.repl, PartitionedSignal):
112 return [self.repl.lower(), self.output.lower()]
113 return [self.repl, self.output.lower()]
114
115
116 if __name__ == "__main__":
117 from ieee754.part.test.test_partsig import create_simulator
118 m = Module()
119 mask = Signal(3)
120 a = PartitionedSignal(mask, 32)
121 m.submodules.repl = repl = PartitionedRepl(a, 2, mask)
122 omask = (1<<len(repl.output))-1
123
124 traces = repl.ports()
125 vl = rtlil.convert(repl, ports=traces)
126 with open("part_repl.il", "w") as f:
127 f.write(vl)
128
129 sim = create_simulator(m, traces, "partrepl")
130
131 def process():
132 yield mask.eq(0b000)
133 yield a.sig.eq(0xa12345c7)
134 yield Settle()
135 out = yield repl.output.sig
136 print("out 000", bin(out&omask), hex(out&omask))
137 yield mask.eq(0b010)
138 yield Settle()
139 out = yield repl.output.sig
140 print("out 010", bin(out&omask), hex(out&omask))
141 yield mask.eq(0b110)
142 yield Settle()
143 out = yield repl.output.sig
144 print("out 110", bin(out&omask), hex(out&omask))
145 yield mask.eq(0b111)
146 yield Settle()
147 out = yield repl.output.sig
148 print("out 111", bin(out&omask), hex(out&omask))
149
150 sim.add_process(process)
151 with sim.write_vcd("partition_repl.vcd", "partition_repl.gtkw",
152 traces=traces):
153 sim.run()
154
155 # Scalar
156 m = Module()
157 mask = Signal(3)
158 a = Signal(32)
159 m.submodules.ass = ass = PartitionedRepl(a, 2, mask)
160 omask = (1<<len(ass.output))-1
161
162 traces = ass.ports()
163 sim = create_simulator(m, traces, "partass")
164
165 def process():
166 yield mask.eq(0b000)
167 yield a.eq(0xa12345c7)
168 yield Settle()
169 out = yield ass.output.sig
170 print("out 000", bin(out&omask), hex(out&omask))
171 yield mask.eq(0b010)
172 yield Settle()
173 out = yield ass.output.sig
174 print("out 010", bin(out&omask), hex(out&omask))
175 yield mask.eq(0b110)
176 yield Settle()
177 out = yield ass.output.sig
178 print("out 110", bin(out&omask), hex(out&omask))
179 yield mask.eq(0b111)
180 yield Settle()
181 out = yield ass.output.sig
182 print("out 111", bin(out&omask), hex(out&omask))
183
184 sim.add_process(process)
185 with sim.write_vcd("partition_repl_scalar.vcd",
186 "partition_repl_scalar.gtkw",
187 traces=traces):
188 sim.run()