working on implementing smtlib2
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 20 May 2022 04:53:16 +0000 (21:53 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 20 May 2022 04:53:16 +0000 (21:53 -0700)
.gitignore
nmigen/hdl/smtlib2.py
tests/test_hdl_smtlib2.py [new file with mode: 0644]

index d11a3eb61749de07cfec0f1a7027b9a93d308774..8dc0a8fe1c9055b2c43576775dc89a67c8f8b4b1 100644 (file)
@@ -7,6 +7,8 @@ __pycache__/
 # coverage
 /.coverage
 /htmlcov
+/cov.xml
+/coverage.xml
 
 # tests
 /tests/spec_*/
index b93bdecfc36b15ca4fe7323161771207496ec95d..04846dd3e1d0cfacd526d8ec41975b1859f182b6 100644 (file)
@@ -6,25 +6,151 @@ from numbers import Rational
 from .._utils import final, flatten
 from . import ast, dsl, ir
 from typing import overload, TYPE_CHECKING
-if TYPE_CHECKING:
+if TYPE_CHECKING:  # :nobr:
     # make typechecker check final
-    from typing import final
+    from typing import final  # :nocov:
 
 __all__ = [
-    # FIXME
+    'RNA',
+    'RNE',
+    'ROUND_DEFAULT',
+    'ROUND_NEAREST_TIES_TO_AWAY',
+    'ROUND_NEAREST_TIES_TO_EVEN',
+    'ROUND_TOWARD_NEGATIVE',
+    'ROUND_TOWARD_POSITIVE',
+    'ROUND_TOWARD_ZERO',
+    'RTN',
+    'RTP',
+    'RTZ',
+    'RoundingModeEnum',
+    'SmtBitVec',
+    'SmtBitVecAdd',
+    'SmtBitVecAnd',
+    'SmtBitVecConcat',
+    'SmtBitVecConst',
+    'SmtBitVecDiv',
+    'SmtBitVecExtract',
+    'SmtBitVecITE',
+    'SmtBitVecInput',
+    'SmtBitVecLShift',
+    'SmtBitVecLt',
+    'SmtBitVecMul',
+    'SmtBitVecNeg',
+    'SmtBitVecNot',
+    'SmtBitVecOr',
+    'SmtBitVecRShift',
+    'SmtBitVecRem',
+    'SmtBitVecToNat',
+    'SmtBitVecToSignal',
+    'SmtBitVecXor',
+    'SmtBool',
+    'SmtBoolAnd',
+    'SmtBoolConst',
+    'SmtBoolITE',
+    'SmtBoolImplies',
+    'SmtBoolNot',
+    'SmtBoolOr',
+    'SmtBoolXor',
+    'SmtDistinct',
+    'SmtFloatingPoint',
+    'SmtFloatingPointAbs',
+    'SmtFloatingPointAdd',
+    'SmtFloatingPointDiv',
+    'SmtFloatingPointEq',
+    'SmtFloatingPointFma',
+    'SmtFloatingPointFromBits',
+    'SmtFloatingPointFromFP',
+    'SmtFloatingPointFromParts',
+    'SmtFloatingPointFromReal',
+    'SmtFloatingPointFromSignedBV',
+    'SmtFloatingPointFromUnsignedBV',
+    'SmtFloatingPointGE',
+    'SmtFloatingPointGt',
+    'SmtFloatingPointITE',
+    'SmtFloatingPointIsInfinite',
+    'SmtFloatingPointIsNaN',
+    'SmtFloatingPointIsNegative',
+    'SmtFloatingPointIsNormal',
+    'SmtFloatingPointIsPositive',
+    'SmtFloatingPointIsSubnormal',
+    'SmtFloatingPointIsZero',
+    'SmtFloatingPointLE',
+    'SmtFloatingPointLt',
+    'SmtFloatingPointMax',
+    'SmtFloatingPointMin',
+    'SmtFloatingPointMul',
+    'SmtFloatingPointNaN',
+    'SmtFloatingPointNeg',
+    'SmtFloatingPointNegInfinity',
+    'SmtFloatingPointNegZero',
+    'SmtFloatingPointPosInfinity',
+    'SmtFloatingPointPosZero',
+    'SmtFloatingPointRem',
+    'SmtFloatingPointRoundToIntegral',
+    'SmtFloatingPointSqrt',
+    'SmtFloatingPointSub',
+    'SmtFloatingPointToReal',
+    'SmtFloatingPointToSignedBV',
+    'SmtFloatingPointToUnsignedBV',
+    'SmtITE',
+    'SmtInt',
+    'SmtIntAbs',
+    'SmtIntAdd',
+    'SmtIntConst',
+    'SmtIntEuclidDiv',
+    'SmtIntEuclidRem',
+    'SmtIntGE',
+    'SmtIntGt',
+    'SmtIntITE',
+    'SmtIntLE',
+    'SmtIntLt',
+    'SmtIntMul',
+    'SmtIntNeg',
+    'SmtIntSub',
+    'SmtIntToReal',
+    'SmtNatToBitVec',
+    'SmtReal',
+    'SmtRealAdd',
+    'SmtRealConst',
+    'SmtRealDiv',
+    'SmtRealGE',
+    'SmtRealGt',
+    'SmtRealITE',
+    'SmtRealIsInt',
+    'SmtRealLE',
+    'SmtRealLt',
+    'SmtRealMul',
+    'SmtRealNeg',
+    'SmtRealSub',
+    'SmtRealToInt',
+    'SmtRoundingMode',
+    'SmtRoundingModeConst',
+    'SmtRoundingModeITE',
+    'SmtSame',
+    'SmtSort',
+    'SmtSortBitVec',
+    'SmtSortBool',
+    'SmtSortFloat128',
+    'SmtSortFloat16',
+    'SmtSortFloat32',
+    'SmtSortFloat64',
+    'SmtSortFloatingPoint',
+    'SmtSortInt',
+    'SmtSortReal',
+    'SmtSortRoundingMode',
+    'SmtValue',
 ]
 
 
 @dataclass(frozen=True, unsafe_hash=True, eq=True)
-class SmtSort(meta=ABCMeta):
+class SmtSort(metaclass=ABCMeta):
     @abstractmethod
     def _smtlib2_expr(self, expr_state):
-        return str(...)
+        return str(...)  # :nocov:
 
     @abstractmethod
-    @staticmethod
-    def _ite_class():
-        return SmtITE
+    def _ite_class(self):
+        return SmtITE  # :nocov:
 
 
 @final
@@ -145,15 +271,14 @@ class _ExprState:
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class SmtValue(ast.DUID, meta=ABCMeta):
+class SmtValue(ast.DUID, metaclass=ABCMeta):
     @abstractmethod
-    @staticmethod
-    def sort():
-        return SmtSort()
+    def sort(self):
+        return SmtSort()  # type: ignore :nocov:
 
     @abstractmethod
     def _smtlib2_expr(self, expr_state):
-        return str(...)
+        return str(...)  # :nocov:
 
     def same(self, other, *rest):
         return SmtSame(self, other, *rest)
@@ -171,35 +296,40 @@ class SmtBool(SmtValue):
     def sort():
         return SmtSortBool()
 
-    def __new__(cls, value=None):
-        if cls is not SmtBool:
-            assert value is None
-            return super().__new__(cls)
-        assert isinstance(value, bool)
-        return SmtBoolConst(value)
+    @staticmethod
+    def make(value):
+        value = ast.Value.cast(value)
+        if isinstance(value, ast.Const):
+            return SmtBoolConst(bool(value.value))
+        return SmtBitVecInput(value).bool()
 
     # type deduction:
     @overload
-    def ite(self, then_v: "SmtBool", else_v: "SmtBool") -> "SmtBoolITE": ...
+    def ite(self, then_v: "SmtBool",
+            else_v: "SmtBool") -> "SmtBoolITE": ...  # :nocov:
+
     @overload
-    def ite(self, then_v: "SmtInt", else_v: "SmtInt") -> "SmtIntITE": ...
+    def ite(self, then_v: "SmtInt",
+            else_v: "SmtInt") -> "SmtIntITE": ...  # :nocov:
+
     @overload
-    def ite(self, then_v: "SmtReal", else_v: "SmtReal") -> "SmtRealITE": ...
+    def ite(self, then_v: "SmtReal",
+            else_v: "SmtReal") -> "SmtRealITE": ...  # :nocov:
 
     @overload
     def ite(self, then_v: "SmtBitVec", else_v: "SmtBitVec") -> "SmtBitVecITE":
-        ...
+        ...  # :nocov:
 
     @overload
     def ite(self, then_v: "SmtRoundingMode",
-            else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ...
+            else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ...  # :nocov:
 
     @overload
-    def ite(self, then_v: "SmtFloatingPoint",
-            else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE": ...
+    def ite(self, then_v: "SmtFloatingPoint", else_v: "SmtFloatingPoint"
+            ) -> "SmtFloatingPointITE": ...  # :nocov:
 
-    def ite(self, then_v, else_v):
-        return SmtITE(self, then_v, else_v)
+    def ite(self, then_v, else_v):  # type: ignore
+        return SmtITE.make(self, then_v, else_v)
 
     def __invert__(self):
         return SmtBoolNot(self)
@@ -232,8 +362,8 @@ class SmtBool(SmtValue):
         return SmtBoolImplies(self, *rest)
 
     def to_bit_vec(self):
-        return self.ite(SmtBitVec(1, width=1),
-                        SmtBitVec(0, width=1))
+        return self.ite(SmtBitVecConst(1, width=1),
+                        SmtBitVecConst(0, width=1))
 
     def to_signal(self):
         return self.to_bit_vec().to_signal()
@@ -245,15 +375,57 @@ class SmtITE(SmtValue):
     then_v: SmtValue
     else_v: SmtValue
 
-    def __new__(cls, cond, then_v, else_v):
+    # type deduction:
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtBool",
+             else_v: "SmtBool") -> "SmtBoolITE": ...  # :nocov:
+
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtInt",
+             else_v: "SmtInt") -> "SmtIntITE": ...  # :nocov:
+
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtReal",
+             else_v: "SmtReal") -> "SmtRealITE": ...  # :nocov:
+
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtBitVec",
+             else_v: "SmtBitVec") -> "SmtBitVecITE": ...  # :nocov:
+
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtRoundingMode",
+             else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ...  # :nocov:
+
+    @overload
+    @staticmethod
+    def make(cond: "SmtBool", then_v: "SmtFloatingPoint",
+             else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE":
+        ...  # :nocov:
+
+    @staticmethod
+    def make(cond, then_v, else_v):
+        assert isinstance(cond, SmtBool)
+        assert isinstance(then_v, SmtValue)
+        assert isinstance(else_v, SmtValue)
+        sort = then_v.sort()
+        assert sort == else_v.sort()
+        return sort._ite_class()(cond, then_v, else_v)  # type: ignore
+
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtValue)
         assert isinstance(else_v, SmtValue)
         sort = then_v.sort()
         assert sort == else_v.sort()
