1 """ floating-point fused-multiply-add
3 computes `z = (a * c) + b` but only rounds once at the end
6 from nmutil
.pipemodbase
import PipeModBase
7 from ieee754
.fpcommon
.basedata
import FPBaseData
8 from nmigen
.hdl
.ast
import Signal
9 from nmigen
.hdl
.dsl
import Module
10 from ieee754
.fpcommon
.getop
import FPPipeContext
11 from ieee754
.fpcommon
.fpbase
import FPRoundingMode
, MultiShiftRMerge
12 from ieee754
.fpfma
.util
import expanded_exponent_shape
, \
13 expanded_mantissa_shape
, get_fpformat
, multiplicand_mantissa_shape
16 class FPFMAInputData(FPBaseData
):
17 def __init__(self
, pspec
):
18 assert pspec
.n_ops
== 3
19 super().__init
__(pspec
)
21 self
.negate_addend
= Signal()
22 """if the addend should be negated"""
24 self
.negate_product
= Signal()
25 """if the product should be negated"""
29 ret
.append(self
.negate_addend
.eq(i
.negate_addend
))
30 ret
.append(self
.negate_product
.eq(i
.negate_product
))
34 yield from super().__iter
__()
35 yield self
.negate_addend
36 yield self
.negate_product
42 class FPFMASpecialCasesDeNormOutData
:
43 def __init__(self
, pspec
):
44 fpf
= get_fpformat(pspec
)
49 self
.exponent
= Signal(expanded_exponent_shape(fpf
))
50 """exponent of intermediate -- unbiased"""
52 self
.a_mantissa
= Signal(multiplicand_mantissa_shape(fpf
))
53 """mantissa of a input -- un-normalized and with implicit bit added"""
55 self
.b_mantissa
= Signal(multiplicand_mantissa_shape(fpf
))
56 """mantissa of b input
58 shifted to appropriate location for add and with implicit bit added
61 self
.c_mantissa
= Signal(expanded_mantissa_shape(fpf
))
62 """mantissa of c input -- un-normalized and with implicit bit added"""
64 self
.do_sub
= Signal()
65 """true if `b_mantissa` should be subtracted from
66 `a_mantissa * c_mantissa` rather than added
69 self
.bypassed_z
= Signal(fpf
.width
)
70 """final output value of the fma when `do_bypass` is set"""
72 self
.do_bypass
= Signal()
73 """set if `bypassed_z` is the final output value of the fma"""
75 self
.ctx
= FPPipeContext(pspec
)
78 self
.rm
= Signal(FPRoundingMode
, reset
=FPRoundingMode
.DEFAULT
)
96 self
.exponent
.eq(i
.exponent
),
97 self
.a_mantissa
.eq(i
.a_mantissa
),
98 self
.b_mantissa
.eq(i
.b_mantissa
),
99 self
.c_mantissa
.eq(i
.c_mantissa
),
100 self
.do_sub
.eq(i
.do_sub
),
101 self
.bypassed_z
.eq(i
.bypassed_z
),
102 self
.do_bypass
.eq(i
.do_bypass
),
108 class FPFMASpecialCasesDeNorm(PipeModBase
):
109 def __init__(self
, pspec
):
110 super().__init
__(pspec
, "sc_denorm")
113 return FPFMAInputData(self
.pspec
)
116 return FPFMASpecialCasesDeNormOutData(self
.pspec
)
118 def elaborate(self
, platform
):
120 fpf
= get_fpformat(self
.pspec
)
125 a_exponent
= Signal(expanded_exponent_shape(fpf
))
126 m
.d
.comb
+= a_exponent
.eq(fpf
.get_exponent(inp
.a
))
127 b_exponent_in
= Signal(expanded_exponent_shape(fpf
))
128 m
.d
.comb
+= b_exponent_in
.eq(fpf
.get_exponent(inp
.b
))
129 c_exponent
= Signal(expanded_exponent_shape(fpf
))
130 m
.d
.comb
+= c_exponent
.eq(fpf
.get_exponent(inp
.c
))
131 prod_exponent
= Signal(expanded_exponent_shape(fpf
))
132 m
.d
.comb
+= prod_exponent
.eq(a_exponent
+ c_exponent
)
133 prod_exp_minus_b_exp
= Signal(expanded_exponent_shape(fpf
))
134 m
.d
.comb
+= prod_exp_minus_b_exp
.eq(prod_exponent
- b_exponent_in
)
135 b_mantissa_in
= Signal(fpf
.fraction_width
+ 1)
136 m
.d
.comb
+= b_mantissa_in
.eq(fpf
.get_mantissa_value(inp
.b
))
138 m
.d
.comb
+= p_sign
.eq(fpf
.get_sign_field(inp
.a
) ^
139 fpf
.get_sign_field(inp
.c
) ^ inp
.negate_product
)
141 m
.d
.comb
+= b_sign
.eq(fpf
.get_sign_field(inp
.b
) ^ inp
.negate_addend
)
143 exponent
= Signal(expanded_exponent_shape(fpf
))
144 b_shift
= Signal(expanded_exponent_shape(fpf
))
145 # use >= since that's just checking the sign bit
146 with m
.If(prod_exp_minus_b_exp
>= 0):
148 exponent
.eq(prod_exponent
),
149 b_shift
.eq(prod_exp_minus_b_exp
),
153 exponent
.eq(b_exponent_in
),
157 m
.submodules
.rshiftm
= rshiftm
= MultiShiftRMerge(out
.b_mantissa
.width
)
159 rshiftm
.inp
.eq(b_mantissa_in
<< (out
.b_mantissa
.width
160 - b_mantissa_in
.width
)),
161 rshiftm
.diff
.eq(b_shift
),
164 # handle special cases
165 with m
.If(fpf
.is_nan(inp
.a
)):
167 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.a
)),
168 out
.do_bypass
.eq(True),
170 with m
.Elif(fpf
.is_nan(inp
.b
)):
172 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.b
)),
173 out
.do_bypass
.eq(True),
175 with m
.Elif(fpf
.is_nan(inp
.c
)):
177 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.c
)),
178 out
.do_bypass
.eq(True),
180 with m
.Elif((fpf
.is_zero(inp
.a
) & fpf
.is_inf(inp
.c
))
181 |
(fpf
.is_inf(inp
.a
) & fpf
.is_zero(inp
.c
))):
184 out
.bypassed_z
.eq(fpf
.quiet_nan()),
185 out
.do_bypass
.eq(True),
187 with m
.Elif((fpf
.is_inf(inp
.a
) | fpf
.is_inf(inp
.c
))
188 & fpf
.is_inf(inp
.b
) & p_sign
!= b_sign
):
191 out
.bypassed_z
.eq(fpf
.quiet_nan()),
192 out
.do_bypass
.eq(True),
194 with m
.Elif(fpf
.is_inf(inp
.a
) | fpf
.is_inf(inp
.c
)):
197 out
.bypassed_z
.eq(fpf
.inf(p_sign
)),
198 out
.do_bypass
.eq(True),
200 with m
.Elif(fpf
.is_inf(inp
.b
)):
203 out
.bypassed_z
.eq(fpf
.inf(b_sign
)),
204 out
.do_bypass
.eq(True),
206 with m
.Elif((fpf
.is_zero(inp
.a
) | fpf
.is_zero(inp
.c
))
207 & fpf
.is_zero(inp
.b
) & p_sign
== b_sign
):
210 out
.bypassed_z
.eq(fpf
.zero(p_sign
)),
211 out
.do_bypass
.eq(True),
213 # zero - zero handled by FPFMAMainStage
216 out
.bypassed_z
.eq(0),
217 out
.do_bypass
.eq(False),
222 out
.exponent
.eq(exponent
),
223 out
.a_mantissa
.eq(fpf
.get_mantissa_value(inp
.a
)),
224 out
.b_mantissa
.eq(rshiftm
.m
),
225 out
.c_mantissa
.eq(fpf
.get_mantissa_value(inp
.c
)),
226 out
.do_sub
.eq(p_sign
!= b_sign
),