# prepare inf/zero/nans
         z_zero = FPNumBaseRecord(width, False, name="z_zero")
-        z_nan = FPNumBaseRecord(width, False, name="z_nan")
+        z_default_nan = FPNumBaseRecord(width, False, name="z_default_nan")
+        z_quieted_a = FPNumBaseRecord(width, False, name="z_quieted_a")
+        z_quieted_b = FPNumBaseRecord(width, False, name="z_quieted_b")
         z_infa = FPNumBaseRecord(width, False, name="z_infa")
         z_infb = FPNumBaseRecord(width, False, name="z_infb")
         comb += z_zero.zero(0)
-        comb += z_nan.nan(0)
+        comb += z_default_nan.nan(0)
+        comb += z_quieted_a.quieted_nan(a1)
+        comb += z_quieted_b.quieted_nan(b1)
         comb += z_infa.inf(a1.s)
         comb += z_infb.inf(b1.s)
 
 
         # this is the logic-decision-making for special-cases:
         # if a is NaN or b is NaN return NaN
+        #   if a is NaN return quieted_nan(a)
+        #   else return quieted_nan(b)
         # elif a is inf return inf (or NaN)
         #   if a is inf and signs don't match return NaN
         #   else return inf(a)
         oz = Mux(t_a1zero, b1.v, oz)
         oz = Mux(t_abz, Cat(self.i.b[:-1], absa), oz)
         oz = Mux(t_b1inf, z_infb.v, oz)
-        oz = Mux(t_a1inf, Mux(bexp128s, z_nan.v, z_infa.v), oz)
-        oz = Mux(t_abnan, z_nan.v, oz)
+        oz = Mux(t_a1inf, Mux(bexp128s, z_default_nan.v, z_infa.v), oz)
+        oz = Mux(t_abnan, Mux(a1.is_nan, z_quieted_a.v, z_quieted_b.v), oz)
 
         comb += self.o.oz.eq(oz)
 
 
 from nmutil.formaltest import FHDLTestCase
 from ieee754.fpadd.pipeline import FPADDBasePipe
 from nmigen.hdl.dsl import Module
-from nmigen.hdl.ast import AnySeq, Initial, Assert, AnyConst, Signal, Assume
+from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume
 from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
     SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, \
     ROUND_NEAREST_TIES_TO_EVEN
         z_fp = SmtFloatingPoint.from_bits(z, sort=sort)
         expected_fp = a_fp.add(b_fp, rm=rm)
         expected = Signal(width)
-        m.d.comb += expected.eq(AnySeq(width))
-        # Important Note: expected and z won't necessarily match bit-exactly
-        # if it's a NaN, all this checks for is z is also any NaN
-        m.d.comb += Assume((SmtFloatingPoint.from_bits(expected, sort=sort)
-                            == expected_fp).as_value())
-        # FIXME: check that it produces the correct NaNs
+        m.d.comb += expected.eq(AnyConst(width))
+        quiet_bit = 1 << (sort.mantissa_field_width - 1)
+        nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
+        with m.If(expected_fp.is_nan().as_value()):
+            with m.If(a_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (a | quiet_bit))
+            with m.Elif(b_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (b | quiet_bit))
+            with m.Else():
+                m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
+        with m.Else():
+            m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
+                               .same(expected_fp).as_value())
         m.d.comb += a.eq(AnyConst(width))
         m.d.comb += b.eq(AnyConst(width))
         with m.If(dut.n.trigger):
-            m.d.sync += Assert((z_fp == expected_fp).as_value())
-        self.assertFormal(m, depth=5, solver="z3")
+            m.d.sync += Assert(z_fp.same(expected_fp).as_value())
+            m.d.sync += Assert(z == expected)
+        self.assertFormal(m, depth=5, solver="bitwuzla")
 
     # FIXME: check other rounding modes
     # FIXME: check exception flags