with m.If(self.i.ctx.op == 1): # SQRT
 
+            # if a is zero return zero
+            with m.If(a1.is_zero):
+                m.d.comb += self.o.out_do_z.eq(1)
+                m.d.comb += self.o.z.zero(a1.s)
+
             # -ve number is NaN
-            with m.If(a1.s):
+            with m.Elif(a1.s):
                 m.d.comb += self.o.out_do_z.eq(1)
                 m.d.comb += self.o.z.nan(0)
 
                 m.d.comb += self.o.out_do_z.eq(1)
                 m.d.comb += self.o.z.nan(0)
 
-            # if a is zero return zero
-            with m.Elif(a1.is_zero):
-                m.d.comb += self.o.out_do_z.eq(1)
-                m.d.comb += self.o.z.zero(0)
-
             # Denormalised Number checks next, so pass a/b data through
             with m.Else():
                 m.d.comb += self.o.out_do_z.eq(0)