derive PartitionedSignal from UserValue (temporarily) and add lower()
[ieee754fpu.git] / src / ieee754 / part / partsig.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamic-partitionable class similar to Signal, which, when the partition
8 is fully open will be identical to Signal. when partitions are closed,
9 the class turns into a SIMD variant of Signal. *this is dynamic*.
10
11 the basic fundamental idea is: write code once, and if you want a SIMD
12 version of it, use PartitionedSignal in place of Signal. job done.
13 this however requires the code to *not* be designed to use nmigen.If,
14 nmigen.Case, or other constructs: only Mux and other logic.
15
16 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
17 """
18
19 from ieee754.part_mul_add.adder import PartitionedAdder
20 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
21 from ieee754.part_bits.xor import PartitionedXOR
22 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
23 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
24 from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
25 from operator import or_, xor, and_, not_
26
27 from nmigen import (Signal, Const)
28 from nmigen.hdl.ast import UserValue
29
30
31 def getsig(op1):
32 if isinstance(op1, PartitionedSignal):
33 op1 = op1.sig
34 return op1
35
36
37 def applyop(op1, op2, op):
38 if isinstance(op1, PartitionedSignal):
39 result = PartitionedSignal.like(op1)
40 else:
41 result = PartitionedSignal.like(op2)
42 result.m.d.comb += result.sig.eq(op(getsig(op1), getsig(op2)))
43 return result
44
45
46 class PartitionedSignal(UserValue):
47 def __init__(self, mask, *args, src_loc_at=0, **kwargs):
48 super().__init__(src_loc_at=src_loc_at)
49 self.sig = Signal(*args, **kwargs)
50 width = len(self.sig) # get signal width
51 # create partition points
52 if isinstance(mask, PartitionPoints):
53 self.partpoints = mask
54 else:
55 self.partpoints = make_partition(mask, width)
56 self.modnames = {}
57 for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
58 self.modnames[name] = 0
59
60 def lower(self):
61 return self.sig
62
63 def set_module(self, m):
64 self.m = m
65
66 def get_modname(self, category):
67 self.modnames[category] += 1
68 return "%s_%d" % (category, self.modnames[category])
69
70 def eq(self, val):
71 return self.sig.eq(getsig(val))
72
73 @staticmethod
74 def like(other, *args, **kwargs):
75 """Builds a new PartitionedSignal with the same PartitionPoints and
76 Signal properties as the other"""
77 result = PartitionedSignal(other.partpoints)
78 result.sig = Signal.like(other.sig, *args, **kwargs)
79 result.m = other.m
80 return result
81
82 # unary ops that do not require partitioning
83
84 def __invert__(self):
85 result = PartitionedSignal.like(self)
86 self.m.d.comb += result.sig.eq(~self.sig)
87 return result
88
89 # unary ops that require partitioning
90
91 def __neg__(self):
92 z = Const(0, len(self.sig))
93 result, _ = self.sub_op(z, self)
94 return result
95
96 # binary ops that don't require partitioning
97
98 def __and__(self, other):
99 return applyop(self, other, and_)
100
101 def __rand__(self, other):
102 return applyop(other, self, and_)
103
104 def __or__(self, other):
105 return applyop(self, other, or_)
106
107 def __ror__(self, other):
108 return applyop(other, self, or_)
109
110 def __xor__(self, other):
111 return applyop(self, other, xor)
112
113 def __rxor__(self, other):
114 return applyop(other, self, xor)
115
116 # binary ops that need partitioning
117
118 # TODO: detect if the 2nd operand is a Const, a Signal or a
119 # PartitionedSignal. if it's a Const or a Signal, a global shift
120 # can occur. if it's a PartitionedSignal, that's much more interesting.
121 def ls_op(self, op1, op2, carry, shr_flag=0):
122 op1 = getsig(op1)
123 if isinstance(op2, Const) or isinstance(op2, Signal):
124 scalar = True
125 pa = PartitionedScalarShift(len(op1), self.partpoints)
126 else:
127 scalar = False
128 op2 = getsig(op2)
129 pa = PartitionedDynamicShift(len(op1), self.partpoints)
130 setattr(self.m.submodules, self.get_modname('ls'), pa)
131 comb = self.m.d.comb
132 if scalar:
133 comb += pa.data.eq(op1)
134 comb += pa.shifter.eq(op2)
135 comb += pa.shift_right.eq(shr_flag)
136 else:
137 comb += pa.a.eq(op1)
138 comb += pa.b.eq(op2)
139 comb += pa.shift_right.eq(shr_flag)
140 # XXX TODO: carry-in, carry-out
141 #comb += pa.carry_in.eq(carry)
142 return (pa.output, 0)
143
144 def __lshift__(self, other):
145 z = Const(0, len(self.partpoints)+1)
146 result, _ = self.ls_op(self, other, carry=z) # TODO, carry
147 return result
148
149 def __rlshift__(self, other):
150 raise NotImplementedError
151 return Operator("<<", [other, self])
152
153 def __rshift__(self, other):
154 z = Const(0, len(self.partpoints)+1)
155 result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
156 return result
157
158 def __rrshift__(self, other):
159 raise NotImplementedError
160 return Operator(">>", [other, self])
161
162 def add_op(self, op1, op2, carry):
163 op1 = getsig(op1)
164 op2 = getsig(op2)
165 pa = PartitionedAdder(len(op1), self.partpoints)
166 setattr(self.m.submodules, self.get_modname('add'), pa)
167 comb = self.m.d.comb
168 comb += pa.a.eq(op1)
169 comb += pa.b.eq(op2)
170 comb += pa.carry_in.eq(carry)
171 result = PartitionedSignal.like(self)
172 comb += result.sig.eq(pa.output)
173 return result, pa.carry_out
174
175 def sub_op(self, op1, op2, carry=~0):
176 op1 = getsig(op1)
177 op2 = getsig(op2)
178 pa = PartitionedAdder(len(op1), self.partpoints)
179 setattr(self.m.submodules, self.get_modname('add'), pa)
180 comb = self.m.d.comb
181 comb += pa.a.eq(op1)
182 comb += pa.b.eq(~op2)
183 comb += pa.carry_in.eq(carry)
184 result = PartitionedSignal.like(self)
185 comb += result.sig.eq(pa.output)
186 return result, pa.carry_out
187
188 def __add__(self, other):
189 result, _ = self.add_op(self, other, carry=0)
190 return result
191
192 def __radd__(self, other):
193 result, _ = self.add_op(other, self)
194 return result
195
196 def __sub__(self, other):
197 result, _ = self.sub_op(self, other)
198 return result
199
200 def __rsub__(self, other):
201 result, _ = self.sub_op(other, self)
202 return result
203
204 def __mul__(self, other):
205 return Operator("*", [self, other])
206
207 def __rmul__(self, other):
208 return Operator("*", [other, self])
209
210 def __check_divisor(self):
211 width, signed = self.shape()
212 if signed:
213 # Python's division semantics and Verilog's division semantics
214 # differ for negative divisors (Python uses div/mod, Verilog
215 # uses quo/rem); for now, avoid the issue
216 # completely by prohibiting such division operations.
217 raise NotImplementedError(
218 "Division by a signed value is not supported")
219
220 def __mod__(self, other):
221 raise NotImplementedError
222 other = Value.cast(other)
223 other.__check_divisor()
224 return Operator("%", [self, other])
225
226 def __rmod__(self, other):
227 raise NotImplementedError
228 self.__check_divisor()
229 return Operator("%", [other, self])
230
231 def __floordiv__(self, other):
232 raise NotImplementedError
233 other = Value.cast(other)
234 other.__check_divisor()
235 return Operator("//", [self, other])
236
237 def __rfloordiv__(self, other):
238 raise NotImplementedError
239 self.__check_divisor()
240 return Operator("//", [other, self])
241
242 # binary comparison ops that need partitioning
243
244 def _compare(self, width, op1, op2, opname, optype):
245 # print (opname, op1, op2)
246 pa = PartitionedEqGtGe(width, self.partpoints)
247 setattr(self.m.submodules, self.get_modname(opname), pa)
248 comb = self.m.d.comb
249 comb += pa.opcode.eq(optype) # set opcode
250 if isinstance(op1, PartitionedSignal):
251 comb += pa.a.eq(op1.sig)
252 else:
253 comb += pa.a.eq(op1)
254 if isinstance(op2, PartitionedSignal):
255 comb += pa.b.eq(op2.sig)
256 else:
257 comb += pa.b.eq(op2)
258 return pa.output
259
260 def __eq__(self, other):
261 width = len(self.sig)
262 return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
263
264 def __ne__(self, other):
265 width = len(self.sig)
266 eq = self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
267 ne = Signal(eq.width)
268 self.m.d.comb += ne.eq(~eq)
269 return ne
270
271 def __gt__(self, other):
272 width = len(self.sig)
273 return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
274
275 def __lt__(self, other):
276 width = len(self.sig)
277 # swap operands, use gt to do lt
278 return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
279
280 def __ge__(self, other):
281 width = len(self.sig)
282 return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
283
284 def __le__(self, other):
285 width = len(self.sig)
286 # swap operands, use ge to do le
287 return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
288
289 # useful operators
290
291 def bool(self):
292 """Conversion to boolean.
293
294 Returns
295 -------
296 Value, out
297 ``1`` if any bits are set, ``0`` otherwise.
298 """
299 return self.any() # have to see how this goes
300 #return Operator("b", [self])
301
302 def any(self):
303 """Check if any bits are ``1``.
304
305 Returns
306 -------
307 Value, out
308 ``1`` if any bits are set, ``0`` otherwise.
309 """
310 return self != Const(0) # leverage the __ne__ operator here
311 return Operator("r|", [self])
312
313 def all(self):
314 """Check if all bits are ``1``.
315
316 Returns
317 -------
318 Value, out
319 ``1`` if all bits are set, ``0`` otherwise.
320 """
321 return self == Const(-1) # leverage the __eq__ operator here
322
323 def xor(self):
324 """Compute pairwise exclusive-or of every bit.
325
326 Returns
327 -------
328 Value, out
329 ``1`` if an odd number of bits are set, ``0`` if an
330 even number of bits are set.
331 """
332 width = len(self.sig)
333 pa = PartitionedXOR(width, self.partpoints)
334 setattr(self.m.submodules, self.get_modname("xor"), pa)
335 self.m.d.comb += pa.a.eq(self.sig)
336 return pa.output
337
338 def implies(premise, conclusion):
339 """Implication.
340
341 Returns
342 -------
343 Value, out
344 ``0`` if ``premise`` is true and ``conclusion`` is not,
345 ``1`` otherwise.
346 """
347 # amazingly, this should actually work.
348 return ~premise | conclusion