-        if cls is SmtITE:
-            return sort._ite_class()(cond, then_v, else_v)
-        return super().__new__(cls)
+        assert isinstance(self, sort._ite_class())
+        object.__setattr__(self, "cond", cond)
+        object.__setattr__(self, "then_v", then_v)
+        object.__setattr__(self, "else_v", else_v)
 
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
@@ -264,13 +436,13 @@ class SmtITE(SmtValue):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtNArySameSort(SmtValue):
+class _SmtNAry(SmtValue):
     inputs: "tuple[SmtValue, ...]"
 
     @abstractmethod
     def _smtlib2_expr_op(self, expr_state):
-        assert isinstance(expr_state, _ExprState)
-        return str(...)
+        assert isinstance(expr_state, _ExprState)  # :nocov:
+        return str(...)  # :nocov:
 
     def _expected_input_class(self):
         return SmtValue
@@ -281,11 +453,11 @@ class _SmtNArySameSort(SmtValue):
 
         for i in inputs:
             assert isinstance(i, self._expected_input_class())
-            assert i.sort == self.input_sort, "all input sorts must match"
+            assert i.sort() == self.input_sort, "all input sorts must match"
 
     @property
     def input_sort(self):
-        return self.inputs[0].sort
+        return self.inputs[0].sort()
 
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
@@ -300,8 +472,8 @@ class _SmtUnary(SmtValue):
 
     @abstractmethod
     def _smtlib2_expr_op(self, expr_state):
