switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / fpmax / fpmax.py
1 # IEEE Floating Point Conversion, FSGNJ
2 # Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
3 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
4
5
6 from nmigen import Module, Signal, Mux
7
8 from nmutil.pipemodbase import PipeModBase
9 from ieee754.fpcommon.basedata import FPBaseData
10 from ieee754.fpcommon.packdata import FPPackData
11 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
12
13
14 class FPMAXPipeMod(PipeModBase):
15 """ FP Sign injection - replaces operand A's sign bit with one
16 generated from operand B
17
18 self.ctx.i.op & 0x3 == 0x0 : Copy sign bit from operand B
19 self.ctx.i.op & 0x3 == 0x1 : Copy inverted sign bit from operand B
20 self.ctx.i.op & 0x3 == 0x2 : Sign bit is A's sign XOR B's sign
21 """
22 def __init__(self, in_pspec):
23 self.in_pspec = in_pspec
24 super().__init__(in_pspec, "fpmax")
25
26 def ispec(self):
27 return FPBaseData(self.in_pspec)
28
29 def ospec(self):
30 return FPPackData(self.in_pspec)
31
32 def elaborate(self, platform):
33 m = Module()
34
35 # useful clarity variables
36 comb = m.d.comb
37 width = self.pspec.width
38 opcode = self.i.ctx.op
39 z1 = self.o.z
40
41 a1 = FPNumBaseRecord(width, False)
42 b1 = FPNumBaseRecord(width, False)
43 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
44 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
45
46 m.d.comb += [a1.v.eq(self.i.a),
47 b1.v.eq(self.i.b)]
48
49 no_nans = Signal(width)
50 some_nans = Signal(width)
51
52 # Handle NaNs
53 has_nan = Signal()
54 comb += has_nan.eq(a1.is_nan | b1.is_nan)
55 both_nan = Signal()
56 comb += both_nan.eq(a1.is_nan & b1.is_nan)
57
58 # if(both_nan):
59 # some_nans = NaN - created from scratch
60 # else:
61 # some_nans = Mux(a1.is_nan, b, a)
62 comb += some_nans.eq(Mux(both_nan,
63 a1.fp.nan2(0),
64 Mux(a1.is_nan, self.i.b, self.i.a)))
65
66 # if sign(a) != sign(b):
67 # no_nans = Mux(a1.s ^ opcode[0], b, a)
68 signs_different = Signal()
69 comb += signs_different.eq(a1.s != b1.s)
70
71 signs_different_value = Signal(width)
72 comb += signs_different_value.eq(Mux(a1.s ^ opcode[0],
73 self.i.b,
74 self.i.a))
75
76 # else:
77 # if a.v > b.v:
78 # no_nans = Mux(opcode[0], b, a)
79 # else:
80 # no_nans = Mux(opcode[0], a, b)
81 gt = Signal()
82 sign = Signal()
83 signs_same = Signal(width)
84 comb += sign.eq(a1.s)
85 comb += gt.eq(a1.v > b1.v)
86 comb += signs_same.eq(Mux(gt ^ sign ^ opcode[0],
87 self.i.a, self.i.b))
88 comb += no_nans.eq(Mux(signs_different, signs_different_value,
89 signs_same))
90
91 comb += z1.eq(Mux(has_nan, some_nans, no_nans))
92
93 # copy the context (muxid, operator)
94 comb += self.o.ctx.eq(self.i.ctx)
95
96 return m