00cebc6dab33a25f93d4e790dd7883fc859bae9e
[ieee754fpu.git] / src / ieee754 / part_ass / assign.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "assign" class, directly equivalent
8 to nmigen Assign
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/assign
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=709
14
15 """
16
17 from nmigen import Signal, Module, Elaboratable, Cat, Const, signed
18 from nmigen.back.pysim import Simulator, Settle
19 from nmutil.extend import ext
20
21 from ieee754.part_mul_add.partpoints import PartitionPoints
22 from ieee754.part.partsig import PartitionedSignal
23
24
25 def get_runlengths(pbit, size):
26 res = []
27 count = 1
28 # identify where the 1s are, which indicates "start of a new partition"
29 # we want a list of the lengths of all partitions
30 for i in range(size):
31 if pbit & (1<<i): # it's a 1: ends old partition, starts new
32 res.append(count) # add partition
33 count = 1 # start again
34 else:
35 count += 1
36 # end reached, add whatever is left. could have done this by creating
37 # "fake" extra bit on the partitions, but hey
38 res.append(count)
39
40 print ("get_runlengths", bin(pbit), size, res)
41
42 return res
43
44
45 class PartitionedAssign(Elaboratable):
46 def __init__(self, shape, assign, mask):
47 """Create a ``PartitionedAssign`` operator
48 """
49 # work out the length (total of all PartitionedSignals)
50 self.assign = assign
51 if isinstance(mask, dict):
52 mask = list(mask.values())
53 self.mask = mask
54 self.shape = shape
55 self.output = PartitionedSignal(mask, self.shape, reset_less=True)
56 self.partition_points = self.output.partpoints
57 self.mwidth = len(self.partition_points)+1
58
59 def get_chunk(self, y, numparts):
60 x = self.assign
61 if not isinstance(x, PartitionedSignal):
62 # assume Scalar. totally different rules
63 end = numparts * (len(x) // self.mwidth)
64 return x[:end]
65 # PartitionedSignal: start at partition point
66 keys = [0] + list(x.partpoints.keys()) + [len(x)]
67 # get current index and increment it (for next Assign chunk)
68 upto = y[0]
69 y[0] += numparts
70 print ("getting", upto, numparts, keys, len(x))
71 # get the partition point as far as we are up to
72 start = keys[upto]
73 end = keys[upto+numparts]
74 print ("start end", start, end, len(x))
75 return x[start:end]
76
77 def elaborate(self, platform):
78 m = Module()
79 comb = m.d.comb
80
81 keys = list(self.partition_points.keys())
82 print ("keys", keys, "values", self.partition_points.values())
83 print ("mask", self.mask)
84 outpartsize = len(self.output) // self.mwidth
85 width, signed = self.output.shape()
86 print ("width, signed", width, signed)
87
88 with m.Switch(Cat(self.mask)):
89 # for each partition possibility, create a Assign sequence
90 for pbit in range(1<<len(keys)):
91 # set up some indices pointing to where things have got
92 # then when called below in the inner nested loop they give
93 # the relevant sequential chunk
94 output = []
95 y = [0]
96 # get a list of the length of each partition run
97 runlengths = get_runlengths(pbit, len(keys))
98 print ("pbit", bin(pbit), "runs", runlengths)
99 for i in runlengths: # for each partition
100 thing = self.get_chunk(y, i) # sequential chunks
101 # now check the length: truncate, extend or leave-alone
102 outlen = i * outpartsize
103 tlen = len(thing)
104 thing = ext(thing, (tlen, signed), outlen)
105 output.append(thing)
106 with m.Case(pbit):
107 # direct access to the underlying Signal
108 comb += self.output.sig.eq(Cat(*output))
109
110 return m
111
112 def ports(self):
113 if isinstance(self.assign, PartitionedSignal):
114 return [self.assign.lower(), self.output.lower()]
115 return [self.assign, self.output.lower()]
116
117
118 if __name__ == "__main__":
119 from ieee754.part.test.test_partsig import create_simulator
120 m = Module()
121 mask = Signal(3)
122 a = PartitionedSignal(mask, 32)
123 m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
124 omask = (1<<len(ass.output))-1
125
126 traces = ass.ports()
127 sim = create_simulator(m, traces, "partass")
128
129 def process():
130 yield mask.eq(0b000)
131 yield a.sig.eq(0xa12345c7)
132 yield Settle()
133 out = yield ass.output.sig
134 print("out 000", bin(out&omask), hex(out&omask))
135 yield mask.eq(0b010)
136 yield Settle()
137 out = yield ass.output.sig
138 print("out 010", bin(out&omask), hex(out&omask))
139 yield mask.eq(0b110)
140 yield Settle()
141 out = yield ass.output.sig
142 print("out 110", bin(out&omask), hex(out&omask))
143 yield mask.eq(0b111)
144 yield Settle()
145 out = yield ass.output.sig
146 print("out 111", bin(out&omask), hex(out&omask))
147
148 sim.add_process(process)
149 with sim.write_vcd("partition_ass.vcd", "partition_ass.gtkw",
150 traces=traces):
151 sim.run()
152
153 # Scalar
154 m = Module()
155 mask = Signal(3)
156 a = Signal(32)
157 m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
158 omask = (1<<len(ass.output))-1
159
160 traces = ass.ports()
161 sim = create_simulator(m, traces, "partass")
162
163 def process():
164 yield mask.eq(0b000)
165 yield a.eq(0xa12345c7)
166 yield Settle()
167 out = yield ass.output.sig
168 print("out 000", bin(out&omask), hex(out&omask))
169 yield mask.eq(0b010)
170 yield Settle()
171 out = yield ass.output.sig
172 print("out 010", bin(out&omask), hex(out&omask))
173 yield mask.eq(0b110)
174 yield Settle()
175 out = yield ass.output.sig
176 print("out 110", bin(out&omask), hex(out&omask))
177 yield mask.eq(0b111)
178 yield Settle()
179 out = yield ass.output.sig
180 print("out 111", bin(out&omask), hex(out&omask))
181
182 sim.add_process(process)
183 with sim.write_vcd("partition_ass.vcd", "partition_ass.gtkw",
184 traces=traces):
185 sim.run()