-        assert isinstance(expr_state, _ExprState)
-        return str(...)
+        assert isinstance(expr_state, _ExprState)  # :nocov:
+        return str(...)  # :nocov:
 
     def _expected_input_class(self):
         return SmtValue
@@ -324,23 +496,22 @@ class _SmtBinary(SmtValue):
 
     @abstractmethod
     def _smtlib2_expr_op(self, expr_state):
-        assert isinstance(expr_state, _ExprState)
-        return str(...)
+        assert isinstance(expr_state, _ExprState)  # :nocov:
+        return str(...)  # :nocov:
 
     def _expected_input_class(self):
         return SmtValue
 
-    def _expected_input_lhs_class(self):
-        return self._expected_input_class()
-
-    def _expected_input_rhs_class(self):
-        return self._expected_input_class()
+    @property
+    def input_sort(self):
+        return self.lhs.sort()
 
     def __init__(self, lhs, rhs):
         object.__setattr__(self, "lhs", lhs)
         object.__setattr__(self, "rhs", rhs)
-        assert isinstance(lhs, self._expected_input_lhs_class())
-        assert isinstance(rhs, self._expected_input_rhs_class())
+        assert isinstance(lhs, self._expected_input_class())
+        assert isinstance(rhs, self._expected_input_class())
+        assert lhs.sort() == rhs.sort()
 
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
@@ -352,7 +523,7 @@ class _SmtBinary(SmtValue):
 
 @final
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
-class SmtSame(_SmtNArySameSort, SmtBool):
+class SmtSame(_SmtNAry, SmtBool):
     def _smtlib2_expr_op(self, expr_state):
         assert isinstance(expr_state, _ExprState)
         return "="
@@ -360,7 +531,7 @@ class SmtSame(_SmtNArySameSort, SmtBool):
 
 @final
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
-class SmtDistinct(_SmtNArySameSort, SmtBool):
+class SmtDistinct(_SmtNAry, SmtBool):
     def _smtlib2_expr_op(self, expr_state):
         assert isinstance(expr_state, _ExprState)
         return "distinct"
@@ -390,15 +561,12 @@ class SmtBoolNot(_SmtUnary, SmtBool):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtBoolNAryMinBinary(_SmtNArySameSort, SmtBool):
+class _SmtBoolNAryMinBinary(_SmtNAry, SmtBool):
     inputs: "tuple[SmtBool, ...]"
 
     def _expected_input_class(self):
         return SmtBool
 
-    def __new__(cls, *inputs):
-        return super().__new__(cls)
-
     def __init__(self, *inputs):
         super().__init__(*inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
@@ -442,11 +610,11 @@ class SmtBoolITE(SmtITE, SmtBool):
     then_v: SmtBool
     else_v: SmtBool
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtBool)
         assert isinstance(else_v, SmtBool)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
+        super().__init__(cond, then_v, else_v)
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
@@ -455,10 +623,8 @@ class SmtReal(SmtValue):
     def sort():
         return SmtSortReal()
 
-    def __new__(cls, value=None):
-        if cls is not SmtReal:
-            assert value is None
-            return super().__new__(cls)
+    @staticmethod
+    def make(value):
         if isinstance(value, SmtInt):
             return SmtIntToReal(value)
         assert isinstance(value, Rational), ("value must be a rational "
@@ -466,12 +632,6 @@ class SmtReal(SmtValue):
                                              "floats aren't yet supported.")
         return SmtRealConst(value)
 
-    def __init__(self, value=None):
-        # __init__ for IDE type checker
-        # value != None already handled by __new__
-        assert value is None, "invalid argument"
-        return super().__init__()
-
     def __neg__(self):
         return SmtRealNeg(self)
 
@@ -521,14 +681,14 @@ class SmtReal(SmtValue):
         return SmtRealGE(self, other)
 
     def __abs__(self):
-        return self.__lt__(SmtReal(0)).ite(-self, self)
+        return self.__lt__(SmtRealConst(0)).ite(-self, self)
 
     def __floor__(self):
         return SmtRealToInt(self)
 
     def __trunc__(self):
