add first cut at formal proof for PartitionedXOR
[ieee754fpu.git] / src / ieee754 / part / partsig.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamic-partitionable class similar to Signal, which, when the partition
8 is fully open will be identical to Signal. when partitions are closed,
9 the class turns into a SIMD variant of Signal. *this is dynamic*.
10
11 the basic fundamental idea is: write code once, and if you want a SIMD
12 version of it, use PartitionedSignal in place of Signal. job done.
13 this however requires the code to *not* be designed to use nmigen.If,
14 nmigen.Case, or other constructs: only Mux and other logic.
15
16 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
17 """
18
19 from ieee754.part_mul_add.adder import PartitionedAdder
20 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
21 from ieee754.part_bits.xor import PartitionedXOR
22 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
23 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
24 from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
25 from operator import or_, xor, and_, not_
26
27 from nmigen import (Signal, Const)
28
29
30 def getsig(op1):
31 if isinstance(op1, PartitionedSignal):
32 op1 = op1.sig
33 return op1
34
35
36 def applyop(op1, op2, op):
37 return op(getsig(op1), getsig(op2))
38
39
40 class PartitionedSignal:
41 def __init__(self, mask, *args, **kwargs):
42 self.sig = Signal(*args, **kwargs)
43 width = len(self.sig) # get signal width
44 # create partition points
45 if isinstance(mask, PartitionPoints):
46 self.partpoints = mask
47 else:
48 self.partpoints = make_partition(mask, width)
49 self.modnames = {}
50 for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
51 self.modnames[name] = 0
52
53 def set_module(self, m):
54 self.m = m
55
56 def get_modname(self, category):
57 self.modnames[category] += 1
58 return "%s_%d" % (category, self.modnames[category])
59
60 def eq(self, val):
61 return self.sig.eq(getsig(val))
62
63 # unary ops that do not require partitioning
64
65 def __invert__(self):
66 return ~self.sig
67
68 # unary ops that require partitioning
69
70 def __neg__(self):
71 z = Const(0, len(self.partpoints)+1)
72 result, _ = self.add_op(self, ~0, carry=z) # TODO, subop
73 return result
74
75 # binary ops that don't require partitioning
76
77 def __and__(self, other):
78 return applyop(self, other, and_)
79
80 def __rand__(self, other):
81 return applyop(other, self, and_)
82
83 def __or__(self, other):
84 return applyop(self, other, or_)
85
86 def __ror__(self, other):
87 return applyop(other, self, or_)
88
89 def __xor__(self, other):
90 return applyop(self, other, xor)
91
92 def __rxor__(self, other):
93 return applyop(other, self, xor)
94
95 # binary ops that need partitioning
96
97 # TODO: detect if the 2nd operand is a Const, a Signal or a
98 # PartitionedSignal. if it's a Const or a Signal, a global shift
99 # can occur. if it's a PartitionedSignal, that's much more interesting.
100 def ls_op(self, op1, op2, carry, shr_flag=0):
101 op1 = getsig(op1)
102 if isinstance(op2, Const) or isinstance(op2, Signal):
103 scalar = True
104 pa = PartitionedScalarShift(len(op1), self.partpoints)
105 else:
106 scalar = False
107 op2 = getsig(op2)
108 pa = PartitionedDynamicShift(len(op1), self.partpoints)
109 setattr(self.m.submodules, self.get_modname('ls'), pa)
110 comb = self.m.d.comb
111 if scalar:
112 comb += pa.data.eq(op1)
113 comb += pa.shifter.eq(op2)
114 comb += pa.shift_right.eq(shr_flag)
115 else:
116 comb += pa.a.eq(op1)
117 comb += pa.b.eq(op2)
118 comb += pa.shift_right.eq(shr_flag)
119 # XXX TODO: carry-in, carry-out
120 #comb += pa.carry_in.eq(carry)
121 return (pa.output, 0)
122
123 def __lshift__(self, other):
124 z = Const(0, len(self.partpoints)+1)
125 result, _ = self.ls_op(self, other, carry=z) # TODO, carry
126 return result
127
128 def __rlshift__(self, other):
129 raise NotImplementedError
130 return Operator("<<", [other, self])
131
132 def __rshift__(self, other):
133 z = Const(0, len(self.partpoints)+1)
134 result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
135 return result
136
137 def __rrshift__(self, other):
138 raise NotImplementedError
139 return Operator(">>", [other, self])
140
141 def add_op(self, op1, op2, carry):
142 op1 = getsig(op1)
143 op2 = getsig(op2)
144 pa = PartitionedAdder(len(op1), self.partpoints)
145 setattr(self.m.submodules, self.get_modname('add'), pa)
146 comb = self.m.d.comb
147 comb += pa.a.eq(op1)
148 comb += pa.b.eq(op2)
149 comb += pa.carry_in.eq(carry)
150 return (pa.output, pa.carry_out)
151
152 def sub_op(self, op1, op2, carry=~0):
153 op1 = getsig(op1)
154 op2 = getsig(op2)
155 pa = PartitionedAdder(len(op1), self.partpoints)
156 setattr(self.m.submodules, self.get_modname('add'), pa)
157 comb = self.m.d.comb
158 comb += pa.a.eq(op1)
159 comb += pa.b.eq(~op2)
160 comb += pa.carry_in.eq(carry)
161 return (pa.output, pa.carry_out)
162
163 def __add__(self, other):
164 result, _ = self.add_op(self, other, carry=0)
165 return result
166
167 def __radd__(self, other):
168 result, _ = self.add_op(other, self)
169 return result
170
171 def __sub__(self, other):
172 result, _ = self.sub_op(self, other)
173 return result
174
175 def __rsub__(self, other):
176 result, _ = self.sub_op(other, self)
177 return result
178
179 def __mul__(self, other):
180 return Operator("*", [self, other])
181
182 def __rmul__(self, other):
183 return Operator("*", [other, self])
184
185 def __check_divisor(self):
186 width, signed = self.shape()
187 if signed:
188 # Python's division semantics and Verilog's division semantics
189 # differ for negative divisors (Python uses div/mod, Verilog
190 # uses quo/rem); for now, avoid the issue
191 # completely by prohibiting such division operations.
192 raise NotImplementedError(
193 "Division by a signed value is not supported")
194
195 def __mod__(self, other):
196 raise NotImplementedError
197 other = Value.cast(other)
198 other.__check_divisor()
199 return Operator("%", [self, other])
200
201 def __rmod__(self, other):
202 raise NotImplementedError
203 self.__check_divisor()
204 return Operator("%", [other, self])
205
206 def __floordiv__(self, other):
207 raise NotImplementedError
208 other = Value.cast(other)
209 other.__check_divisor()
210 return Operator("//", [self, other])
211
212 def __rfloordiv__(self, other):
213 raise NotImplementedError
214 self.__check_divisor()
215 return Operator("//", [other, self])
216
217 # binary comparison ops that need partitioning
218
219 def _compare(self, width, op1, op2, opname, optype):
220 # print (opname, op1, op2)
221 pa = PartitionedEqGtGe(width, self.partpoints)
222 setattr(self.m.submodules, self.get_modname(opname), pa)
223 comb = self.m.d.comb
224 comb += pa.opcode.eq(optype) # set opcode
225 if isinstance(op1, PartitionedSignal):
226 comb += pa.a.eq(op1.sig)
227 else:
228 comb += pa.a.eq(op1)
229 if isinstance(op2, PartitionedSignal):
230 comb += pa.b.eq(op2.sig)
231 else:
232 comb += pa.b.eq(op2)
233 return pa.output
234
235 def __eq__(self, other):
236 width = len(self.sig)
237 return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
238
239 def __ne__(self, other):
240 width = len(self.sig)
241 eq = self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
242 ne = Signal(eq.width)
243 self.m.d.comb += ne.eq(~eq)
244 return ne
245
246 def __gt__(self, other):
247 width = len(self.sig)
248 return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
249
250 def __lt__(self, other):
251 width = len(self.sig)
252 # swap operands, use gt to do lt
253 return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
254
255 def __ge__(self, other):
256 width = len(self.sig)
257 return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
258
259 def __le__(self, other):
260 width = len(self.sig)
261 # swap operands, use ge to do le
262 return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
263
264 # useful operators
265
266 def bool(self):
267 """Conversion to boolean.
268
269 Returns
270 -------
271 Value, out
272 ``1`` if any bits are set, ``0`` otherwise.
273 """
274 return self.any() # have to see how this goes
275 #return Operator("b", [self])
276
277 def any(self):
278 """Check if any bits are ``1``.
279
280 Returns
281 -------
282 Value, out
283 ``1`` if any bits are set, ``0`` otherwise.
284 """
285 return self != Const(0) # leverage the __ne__ operator here
286 return Operator("r|", [self])
287
288 def all(self):
289 """Check if all bits are ``1``.
290
291 Returns
292 -------
293 Value, out
294 ``1`` if all bits are set, ``0`` otherwise.
295 """
296 return self == Const(-1) # leverage the __eq__ operator here
297
298 def xor(self):
299 """Compute pairwise exclusive-or of every bit.
300
301 Returns
302 -------
303 Value, out
304 ``1`` if an odd number of bits are set, ``0`` if an
305 even number of bits are set.
306 """
307 width = len(self.sig)
308 pa = PartitionedXOR(width, self.partpoints)
309 setattr(self.m.submodules, self.get_modname("xor"), pa)
310 self.m.d.comb += pa.a.eq(self.sig)
311 return pa.output
312
313 def implies(premise, conclusion):
314 """Implication.
315
316 Returns
317 -------
318 Value, out
319 ``0`` if ``premise`` is true and ``conclusion`` is not,
320 ``1`` otherwise.
321 """
322 # amazingly, this should actually work.
323 return ~premise | conclusion