f09b972a4c93e3a07e1f869c1f9a8ff16952743a
[ieee754fpu.git] / src / ieee754 / part / partsig.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamic-partitionable class similar to Signal, which, when the partition
8 is fully open will be identical to Signal. when partitions are closed,
9 the class turns into a SIMD variant of Signal. *this is dynamic*.
10
11 the basic fundamental idea is: write code once, and if you want a SIMD
12 version of it, use PartitionedSignal in place of Signal. job done.
13 this however requires the code to *not* be designed to use nmigen.If,
14 nmigen.Case, or other constructs: only Mux and other logic.
15
16 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
17 """
18
19 from ieee754.part_mul_add.adder import PartitionedAdder
20 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
21 from ieee754.part_bits.xor import PartitionedXOR
22 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
23 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
24 from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
25 from operator import or_, xor, and_, not_
26
27 from nmigen import (Signal, Const)
28
29
30 def getsig(op1):
31 if isinstance(op1, PartitionedSignal):
32 op1 = op1.sig
33 return op1
34
35
36 def applyop(op1, op2, op):
37 if isinstance(op1, PartitionedSignal):
38 result = PartitionedSignal.like(op1)
39 else:
40 result = PartitionedSignal.like(op2)
41 result.m.d.comb += result.sig.eq(op(getsig(op1), getsig(op2)))
42 return result
43
44
45 class PartitionedSignal:
46 def __init__(self, mask, *args, **kwargs):
47 self.sig = Signal(*args, **kwargs)
48 width = len(self.sig) # get signal width
49 # create partition points
50 if isinstance(mask, PartitionPoints):
51 self.partpoints = mask
52 else:
53 self.partpoints = make_partition(mask, width)
54 self.modnames = {}
55 for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
56 self.modnames[name] = 0
57
58 def set_module(self, m):
59 self.m = m
60
61 def get_modname(self, category):
62 self.modnames[category] += 1
63 return "%s_%d" % (category, self.modnames[category])
64
65 def eq(self, val):
66 return self.sig.eq(getsig(val))
67
68 @staticmethod
69 def like(other, *args, **kwargs):
70 """Builds a new PartitionedSignal with the same PartitionPoints and
71 Signal properties as the other"""
72 result = PartitionedSignal(other.partpoints)
73 result.sig = Signal.like(other.sig, *args, **kwargs)
74 result.m = other.m
75 return result
76
77 # unary ops that do not require partitioning
78
79 def __invert__(self):
80 result = PartitionedSignal.like(self)
81 self.m.d.comb += result.sig.eq(~self.sig)
82 return result
83
84 # unary ops that require partitioning
85
86 def __neg__(self):
87 z = Const(0, len(self.sig))
88 result, _ = self.sub_op(z, self)
89 return result
90
91 # binary ops that don't require partitioning
92
93 def __and__(self, other):
94 return applyop(self, other, and_)
95
96 def __rand__(self, other):
97 return applyop(other, self, and_)
98
99 def __or__(self, other):
100 return applyop(self, other, or_)
101
102 def __ror__(self, other):
103 return applyop(other, self, or_)
104
105 def __xor__(self, other):
106 return applyop(self, other, xor)
107
108 def __rxor__(self, other):
109 return applyop(other, self, xor)
110
111 # binary ops that need partitioning
112
113 # TODO: detect if the 2nd operand is a Const, a Signal or a
114 # PartitionedSignal. if it's a Const or a Signal, a global shift
115 # can occur. if it's a PartitionedSignal, that's much more interesting.
116 def ls_op(self, op1, op2, carry, shr_flag=0):
117 op1 = getsig(op1)
118 if isinstance(op2, Const) or isinstance(op2, Signal):
119 scalar = True
120 pa = PartitionedScalarShift(len(op1), self.partpoints)
121 else:
122 scalar = False
123 op2 = getsig(op2)
124 pa = PartitionedDynamicShift(len(op1), self.partpoints)
125 setattr(self.m.submodules, self.get_modname('ls'), pa)
126 comb = self.m.d.comb
127 if scalar:
128 comb += pa.data.eq(op1)
129 comb += pa.shifter.eq(op2)
130 comb += pa.shift_right.eq(shr_flag)
131 else:
132 comb += pa.a.eq(op1)
133 comb += pa.b.eq(op2)
134 comb += pa.shift_right.eq(shr_flag)
135 # XXX TODO: carry-in, carry-out
136 #comb += pa.carry_in.eq(carry)
137 return (pa.output, 0)
138
139 def __lshift__(self, other):
140 z = Const(0, len(self.partpoints)+1)
141 result, _ = self.ls_op(self, other, carry=z) # TODO, carry
142 return result
143
144 def __rlshift__(self, other):
145 raise NotImplementedError
146 return Operator("<<", [other, self])
147
148 def __rshift__(self, other):
149 z = Const(0, len(self.partpoints)+1)
150 result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
151 return result
152
153 def __rrshift__(self, other):
154 raise NotImplementedError
155 return Operator(">>", [other, self])
156
157 def add_op(self, op1, op2, carry):
158 op1 = getsig(op1)
159 op2 = getsig(op2)
160 pa = PartitionedAdder(len(op1), self.partpoints)
161 setattr(self.m.submodules, self.get_modname('add'), pa)
162 comb = self.m.d.comb
163 comb += pa.a.eq(op1)
164 comb += pa.b.eq(op2)
165 comb += pa.carry_in.eq(carry)
166 result = PartitionedSignal.like(self)
167 comb += result.sig.eq(pa.output)
168 return result, pa.carry_out
169
170 def sub_op(self, op1, op2, carry=~0):
171 op1 = getsig(op1)
172 op2 = getsig(op2)
173 pa = PartitionedAdder(len(op1), self.partpoints)
174 setattr(self.m.submodules, self.get_modname('add'), pa)
175 comb = self.m.d.comb
176 comb += pa.a.eq(op1)
177 comb += pa.b.eq(~op2)
178 comb += pa.carry_in.eq(carry)
179 result = PartitionedSignal.like(self)
180 comb += result.sig.eq(pa.output)
181 return result, pa.carry_out
182
183 def __add__(self, other):
184 result, _ = self.add_op(self, other, carry=0)
185 return result
186
187 def __radd__(self, other):
188 result, _ = self.add_op(other, self)
189 return result
190
191 def __sub__(self, other):
192 result, _ = self.sub_op(self, other)
193 return result
194
195 def __rsub__(self, other):
196 result, _ = self.sub_op(other, self)
197 return result
198
199 def __mul__(self, other):
200 return Operator("*", [self, other])
201
202 def __rmul__(self, other):
203 return Operator("*", [other, self])
204
205 def __check_divisor(self):
206 width, signed = self.shape()
207 if signed:
208 # Python's division semantics and Verilog's division semantics
209 # differ for negative divisors (Python uses div/mod, Verilog
210 # uses quo/rem); for now, avoid the issue
211 # completely by prohibiting such division operations.
212 raise NotImplementedError(
213 "Division by a signed value is not supported")
214
215 def __mod__(self, other):
216 raise NotImplementedError
217 other = Value.cast(other)
218 other.__check_divisor()
219 return Operator("%", [self, other])
220
221 def __rmod__(self, other):
222 raise NotImplementedError
223 self.__check_divisor()
224 return Operator("%", [other, self])
225
226 def __floordiv__(self, other):
227 raise NotImplementedError
228 other = Value.cast(other)
229 other.__check_divisor()
230 return Operator("//", [self, other])
231
232 def __rfloordiv__(self, other):
233 raise NotImplementedError
234 self.__check_divisor()
235 return Operator("//", [other, self])
236
237 # binary comparison ops that need partitioning
238
239 def _compare(self, width, op1, op2, opname, optype):
240 # print (opname, op1, op2)
241 pa = PartitionedEqGtGe(width, self.partpoints)
242 setattr(self.m.submodules, self.get_modname(opname), pa)
243 comb = self.m.d.comb
244 comb += pa.opcode.eq(optype) # set opcode
245 if isinstance(op1, PartitionedSignal):
246 comb += pa.a.eq(op1.sig)
247 else:
248 comb += pa.a.eq(op1)
249 if isinstance(op2, PartitionedSignal):
250 comb += pa.b.eq(op2.sig)
251 else:
252 comb += pa.b.eq(op2)
253 return pa.output
254
255 def __eq__(self, other):
256 width = len(self.sig)
257 return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
258
259 def __ne__(self, other):
260 width = len(self.sig)
261 eq = self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
262 ne = Signal(eq.width)
263 self.m.d.comb += ne.eq(~eq)
264 return ne
265
266 def __gt__(self, other):
267 width = len(self.sig)
268 return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
269
270 def __lt__(self, other):
271 width = len(self.sig)
272 # swap operands, use gt to do lt
273 return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
274
275 def __ge__(self, other):
276 width = len(self.sig)
277 return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
278
279 def __le__(self, other):
280 width = len(self.sig)
281 # swap operands, use ge to do le
282 return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
283
284 # useful operators
285
286 def bool(self):
287 """Conversion to boolean.
288
289 Returns
290 -------
291 Value, out
292 ``1`` if any bits are set, ``0`` otherwise.
293 """
294 return self.any() # have to see how this goes
295 #return Operator("b", [self])
296
297 def any(self):
298 """Check if any bits are ``1``.
299
300 Returns
301 -------
302 Value, out
303 ``1`` if any bits are set, ``0`` otherwise.
304 """
305 return self != Const(0) # leverage the __ne__ operator here
306 return Operator("r|", [self])
307
308 def all(self):
309 """Check if all bits are ``1``.
310
311 Returns
312 -------
313 Value, out
314 ``1`` if all bits are set, ``0`` otherwise.
315 """
316 return self == Const(-1) # leverage the __eq__ operator here
317
318 def xor(self):
319 """Compute pairwise exclusive-or of every bit.
320
321 Returns
322 -------
323 Value, out
324 ``1`` if an odd number of bits are set, ``0`` if an
325 even number of bits are set.
326 """
327 width = len(self.sig)
328 pa = PartitionedXOR(width, self.partpoints)
329 setattr(self.m.submodules, self.get_modname("xor"), pa)
330 self.m.d.comb += pa.a.eq(self.sig)
331 return pa.output
332
333 def implies(premise, conclusion):
334 """Implication.
335
336 Returns
337 -------
338 Value, out
339 ``0`` if ``premise`` is true and ``conclusion`` is not,
340 ``1`` otherwise.
341 """
342 # amazingly, this should actually work.
343 return ~premise | conclusion