-        return self.__lt__(SmtReal(0)).ite(-(-self).__floor__(),
-                                           self.__floor__())
+        return self.__lt__(SmtRealConst(0)).ite(-(-self).__floor__(),
+                                                self.__floor__())
 
     def __ceil__(self):
         return -(-self).__floor__()
@@ -586,30 +746,24 @@ class SmtRealIsInt(_SmtUnary, SmtBool):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtRealNAryMinBinary(_SmtNArySameSort, SmtReal):
+class _SmtRealNAryMinBinary(_SmtNAry, SmtReal):
     inputs: "tuple[SmtReal, ...]"
 
     def _expected_input_class(self):
         return SmtReal
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
     def __init__(self, *inputs):
         super().__init__(*inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtRealCompareOp(_SmtNArySameSort, SmtBool):
+class _SmtRealCompareOp(_SmtNAry, SmtBool):
     inputs: "tuple[SmtReal, ...]"
 
     def _expected_input_class(self):
         return SmtReal
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
     def __init__(self, *inputs):
         super().__init__(*inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
@@ -685,11 +839,11 @@ class SmtRealITE(SmtITE, SmtReal):
     then_v: SmtReal
     else_v: SmtReal
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtReal)
         assert isinstance(else_v, SmtReal)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
+        super().__init__(cond, then_v, else_v)
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
@@ -698,19 +852,11 @@ class SmtInt(SmtValue):
     def sort():
         return SmtSortInt()
 
-    def __new__(cls, value=None):
-        if cls is not SmtInt:
-            assert value is None
-            return super().__new__(cls)
+    @staticmethod
+    def make(value):
         assert isinstance(value, int), "value must be an integer"
         return SmtIntConst(value)
 
-    def __init__(self, value=None):
-        # __init__ for IDE type checker
-        # value != None already handled by __new__
-        assert value is None, "invalid argument"
-        return super().__init__()
-
     def __neg__(self):
         return SmtIntNeg(self)
 
@@ -833,15 +979,12 @@ class SmtRealToInt(_SmtUnary, SmtInt):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtIntNAryMinBinary(_SmtNArySameSort, SmtInt):
+class _SmtIntNAryMinBinary(_SmtNAry, SmtInt):
     inputs: "tuple[SmtInt, ...]"
 
     def _expected_input_class(self):
         return SmtInt
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
     def __init__(self, *inputs):
         super().__init__(*inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
@@ -857,15 +1000,12 @@ class _SmtIntBinary(_SmtBinary, SmtInt):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtIntCompareOp(_SmtNArySameSort, SmtBool):
+class _SmtIntCompareOp(_SmtNAry, SmtBool):
     inputs: "tuple[SmtInt, ...]"
 
     def _expected_input_class(self):
         return SmtInt
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
     def __init__(self, *inputs):
         super().__init__(*inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
@@ -949,11 +1089,11 @@ class SmtIntITE(SmtITE, SmtInt):
     then_v: SmtInt
     else_v: SmtInt
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtInt)
         assert isinstance(else_v, SmtInt)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
+        super().__init__(cond, then_v, else_v)
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
@@ -964,7 +1104,7 @@ class SmtBitVec(SmtValue):
         return SmtSortBitVec(self.width)
 
     @staticmethod
-    def __make_bitvec(value=None, *, width=None):
+    def make(value=None, *, width=None):
         if isinstance(value, int):
             assert width is not None, "missing width"
             assert isinstance(width, int) and width > 0, "invalid width"
@@ -972,15 +1112,7 @@ class SmtBitVec(SmtValue):
         assert width is None, "can't give both width and nMigen Value"
         return SmtBitVecInput(value)
 
-    def __new__(cls, *args, **kwargs):
-        if cls is not SmtInt:
-            return super().__new__(cls)
-        return SmtBitVec.__make_bitvec(*args, **kwargs)
-
-    def __init__(self, value=None, *, width=None):
-        # __init__ for IDE type checker
-        # value != None already handled by __new__
-        assert value is None, "invalid argument"
+    def __init__(self, *, width=None):
         assert isinstance(width, int) and width > 0, "invalid width"
         object.__setattr__(self, "width", width)
         super().__init__()
@@ -1050,7 +1182,8 @@ class SmtBitVec(SmtValue):
         return SmtBoolNot(SmtBitVecLt(self, other))
 
     def __abs__(self):
-        return self.__lt__(SmtBitVec(0, width=self.width)).ite(-self, self)
+        lt = self.__lt__(SmtBitVecConst(0, width=self.width))
+        return lt.ite(-self, self)
 
     def __and__(self, other):
         return SmtBitVecAnd(self, other)
@@ -1094,6 +1227,9 @@ class SmtBitVec(SmtValue):
             return SmtBitVecConcat(self[i] for i in r)
         return SmtBitVecExtract(self, r)
 
+    def bool(self):
+        return self != SmtBitVecConst(0, width=self.width)
+
     def to_signal(self):
         return SmtBitVecToSignal(self)
 
@@ -1182,23 +1318,23 @@ class SmtBitVecInput(SmtBitVec):
     def __init_subclass__(cls):
         try:
             _ = SmtBitVecConst
-        except AttributeError:
+        except NameError:
             # only possible when we're defining SmtBitVecConst
             return
         raise TypeError("subclassing SmtBitVecInput isn't supported")
 
-    def __new__(cls, value):
+    @staticmethod
+    def make(value):
         value = ast.Value.cast(value)
         assert isinstance(value, ast.Value)
         if isinstance(value, ast.Const):
-            if cls is not SmtBitVecConst:
-                return SmtBitVecConst(value)
-        retval = super().__new__(cls)
-        object.__setattr__(retval, "value", value)
-        return retval
+            return SmtBitVecConst(value)
+        return SmtBitVecInput(value)
 
     def __init__(self, value):
-        # self.value assigned in __new__
+        value = ast.Value.cast(value)
+        assert isinstance(value, ast.Value)
+        object.__setattr__(self, "value", value)
         super().__init__(width=self.value.shape().width)  # type: ignore
 
     def _smtlib2_expr(self, expr_state):
@@ -1218,7 +1354,7 @@ class SmtBitVecInput(SmtBitVec):
 class SmtBitVecConst(SmtBitVecInput):
     value: ast.Const
 
-    def __new__(cls, value, *, width=None):
+    def __init__(self, value, *, width=None):
         if isinstance(value, ast.Const):
             assert width is None
             # decompose -- needed since we switch to unsigned
@@ -1227,11 +1363,7 @@ class SmtBitVecConst(SmtBitVecInput):
         assert isinstance(value, int), "value must be an integer"
         assert isinstance(width, int) and width > 0, "invalid width"
         value = ast.Const(value, ast.unsigned(width))
-        return super().__new__(cls, value)
-
-    def __init__(self, value, *, width=None):
-        # rest of logic already handled by __new__
-        super().__init__(self.value)
+        super().__init__(value)
 
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
@@ -1307,17 +1439,14 @@ class SmtNatToBitVec(_SmtUnary, SmtBitVec):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtBitVecNAryMinBinary(_SmtNArySameSort, SmtBitVec):
+class _SmtBitVecNAryMinBinary(_SmtNAry, SmtBitVec):
     inputs: "tuple[SmtBitVec, ...]"
 
     def _expected_input_class(self):
         return SmtBitVec
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
     def __init__(self, *inputs):
-        _SmtNArySameSort.__init__(self, *inputs)
+        _SmtNAry.__init__(self, *inputs)
         assert len(self.inputs) >= 2, "not enough inputs"
         SmtBitVec.__init__(self, width=self.inputs[0].width)
 
@@ -1336,19 +1465,17 @@ class _SmtBitVecBinary(_SmtBinary, SmtBitVec):
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class _SmtBitVecCompareOp(_SmtNArySameSort, SmtBool):
-    inputs: "tuple[SmtBitVec, ...]"
+class _SmtBitVecCompareOp(_SmtBinary, SmtBool):
+    lhs: SmtBitVec
+    rhs: SmtBitVec
+
+    def __init__(self, lhs, rhs):
+        _SmtBinary.__init__(self, lhs, rhs)
+        SmtBool.__init__(self)
 
     def _expected_input_class(self):
         return SmtBitVec
 
-    def __new__(cls, *inputs):
-        return _SmtNArySameSort.__new__(cls)
-
-    def __init__(self, *inputs):
-        super().__init__(*inputs)
-        assert len(self.inputs) >= 2, "not enough inputs"
-
 
 @final
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
@@ -1436,29 +1563,19 @@ class SmtBitVecITE(SmtITE, SmtBitVec):
     then_v: SmtBitVec
     else_v: SmtBitVec
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtBitVec)
         assert isinstance(else_v, SmtBitVec)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
-
-    def __init__(self, cond, then_v, else_v):
         SmtITE.__init__(self, cond, then_v, else_v)
         SmtBitVec.__init__(self, width=self.then_v.width)
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
 class SmtRoundingMode(SmtValue):
-    def __new__(cls, value=None):
-        if cls is SmtRoundingMode:
-            return SmtRoundingModeConst(value)
-        return super().__new__(cls)
-
-    def __init__(self, value=None):
-        # __init__ for IDE type checker
-        # value != None already handled by __new__
-        assert value is None, "invalid argument"
-        super().__init__()
+    @staticmethod
+    def make(value):
+        return SmtRoundingModeConst(value)
 
     @staticmethod
     def sort():
@@ -1492,11 +1609,11 @@ class RoundingModeEnum(enum.Enum):
 
 @final
 @dataclass(frozen=True, unsafe_hash=False, eq=False)
-class SmtRoundingModeConst(SmtValue):
+class SmtRoundingModeConst(SmtRoundingMode):
     value: RoundingModeEnum
 
-    def __new__(cls, value):
-        value = RoundingModeEnum(value)
+    def __new__(cls, *args, **kwargs):
+        value = RoundingModeEnum(*args, **kwargs)
         try:
             if value is RoundingModeEnum.RNE:
                 return RNE
@@ -1509,10 +1626,14 @@ class SmtRoundingModeConst(SmtValue):
             else:
                 assert value is RoundingModeEnum.RTZ
                 return RTZ
-        except AttributeError:
+        except NameError:
             # instance not created yet
             return super().__new__(cls)
 
+    def __init__(self, value):
+        assert isinstance(value, RoundingModeEnum)
+        object.__setattr__(self, "value", value)
+
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
         return self.value._smtlib2_expr
@@ -1549,11 +1670,11 @@ class SmtRoundingModeITE(SmtITE, SmtRoundingMode):
     then_v: SmtRoundingMode
     else_v: SmtRoundingMode
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtRoundingMode)
         assert isinstance(else_v, SmtRoundingMode)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
