d775671c7f685a7be8d56c240a55f77096611995
[ieee754fpu.git] / src / ieee754 / fpmul / fmul.py
1 from nmigen import Module, Signal, Cat, Mux, Array, Const
2 from nmigen.cli import main, verilog
3
4 from ieee754.fpcommon.fpbase import (FPNumIn, FPNumOut, FPOpIn,
5 FPOpOut, Overflow, FPBase, FPState)
6 from ieee754.fpcommon.getop import FPGetOp
7 from nmutil.nmoperator import eq
8
9
10 class FPMUL(FPBase):
11
12 def __init__(self, width):
13 FPBase.__init__(self)
14 self.width = width
15
16 self.in_a = FPOpIn(width)
17 self.in_a.data_i = Signal(width)
18 self.in_b = FPOpIn(width)
19 self.in_b.data_i = Signal(width)
20 self.out_z = FPOpOut(width)
21 self.out_z.data_o = Signal(width)
22
23 self.states = []
24
25 def add_state(self, state):
26 self.states.append(state)
27 return state
28
29 def elaborate(self, platform=None):
30 """ creates the HDL code-fragment for FPMUL
31 """
32 m = Module()
33
34 # Latches
35 a = FPNumIn(None, self.width, False)
36 b = FPNumIn(None, self.width, False)
37 z = FPNumOut(self.width, False)
38
39 mw = (z.m_width)*2 - 1 + 3 # sticky/round/guard bits + (2*mant) - 1
40 product = Signal(mw)
41
42 of = Overflow()
43 m.submodules.of = of
44 m.submodules.a = a
45 m.submodules.b = b
46 m.submodules.z = z
47
48 m.d.comb += a.v.eq(self.in_a.v)
49 m.d.comb += b.v.eq(self.in_b.v)
50
51 with m.FSM() as fsm:
52
53 # ******
54 # gets operand a
55
56 with m.State("get_a"):
57 res = self.get_op(m, self.in_a, a, "get_b")
58 m.d.sync += eq([a, self.in_a.ready_o], res)
59
60 # ******
61 # gets operand b
62
63 with m.State("get_b"):
64 res = self.get_op(m, self.in_b, b, "special_cases")
65 m.d.sync += eq([b, self.in_b.ready_o], res)
66
67 # ******
68 # special cases
69
70 with m.State("special_cases"):
71 #if a or b is NaN return NaN
72 with m.If(a.is_nan | b.is_nan):
73 m.next = "put_z"
74 m.d.sync += z.nan(1)
75 #if a is inf return inf
76 with m.Elif(a.is_inf):
77 m.next = "put_z"
78 m.d.sync += z.inf(a.s ^ b.s)
79 #if b is zero return NaN
80 with m.If(b.is_zero):
81 m.d.sync += z.nan(1)
82 #if b is inf return inf
83 with m.Elif(b.is_inf):
84 m.next = "put_z"
85 m.d.sync += z.inf(a.s ^ b.s)
86 #if a is zero return NaN
87 with m.If(a.is_zero):
88 m.next = "put_z"
89 m.d.sync += z.nan(1)
90 #if a is zero return zero
91 with m.Elif(a.is_zero):
92 m.next = "put_z"
93 m.d.sync += z.zero(a.s ^ b.s)
94 #if b is zero return zero
95 with m.Elif(b.is_zero):
96 m.next = "put_z"
97 m.d.sync += z.zero(a.s ^ b.s)
98 # Denormalised Number checks
99 with m.Else():
100 m.next = "normalise_a"
101 self.denormalise(m, a)
102 self.denormalise(m, b)
103
104 # ******
105 # normalise_a
106
107 with m.State("normalise_a"):
108 self.op_normalise(m, a, "normalise_b")
109
110 # ******
111 # normalise_b
112
113 with m.State("normalise_b"):
114 self.op_normalise(m, b, "multiply_0")
115
116 #multiply_0
117 with m.State("multiply_0"):
118 m.next = "multiply_1"
119 m.d.sync += [
120 z.s.eq(a.s ^ b.s),
121 z.e.eq(a.e + b.e + 1),
122 product.eq(a.m * b.m * 4)
123 ]
124
125 #multiply_1
126 with m.State("multiply_1"):
127 mw = z.m_width
128 m.next = "normalise_1"
129 m.d.sync += [
130 z.m.eq(product[mw+2:]),
131 of.guard.eq(product[mw+1]),
132 of.round_bit.eq(product[mw]),
133 of.sticky.eq(product[0:mw] != 0)
134 ]
135
136 # ******
137 # First stage of normalisation.
138 with m.State("normalise_1"):
139 self.normalise_1(m, z, of, "normalise_2")
140
141 # ******
142 # Second stage of normalisation.
143
144 with m.State("normalise_2"):
145 self.normalise_2(m, z, of, "round")
146
147 # ******
148 # rounding stage
149
150 with m.State("round"):
151 self.roundz(m, z, of.roundz)
152 m.next = "corrections"
153
154 # ******
155 # correction stage
156
157 with m.State("corrections"):
158 self.corrections(m, z, "pack")
159
160 # ******
161 # pack stage
162 with m.State("pack"):
163 self.pack(m, z, "put_z")
164
165 # ******
166 # put_z stage
167
168 with m.State("put_z"):
169 self.put_z(m, z, self.out_z, "get_a")
170
171 return m
172
173
174 if __name__ == "__main__":
175 alu = FPMUL(width=32)
176 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())