add set-logic ALL clause to stop cvc5 warning
[ieee754fpu.git] / src / ieee754 / fpcmp / fpcmp.py
1 # IEEE Floating Point Conversion, FSGNJ
2 # Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
3 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
4
5
6 from nmigen import Module, Signal, Mux
7
8 from nmutil.pipemodbase import PipeModBase
9 from ieee754.fpcommon.basedata import FPBaseData
10 from ieee754.fpcommon.packdata import FPPackData
11 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
12
13
14 class FPCMPPipeMod(PipeModBase):
15 """
16 Floating point comparison: FEQ, FLT, FLE
17 Opcodes (funct3):
18 - 0b00 - FLE - floating point less than or equal to
19 - 0b01 - FLT - floating point less than
20 - 0b10 - FEQ - floating equals
21 """
22 def __init__(self, in_pspec):
23 self.in_pspec = in_pspec
24 super().__init__(in_pspec, "fpcmp")
25
26 def ispec(self):
27 return FPBaseData(self.in_pspec)
28
29 def ospec(self):
30 return FPPackData(self.in_pspec)
31
32 def elaborate(self, platform):
33 m = Module()
34
35 # useful clarity variables
36 comb = m.d.comb
37 width = self.pspec.width
38 opcode = self.i.ctx.op
39 z1 = self.o.z
40
41 a1 = FPNumBaseRecord(width, False)
42 b1 = FPNumBaseRecord(width, False)
43 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
44 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
45
46 m.d.comb += [a1.v.eq(self.i.a),
47 b1.v.eq(self.i.b)]
48
49 both_zero = Signal()
50 comb += both_zero.eq((a1.v[0:width-1] == 0) &
51 (b1.v[0:width-1] == 0))
52
53 ab_equal = Signal()
54 m.d.comb += ab_equal.eq((a1.v == b1.v) | both_zero)
55
56 contains_nan = Signal()
57 m.d.comb += contains_nan.eq(a1.is_nan | b1.is_nan)
58 a_lt_b = Signal()
59
60 # if(a1.is_zero && b1.is_zero):
61 # a_lt_b = 0
62 # elif(a1.s != b1.s):
63 # a_lt_b = a1.s > b1.s (a is more negative than b)
64 signs_same = Signal()
65 comb += signs_same.eq(a1.s > b1.s)
66
67 # else: # a1.s == b1.s
68 # if(a1.s == 0):
69 # a_lt_b = a[0:31] < b[0:31]
70 # else:
71 # a_lt_b = a[0:31] > b[0:31]
72 signs_different = Signal()
73 comb += signs_different.eq(Mux(a1.s,
74 (a1.v[0:width-1] > b1.v[0:width-1]),
75 (a1.v[0:width-1] < b1.v[0:width-1])))
76
77 comb += a_lt_b.eq(Mux(both_zero, 0,
78 Mux(a1.s == b1.s,
79 signs_different,
80 signs_same)))
81
82 no_nan = Signal()
83 # switch(opcode):
84 # case(0b00): # lt
85 # no_nan = a_lt_b
86 # case(0b01): # le
87 # no_nan = ab_equal
88 # case(0b10):
89 # no_nan = a_lt_b | ab_equal
90 comb += no_nan.eq(
91 Mux(opcode != 0b00, ab_equal, 0) |
92 Mux(opcode[1], 0, a_lt_b))
93
94 # if(a1.is_nan | b1.is_nan):
95 # z1 = 0
96 # else:
97 # z1 = no_nan
98 comb += z1.eq(Mux(contains_nan, 0, no_nan))
99
100 # copy the context (muxid, operator)
101 comb += self.o.ctx.eq(self.i.ctx)
102
103 return m