+        super().__init__(cond, then_v, else_v)
 
 
 @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
@@ -1580,11 +1701,33 @@ class SmtFloatingPoint(SmtValue):
 
     @staticmethod
     def zero(*, sign=None, eb=None, sb=None, sort=None):
-        return SmtFloatingPointZero(sign=sign, eb=eb, sb=sb, sort=sort)
+        if sign is None:
+            return SmtFloatingPointPosZero(eb=eb, sb=sb, sort=sort)
+        if isinstance(sign, SmtBitVec):
+            assert sign.width == 1, "invalid sign width"
+            sign = sign.bool()
+        assert isinstance(sign, SmtBool)
+        return sign.ite(SmtFloatingPointNegZero(eb=eb, sb=sb, sort=sort),
+                        SmtFloatingPointPosZero(eb=eb, sb=sb, sort=sort))
+
+    @staticmethod
+    def neg_zero(*, eb=None, sb=None, sort=None):
+        return SmtFloatingPointNegZero(eb=eb, sb=sb, sort=sort)
 
     @staticmethod
-    def inf(*, sign=None, eb=None, sb=None, sort=None):
-        return SmtFloatingPointInf(sign=sign, eb=eb, sb=sb, sort=sort)
+    def infinity(*, sign=None, eb=None, sb=None, sort=None):
+        if sign is None:
+            return SmtFloatingPointPosInfinity(eb=eb, sb=sb, sort=sort)
+        if isinstance(sign, SmtBitVec):
+            assert sign.width == 1, "invalid sign width"
+            sign = sign.bool()
+        assert isinstance(sign, SmtBool)
+        return sign.ite(SmtFloatingPointNegInfinity(eb=eb, sb=sb, sort=sort),
+                        SmtFloatingPointPosInfinity(eb=eb, sb=sb, sort=sort))
+
+    @staticmethod
+    def neg_infinity(*, eb=None, sb=None, sort=None):
+        return SmtFloatingPointNegInfinity(eb=eb, sb=sb, sort=sort)
 
     @staticmethod
     def from_parts(*, sign, exponent, mantissa):
