whoops accidentally removed bugreport link
[ieee754fpu.git] / src / ieee754 / part / partsig.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamic-partitionable class similar to Signal, which, when the partition
8 is fully open will be identical to Signal. when partitions are closed,
9 the class turns into a SIMD variant of Signal. *this is dynamic*.
10
11 the basic fundamental idea is: write code once, and if you want a SIMD
12 version of it, use PartitionedSignal in place of Signal. job done.
13 this however requires the code to *not* be designed to use nmigen.If,
14 nmigen.Case, or other constructs: only Mux and other logic.
15
16 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
17 """
18
19 from ieee754.part_mul_add.adder import PartitionedAdder
20 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
21 from ieee754.part_bits.xor import PartitionedXOR
22 from ieee754.part_bits.bool import PartitionedBool
23 from ieee754.part_bits.all import PartitionedAll
24 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
25 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
26 from ieee754.part_mul_add.partpoints import make_partition2, PartitionPoints
27 from ieee754.part_mux.part_mux import PMux
28 from ieee754.part_ass.passign import PAssign
29 from ieee754.part_cat.pcat import PCat
30 from operator import or_, xor, and_, not_
31
32 from nmigen import (Signal, Const)
33 from nmigen.hdl.ast import UserValue, Shape
34
35
36 def getsig(op1):
37 if isinstance(op1, PartitionedSignal):
38 op1 = op1.sig
39 return op1
40
41
42 def applyop(op1, op2, op):
43 if isinstance(op1, PartitionedSignal):
44 result = PartitionedSignal.like(op1)
45 else:
46 result = PartitionedSignal.like(op2)
47 result.m.d.comb += result.sig.eq(op(getsig(op1), getsig(op2)))
48 return result
49
50 global modnames
51 modnames = {}
52 # for sub-modules to be created on-demand. Mux is done slightly
53 # differently (has its own global)
54 for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool', 'all']:
55 modnames[name] = 0
56
57
58 class PartitionedSignal(UserValue):
59 # XXX ################################################### XXX
60 # XXX Keep these functions in the same order as ast.Value XXX
61 # XXX ################################################### XXX
62 def __init__(self, mask, *args, src_loc_at=0, **kwargs):
63 super().__init__(src_loc_at=src_loc_at)
64 self.sig = Signal(*args, **kwargs)
65 width = len(self.sig) # get signal width
66 # create partition points
67 if isinstance(mask, PartitionPoints):
68 self.partpoints = mask
69 else:
70 self.partpoints = make_partition2(mask, width)
71
72
73 def set_module(self, m):
74 self.m = m
75
76 def get_modname(self, category):
77 modnames[category] += 1
78 return "%s_%d" % (category, modnames[category])
79
80 @staticmethod
81 def like(other, *args, **kwargs):
82 """Builds a new PartitionedSignal with the same PartitionPoints and
83 Signal properties as the other"""
84 result = PartitionedSignal(PartitionPoints(other.partpoints))
85 result.sig = Signal.like(other.sig, *args, **kwargs)
86 result.m = other.m
87 return result
88
89 def lower(self):
90 return self.sig
91 # now using __Assign__
92 #def eq(self, val):
93 # return self.sig.eq(getsig(val))
94
95 # nmigen-redirected constructs (Mux, Cat, Switch, Assign)
96
97 def __Mux__(self, val1, val2):
98 # print ("partsig mux", self, val1, val2)
99 assert len(val1) == len(val2), \
100 "PartitionedSignal width sources must be the same " \
101 "val1 == %d, val2 == %d" % (len(val1), len(val2))
102 return PMux(self.m, self.partpoints, self, val1, val2)
103
104 # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=458
105 #def __Switch__(self, cases, *, src_loc=None, src_loc_at=0,
106 # case_src_locs={}):
107
108 def __Cat__(self, *args, src_loc_at=0):
109 args = [self] + list(args)
110 for sig in args:
111 assert isinstance(sig, PartitionedSignal), \
112 "All PartitionedSignal.__Cat__ arguments must be " \
113 "a PartitionedSignal. %s is not." % repr(sig)
114 return PCat(self.m, args, self.partpoints)
115
116 def __Assign__(self, val, *, src_loc_at=0):
117 # print ("partsig ass", self, val)
118 return PAssign(self.m, self, val, self.partpoints)
119
120 # no override needed, Value.__bool__ sufficient
121 # def __bool__(self):
122
123 # unary ops that do not require partitioning
124
125 def __invert__(self):
126 result = PartitionedSignal.like(self)
127 self.m.d.comb += result.sig.eq(~self.sig)
128 return result
129
130 # unary ops that require partitioning
131
132 def __neg__(self):
133 z = Const(0, len(self.sig))
134 result, _ = self.sub_op(z, self)
135 return result
136
137 # binary ops that need partitioning
138
139 def add_op(self, op1, op2, carry):
140 op1 = getsig(op1)
141 op2 = getsig(op2)
142 pa = PartitionedAdder(len(op1), self.partpoints)
143 setattr(self.m.submodules, self.get_modname('add'), pa)
144 comb = self.m.d.comb
145 comb += pa.a.eq(op1)
146 comb += pa.b.eq(op2)
147 comb += pa.carry_in.eq(carry)
148 result = PartitionedSignal.like(self)
149 comb += result.sig.eq(pa.output)
150 return result, pa.carry_out
151
152 def sub_op(self, op1, op2, carry=~0):
153 op1 = getsig(op1)
154 op2 = getsig(op2)
155 pa = PartitionedAdder(len(op1), self.partpoints)
156 setattr(self.m.submodules, self.get_modname('add'), pa)
157 comb = self.m.d.comb
158 comb += pa.a.eq(op1)
159 comb += pa.b.eq(~op2)
160 comb += pa.carry_in.eq(carry)
161 result = PartitionedSignal.like(self)
162 comb += result.sig.eq(pa.output)
163 return result, pa.carry_out
164
165 def __add__(self, other):
166 result, _ = self.add_op(self, other, carry=0)
167 return result
168
169 def __radd__(self, other):
170 # https://bugs.libre-soc.org/show_bug.cgi?id=718
171 result, _ = self.add_op(other, self)
172 return result
173
174 def __sub__(self, other):
175 result, _ = self.sub_op(self, other)
176 return result
177
178 def __rsub__(self, other):
179 # https://bugs.libre-soc.org/show_bug.cgi?id=718
180 result, _ = self.sub_op(other, self)
181 return result
182
183 def __mul__(self, other):
184 raise NotImplementedError # too complicated at the moment
185 return Operator("*", [self, other])
186
187 def __rmul__(self, other):
188 raise NotImplementedError # too complicated at the moment
189 return Operator("*", [other, self])
190
191 # not needed: same as Value.__check_divisor
192 #def __check_divisor(self):
193
194 def __mod__(self, other):
195 raise NotImplementedError
196 other = Value.cast(other)
197 other.__check_divisor()
198 return Operator("%", [self, other])
199
200 def __rmod__(self, other):
201 raise NotImplementedError
202 self.__check_divisor()
203 return Operator("%", [other, self])
204
205 def __floordiv__(self, other):
206 raise NotImplementedError
207 other = Value.cast(other)
208 other.__check_divisor()
209 return Operator("//", [self, other])
210
211 def __rfloordiv__(self, other):
212 raise NotImplementedError
213 self.__check_divisor()
214 return Operator("//", [other, self])
215
216 # not needed: same as Value.__check_shamt
217 #def __check_shamt(self):
218
219 # TODO: detect if the 2nd operand is a Const, a Signal or a
220 # PartitionedSignal. if it's a Const or a Signal, a global shift
221 # can occur. if it's a PartitionedSignal, that's much more interesting.
222 def ls_op(self, op1, op2, carry, shr_flag=0):
223 op1 = getsig(op1)
224 if isinstance(op2, Const) or isinstance(op2, Signal):
225 scalar = True
226 pa = PartitionedScalarShift(len(op1), self.partpoints)
227 else:
228 scalar = False
229 op2 = getsig(op2)
230 pa = PartitionedDynamicShift(len(op1), self.partpoints)
231 # else:
232 # TODO: case where the *shifter* is a PartitionedSignal but
233 # the thing *being* Shifted is a scalar (Signal, expression)
234 # https://bugs.libre-soc.org/show_bug.cgi?id=718
235 setattr(self.m.submodules, self.get_modname('ls'), pa)
236 comb = self.m.d.comb
237 if scalar:
238 comb += pa.data.eq(op1)
239 comb += pa.shifter.eq(op2)
240 comb += pa.shift_right.eq(shr_flag)
241 else:
242 comb += pa.a.eq(op1)
243 comb += pa.b.eq(op2)
244 comb += pa.shift_right.eq(shr_flag)
245 # XXX TODO: carry-in, carry-out (for arithmetic shift)
246 #comb += pa.carry_in.eq(carry)
247 return (pa.output, 0)
248
249 def __lshift__(self, other):
250 z = Const(0, len(self.partpoints)+1)
251 result, _ = self.ls_op(self, other, carry=z) # TODO, carry
252 return result
253
254 def __rlshift__(self, other):
255 # https://bugs.libre-soc.org/show_bug.cgi?id=718
256 raise NotImplementedError
257 return Operator("<<", [other, self])
258
259 def __rshift__(self, other):
260 z = Const(0, len(self.partpoints)+1)
261 result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
262 return result
263
264 def __rrshift__(self, other):
265 # https://bugs.libre-soc.org/show_bug.cgi?id=718
266 raise NotImplementedError
267 return Operator(">>", [other, self])
268
269 # binary ops that don't require partitioning
270
271 def __and__(self, other):
272 return applyop(self, other, and_)
273
274 def __rand__(self, other):
275 return applyop(other, self, and_)
276
277 def __or__(self, other):
278 return applyop(self, other, or_)
279
280 def __ror__(self, other):
281 return applyop(other, self, or_)
282
283 def __xor__(self, other):
284 return applyop(self, other, xor)
285
286 def __rxor__(self, other):
287 return applyop(other, self, xor)
288
289 # binary comparison ops that need partitioning
290
291 def _compare(self, width, op1, op2, opname, optype):
292 # print (opname, op1, op2)
293 pa = PartitionedEqGtGe(width, self.partpoints)
294 setattr(self.m.submodules, self.get_modname(opname), pa)
295 comb = self.m.d.comb
296 comb += pa.opcode.eq(optype) # set opcode
297 if isinstance(op1, PartitionedSignal):
298 comb += pa.a.eq(op1.sig)
299 else:
300 comb += pa.a.eq(op1)
301 if isinstance(op2, PartitionedSignal):
302 comb += pa.b.eq(op2.sig)
303 else:
304 comb += pa.b.eq(op2)
305 return pa.output
306
307 def __eq__(self, other):
308 width = len(self.sig)
309 return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
310
311 def __ne__(self, other):
312 width = len(self.sig)
313 eq = self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
314 ne = Signal(eq.width)
315 self.m.d.comb += ne.eq(~eq)
316 return ne
317
318 def __lt__(self, other):
319 width = len(self.sig)
320 # swap operands, use gt to do lt
321 return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
322
323 def __le__(self, other):
324 width = len(self.sig)
325 # swap operands, use ge to do le
326 return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
327
328 def __gt__(self, other):
329 width = len(self.sig)
330 return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
331
332 def __ge__(self, other):
333 width = len(self.sig)
334 return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
335
336 # no override needed: Value.__abs__ is general enough it does the job
337 # def __abs__(self):
338
339 def __len__(self):
340 return len(self.sig)
341
342 # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=716
343 # def __getitem__(self, key):
344
345 def __new_sign(self, signed):
346 shape = Shape(len(self), signed=signed)
347 result = PartitionedSignal.like(self, shape=shape)
348 self.m.d.comb += result.sig.eq(self.sig)
349 return result
350
351 # http://bugs.libre-riscv.org/show_bug.cgi?id=719
352 def as_unsigned(self):
353 return self.__new_sign(False)
354 def as_signed(self):
355 return self.__new_sign(True)
356
357 # useful operators
358
359 def bool(self):
360 """Conversion to boolean.
361
362 Returns
363 -------
364 Value, out
365 ``1`` if any bits are set, ``0`` otherwise.
366 """
367 width = len(self.sig)
368 pa = PartitionedBool(width, self.partpoints)
369 setattr(self.m.submodules, self.get_modname("bool"), pa)
370 self.m.d.comb += pa.a.eq(self.sig)
371 return pa.output
372
373 def any(self):
374 """Check if any bits are ``1``.
375
376 Returns
377 -------
378 Value, out
379 ``1`` if any bits are set, ``0`` otherwise.
380 """
381 return self != Const(0) # leverage the __ne__ operator here
382 return Operator("r|", [self])
383
384 def all(self):
385 """Check if all bits are ``1``.
386
387 Returns
388 -------
389 Value, out
390 ``1`` if all bits are set, ``0`` otherwise.
391 """
392 # something wrong with PartitionedAll, but self == Const(-1)"
393 # XXX https://bugs.libre-soc.org/show_bug.cgi?id=176#c17
394 #width = len(self.sig)
395 #pa = PartitionedAll(width, self.partpoints)
396 #setattr(self.m.submodules, self.get_modname("all"), pa)
397 #self.m.d.comb += pa.a.eq(self.sig)
398 #return pa.output
399 return self == Const(-1) # leverage the __eq__ operator here
400
401 def xor(self):
402 """Compute pairwise exclusive-or of every bit.
403
404 Returns
405 -------
406 Value, out
407 ``1`` if an odd number of bits are set, ``0`` if an
408 even number of bits are set.
409 """
410 width = len(self.sig)
411 pa = PartitionedXOR(width, self.partpoints)
412 setattr(self.m.submodules, self.get_modname("xor"), pa)
413 self.m.d.comb += pa.a.eq(self.sig)
414 return pa.output
415
416 # not needed: Value.implies does the job
417 # def implies(premise, conclusion):
418
419 # TODO. contains a Value.cast which means an override is needed (on both)
420 # def bit_select(self, offset, width):
421 # def word_select(self, offset, width):
422
423 # not needed: Value.matches, amazingly, should do the job
424 # def matches(self, *patterns):
425
426 # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=713
427 def shape(self):
428 return self.sig.shape()