@@ -1619,8 +1762,8 @@ class SmtFloatingPoint(SmtValue):
     def sqrt(self, *, rm):
         return SmtFloatingPointSqrt(self, rm=rm)
 
-    def rem(self, other, *, rm):
-        return SmtFloatingPointRem(self, other, rm=rm)
+    def rem(self, other):
+        return SmtFloatingPointRem(self, other)
 
     def round_to_integral(self, *, rm):
         return SmtFloatingPointRoundToIntegral(self, rm=rm)
@@ -1638,16 +1781,16 @@ class SmtFloatingPoint(SmtValue):
         return ~SmtFloatingPointEq(self, other)
 
     def __lt__(self, other):
-        return ~SmtFloatingPointLt(self, other)
+        return SmtFloatingPointLt(self, other)
 
     def __le__(self, other):
-        return ~SmtFloatingPointLE(self, other)
+        return SmtFloatingPointLE(self, other)
 
     def __gt__(self, other):
-        return ~SmtFloatingPointGt(self, other)
+        return SmtFloatingPointGt(self, other)
 
     def __ge__(self, other):
-        return ~SmtFloatingPointGE(self, other)
+        return SmtFloatingPointGE(self, other)
 
     def is_normal(self):
         return SmtFloatingPointIsNormal(self)
@@ -1724,12 +1867,510 @@ class SmtFloatingPointITE(SmtITE, SmtFloatingPoint):
     then_v: SmtFloatingPoint
     else_v: SmtFloatingPoint
 
-    def __new__(cls, cond, then_v, else_v):
+    def __init__(self, cond, then_v, else_v):
         assert isinstance(cond, SmtBool)
         assert isinstance(then_v, SmtFloatingPoint)
         assert isinstance(else_v, SmtFloatingPoint)
-        return SmtITE.__new__(cls, cond, then_v, else_v)
-
-    def __init__(self, cond, then_v, else_v):
         SmtITE.__init__(self, cond, then_v, else_v)
         SmtFloatingPoint.__init__(self, sort=self.then_v.sort())
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointNaN(SmtFloatingPoint):
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ NaN {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointPosZero(SmtFloatingPoint):
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ +zero {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointNegZero(SmtFloatingPoint):
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ -zero {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointPosInfinity(SmtFloatingPoint):
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ +oo {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointNegInfinity(SmtFloatingPoint):
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ -oo {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointFromParts(SmtFloatingPoint):
+    sign: SmtBitVec
+    exponent: SmtBitVec
+    mantissa: SmtBitVec
+
+    def __init__(self, sign, exponent, mantissa):
+        if isinstance(sign, SmtBool):
+            sign = sign.to_bit_vec()
+        assert isinstance(sign, SmtBitVec) and sign.width == 1
+        assert isinstance(exponent, SmtBitVec)
+        assert isinstance(mantissa, SmtBitVec)
+        super().__init__(eb=exponent.width, sb=mantissa.width + 1)
+
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        sign = self.sign._smtlib2_expr(expr_state)
+        exponent = self.exponent._smtlib2_expr(expr_state)
+        mantissa = self.mantissa._smtlib2_expr(expr_state)
+        return f"(fp {sign} {exponent} {mantissa})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointUnary(_SmtUnary, SmtFloatingPoint):
+    inp: SmtFloatingPoint
+
+    def __init__(self, inp):
+        assert isinstance(inp, SmtFloatingPoint)
+        _SmtUnary.__init__(self, inp)
+        SmtFloatingPoint.__init__(self, eb=inp.eb, sb=inp.sb)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointAbs(_SmtFloatingPointUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.abs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointNeg(_SmtFloatingPointUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.neg"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointBinaryRounded(_SmtBinary, SmtFloatingPoint):
+    lhs: SmtFloatingPoint
+    rhs: SmtFloatingPoint
+    rm: SmtRoundingMode
+
+    def __init__(self, lhs, rhs, *, rm):
+        assert isinstance(rm, SmtRoundingMode)
+        _SmtBinary.__init__(self, lhs, rhs)
+        object.__setattr__(self, "rm", rm)
+        SmtFloatingPoint.__init__(self, eb=self.lhs.eb, sb=self.rhs.sb)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        op = self._smtlib2_expr_op(expr_state)
+        rm = self.rm._smtlib2_expr(expr_state)
+        lhs = self.lhs._smtlib2_expr(expr_state)
+        rhs = self.rhs._smtlib2_expr(expr_state)
+        return f"({op} {rm} {lhs} {rhs})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointAdd(_SmtFloatingPointBinaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.add"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointSub(_SmtFloatingPointBinaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.sub"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointMul(_SmtFloatingPointBinaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.mul"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointDiv(_SmtFloatingPointBinaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.div"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointFma(SmtFloatingPoint):
+    """returns `self.factor1 * self.factor2 + self.term`"""
+    factor1: SmtFloatingPoint
+    factor2: SmtFloatingPoint
+    term: SmtFloatingPoint
+    rm: SmtRoundingMode
+
+    def __init__(self, factor1, factor2, term, *, rm):
+        assert isinstance(factor1, SmtFloatingPoint)
+        assert isinstance(factor2, SmtFloatingPoint)
+        assert isinstance(term, SmtFloatingPoint)
+        assert factor1.sort() == factor2.sort() == term.sort(), "sort mismatch"
+        assert isinstance(rm, SmtRoundingMode)
+        object.__setattr__(self, "factor1", factor1)
+        object.__setattr__(self, "factor2", factor2)
+        object.__setattr__(self, "term", term)
+        object.__setattr__(self, "rm", rm)
+        super().__init__(eb=factor1.eb, sb=factor1.sb)
+
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        rm = self.rm._smtlib2_expr(expr_state)
+        factor1 = self.factor1._smtlib2_expr(expr_state)
+        factor2 = self.factor2._smtlib2_expr(expr_state)
+        term = self.term._smtlib2_expr(expr_state)
+        return f"(fp {rm} {factor1} {factor2} {term})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointUnaryRounded(_SmtUnary, SmtFloatingPoint):
+    inp: SmtFloatingPoint
+    rm: SmtRoundingMode
+
+    def __init__(self, inp, *, rm):
+        assert isinstance(rm, SmtRoundingMode)
+        _SmtUnary.__init__(self, inp)
+        object.__setattr__(self, "rm", rm)
+        SmtFloatingPoint.__init__(self, eb=inp.eb, sb=inp.sb)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        op = self._smtlib2_expr_op(expr_state)
+        rm = self.rm._smtlib2_expr(expr_state)
+        inp = self.inp._smtlib2_expr(expr_state)
+        return f"({op} {rm} {inp})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointSqrt(_SmtFloatingPointUnaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.sqrt"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointBinary(_SmtBinary, SmtFloatingPoint):
+    lhs: SmtFloatingPoint
+    rhs: SmtFloatingPoint
+
+    def __init__(self, lhs, rhs):
+        _SmtBinary.__init__(self, lhs, rhs)
+        SmtFloatingPoint.__init__(self, eb=self.lhs.eb, sb=self.rhs.sb)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointRem(_SmtFloatingPointBinary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.rem"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointRoundToIntegral(_SmtFloatingPointUnaryRounded):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.roundToIntegral"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointMin(_SmtFloatingPointBinary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.min"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointMax(_SmtFloatingPointBinary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.max"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointCompareOp(_SmtNAry, SmtBool):
+    inputs: "tuple[SmtFloatingPoint, ...]"
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def __init__(self, *inputs):
+        super().__init__(*inputs)
+        assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointEq(_SmtFloatingPointCompareOp):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.eq"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointLt(_SmtFloatingPointCompareOp):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.lt"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointLE(_SmtFloatingPointCompareOp):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.leq"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointGt(_SmtFloatingPointCompareOp):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.gt"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointGE(_SmtFloatingPointCompareOp):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.geq"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointToBoolUnary(_SmtUnary, SmtBool):
+    inp: SmtFloatingPoint
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsNormal(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isNormal"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsSubnormal(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isSubnormal"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsZero(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isZero"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsInfinite(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isInfinite"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsNaN(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isNaN"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsNegative(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isNegative"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointIsPositive(_SmtFloatingPointToBoolUnary):
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.isPositive"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtFloatingPointRoundFrom(_SmtUnary, SmtFloatingPoint):
+    rm: SmtRoundingMode
+
+    def __init__(self, inp, *, rm, eb=None, sb=None, sort=None):
+        assert isinstance(rm, SmtRoundingMode)
+        _SmtUnary.__init__(self, inp)
+        object.__setattr__(self, "rm", rm)
+        SmtFloatingPoint.__init__(self, eb=eb, sb=sb, sort=sort)
+
+    def _smtlib2_expr(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        op = self._smtlib2_expr_op(expr_state)
+        rm = self.rm._smtlib2_expr(expr_state)
+        inp = self.inp._smtlib2_expr(expr_state)
+        return f"({op} {rm} {inp})"
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ to_fp {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointFromSignedBV(_SmtFloatingPointRoundFrom):
+    inp: SmtBitVec
+
+    def _expected_input_class(self):
+        return SmtBitVec
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointFromUnsignedBV(_SmtFloatingPointRoundFrom):
+    inp: SmtBitVec
+
+    def _expected_input_class(self):
+        return SmtBitVec
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ to_fp_unsigned {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointFromReal(_SmtFloatingPointRoundFrom):
+    inp: SmtReal
+
+    def _expected_input_class(self):
+        return SmtReal
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPointFromFP(_SmtFloatingPointRoundFrom):
+    inp: SmtFloatingPoint
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointFromBits(_SmtUnary, SmtFloatingPoint):
+    inp: SmtBitVec
+
+    def __init__(self, inp, *, eb=None, sb=None, sort=None):
+        assert isinstance(inp, SmtBitVec)
+        _SmtUnary.__init__(self, inp)
+        SmtFloatingPoint.__init__(self, eb=eb, sb=sb, sort=sort)
+        mantissa_field_width = sb - 1
+        sign_field_width = 1
+        expected_width = sign_field_width + eb + mantissa_field_width
+        assert inp.width == expected_width, \
+            f"input BitVec is the wrong width, expected {expected_width}"
+
+    def _expected_input_class(self):
+        return SmtBitVec
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return f"(_ to_fp {self.eb} {self.sb})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointToReal(_SmtUnary, SmtReal):
+    inp: SmtFloatingPoint
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.to_real"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointToUnsignedBV(_SmtUnary, SmtBitVec):
+    inp: SmtFloatingPoint
+
+    def __init__(self, inp, *, width):
+        _SmtUnary.__init__(self, inp)
+        SmtBitVec.__init__(self, width=width)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.to_ubv"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointToSignedBV(_SmtUnary, SmtBitVec):
+    inp: SmtFloatingPoint
+
+    def __init__(self, inp, *, width):
+        _SmtUnary.__init__(self, inp)
+        SmtBitVec.__init__(self, width=width)
+
+    def _expected_input_class(self):
+        return SmtFloatingPoint
+
+    def _smtlib2_expr_op(self, expr_state):
+        assert isinstance(expr_state, _ExprState)
+        return "fp.to_sbv"
diff --git a/tests/test_hdl_smtlib2.py b/tests/test_hdl_smtlib2.py
new file mode 100644 (file)
index 0000000..423ea8f
--- /dev/null
@@ -0,0 +1,79 @@
+from nmigen.hdl.smtlib2 import *
+from nmigen.hdl.smtlib2 import _ExprState
+from nmigen.hdl.ast import ValueDict, Signal
+from .utils import FHDLTestCase
+
+
+class SmtTestCase(FHDLTestCase):
+    def assertSmtExpr(self, v, ty, expected_expr, input_bit_ranges=(), input_width=0):
+        self.assertIs(type(v), ty)
+        state = _ExprState()
+        expr = v._smtlib2_expr(state)
+        self.assertEqual(expr, expected_expr)
+        self.assertEqual(state.input_bit_ranges, ValueDict(input_bit_ranges))
+        self.assertEqual(state.input_width, input_width)
+
+
+class TestSorts(SmtTestCase):
+    def test_bool(self):
+        self.assertSmtExpr(SmtSortBool(), SmtSortBool, "Bool")
+
+    def test_int(self):
+        self.assertSmtExpr(SmtSortInt(), SmtSortInt, "Int")
+
+    def test_real(self):
+        self.assertSmtExpr(SmtSortReal(), SmtSortReal, "Real")
+
+    def test_bv(self):
+        self.assertSmtExpr(SmtSortBitVec(1), SmtSortBitVec, "(_ BitVec 1)")
+        self.assertSmtExpr(SmtSortBitVec(16), SmtSortBitVec, "(_ BitVec 16)")
+
+    def test_rm(self):
+        self.assertSmtExpr(SmtSortRoundingMode(),
+                           SmtSortRoundingMode, "RoundingMode")
+
+    def test_fp(self):
+        self.assertSmtExpr(SmtSortFloat16(), SmtSortFloatingPoint, "Float16")
+        self.assertSmtExpr(SmtSortFloat32(), SmtSortFloatingPoint, "Float32")
+        self.assertSmtExpr(SmtSortFloat64(), SmtSortFloatingPoint, "Float64")
+        self.assertSmtExpr(SmtSortFloat128(), SmtSortFloatingPoint, "Float128")
+        self.assertSmtExpr(SmtSortFloatingPoint(3, 5),
+                           SmtSortFloatingPoint, "(_ FloatingPoint 3 5)")
+        self.assertSmtExpr(SmtSortFloatingPoint(2, 2),
+                           SmtSortFloatingPoint, "(_ FloatingPoint 2 2)")
+
+
+class TestBool(SmtTestCase):
+    def test_make(self):
+        self.assertSmtExpr(SmtBool.make(False), SmtBoolConst, "false")
+        self.assertSmtExpr(SmtBool.make(True), SmtBoolConst, "true")
+        sig = Signal()
+        self.assertSmtExpr(SmtBool.make(sig), SmtDistinct,
+                           "(distinct #b0 ((_ extract 0 0) A))",
+                           input_bit_ranges=[(sig, range(0, 1))],
+                           input_width=1)
+        sig2 = Signal(2)
+        self.assertSmtExpr(SmtBool.make(sig2.bool()), SmtDistinct,
+                           "(distinct #b0 ((_ extract 0 0) A))",
+                           input_bit_ranges=[(sig2.bool(), range(0, 1))],
+                           input_width=1)
+
+    def test_ite(self):
+        a = Signal()
+        b = Signal()
+        c = Signal()
+        self.assertSmtExpr(
+            SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)),
+            SmtBoolITE,
+            "(ite (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)) "
+            "(distinct #b0 ((_ extract 2 2) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+                (c, range(2, 3)),
+            ],
+            input_width=3,
+        )
+
+# FIXME: add more tests