--- /dev/null
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass, field
+import enum
+from fractions import Fraction
+from numbers import Rational
+from .._utils import final, flatten
+from . import ast, dsl, ir
+from typing import overload, TYPE_CHECKING
+if TYPE_CHECKING:
+ # make typechecker check final
+ from typing import final
+
+__all__ = [
+ # FIXME
+]
+
+
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSort(meta=ABCMeta):
+ @abstractmethod
+ def _smtlib2_expr(self, expr_state):
+ return str(...)
+
+ @abstractmethod
+ @staticmethod
+ def _ite_class():
+ return SmtITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortBool(SmtSort):
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "Bool"
+
+ @staticmethod
+ def _ite_class():
+ return SmtBoolITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortInt(SmtSort):
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "Int"
+
+ @staticmethod
+ def _ite_class():
+ return SmtIntITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortReal(SmtSort):
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "Real"
+
+ @staticmethod
+ def _ite_class():
+ return SmtRealITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortBitVec(SmtSort):
+ width: int
+
+ def __post_init__(self):
+ assert isinstance(self.width, int) and self.width > 0, "invalid width"
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return f"(_ BitVec {self.width})"
+
+ @staticmethod
+ def _ite_class():
+ return SmtBitVecITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortRoundingMode(SmtSort):
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "RoundingMode"
+
+ @staticmethod
+ def _ite_class():
+ return SmtRoundingModeITE
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=True, eq=True)
+class SmtSortFloatingPoint(SmtSort):
+ eb: int
+ """ number of bits in the exponent """
+
+ sb: int
+ """ number of bits in the significand including the hidden bit """
+
+ def __post_init__(self):
+ assert isinstance(self.eb, int) and self.eb > 1, "invalid eb"
+ assert isinstance(self.sb, int) and self.sb > 1, "invalid sb"
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ if self == SmtSortFloat16():
+ return "Float16"
+ if self == SmtSortFloat32():
+ return "Float32"
+ if self == SmtSortFloat64():
+ return "Float64"
+ if self == SmtSortFloat128():
+ return "Float128"
+ return f"(_ FloatingPoint {self.eb} {self.sb})"
+
+ @staticmethod
+ def _ite_class():
+ return SmtFloatingPointITE
+
+
+def SmtSortFloat16():
+ return SmtSortFloatingPoint(eb=5, sb=11)
+
+
+def SmtSortFloat32():
+ return SmtSortFloatingPoint(eb=8, sb=24)
+
+
+def SmtSortFloat64():
+ return SmtSortFloatingPoint(eb=11, sb=53)
+
+
+def SmtSortFloat128():
+ return SmtSortFloatingPoint(eb=15, sb=113)
+
+
+@dataclass
+class _ExprState:
+ input_bit_ranges: ast.ValueDict = field(default_factory=ast.ValueDict)
+ input_width: int = 0
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtValue(ast.DUID, meta=ABCMeta):
+ @abstractmethod
+ @staticmethod
+ def sort():
+ return SmtSort()
+
+ @abstractmethod
+ def _smtlib2_expr(self, expr_state):
+ return str(...)
+
+ def same(self, other, *rest):
+ return SmtSame(self, other, *rest)
+
+ def distinct(self, other, *rest):
+ return SmtDistinct(self, other, *rest)
+
+ def __bool__(self):
+ raise TypeError("Attempted to convert SmtValue to Python boolean")
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBool(SmtValue):
+ @staticmethod
+ def sort():
+ return SmtSortBool()
+
+ def __new__(cls, value=None):
+ if cls is not SmtBool:
+ assert value is None
+ return super().__new__(cls)
+ assert isinstance(value, bool)
+ return SmtBoolConst(value)
+
+ # type deduction:
+ @overload
+ def ite(self, then_v: "SmtBool", else_v: "SmtBool") -> "SmtBoolITE": ...
+ @overload
+ def ite(self, then_v: "SmtInt", else_v: "SmtInt") -> "SmtIntITE": ...
+ @overload
+ def ite(self, then_v: "SmtReal", else_v: "SmtReal") -> "SmtRealITE": ...
+
+ @overload
+ def ite(self, then_v: "SmtBitVec", else_v: "SmtBitVec") -> "SmtBitVecITE":
+ ...
+
+ @overload
+ def ite(self, then_v: "SmtRoundingMode",
+ else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ...
+
+ @overload
+ def ite(self, then_v: "SmtFloatingPoint",
+ else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE": ...
+
+ def ite(self, then_v, else_v):
+ return SmtITE(self, then_v, else_v)
+
+ def __invert__(self):
+ return SmtBoolNot(self)
+
+ def __and__(self, other):
+ return SmtBoolAnd(self, other)
+
+ def __rand__(self, other):
+ return SmtBoolAnd(other, self)
+
+ def __xor__(self, other):
+ return SmtBoolXor(self, other)
+
+ def __rxor__(self, other):
+ return SmtBoolXor(other, self)
+
+ def __or__(self, other):
+ return SmtBoolOr(self, other)
+
+ def __ror__(self, other):
+ return SmtBoolOr(other, self)
+
+ def __eq__(self, other):
+ return self.same(other)
+
+ def __ne__(self, other):
+ return self.distinct(other)
+
+ def implies(self, *rest):
+ return SmtBoolImplies(self, *rest)
+
+ def to_bit_vec(self):
+ return self.ite(SmtBitVec(1, width=1),
+ SmtBitVec(0, width=1))
+
+ def to_signal(self):
+ return self.to_bit_vec().to_signal()
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtITE(SmtValue):
+ cond: SmtBool
+ then_v: SmtValue
+ else_v: SmtValue
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtValue)
+ assert isinstance(else_v, SmtValue)
+ sort = then_v.sort()
+ assert sort == else_v.sort()
+ if cls is SmtITE:
+ return sort._ite_class()(cond, then_v, else_v)
+ return super().__new__(cls)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ cond_s = self.cond._smtlib2_expr(expr_state)
+ then_s = self.then_v._smtlib2_expr(expr_state)
+ else_s = self.else_v._smtlib2_expr(expr_state)
+ return f"(ite {cond_s} {then_s} {else_s})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtNArySameSort(SmtValue):
+ inputs: "tuple[SmtValue, ...]"
+
+ @abstractmethod
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return str(...)
+
+ def _expected_input_class(self):
+ return SmtValue
+
+ def __init__(self, *inputs):
+ object.__setattr__(self, "inputs", inputs)
+ assert len(inputs) > 0, "not enough inputs"
+
+ for i in inputs:
+ assert isinstance(i, self._expected_input_class())
+ assert i.sort == self.input_sort, "all input sorts must match"
+
+ @property
+ def input_sort(self):
+ return self.inputs[0].sort
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ op = self._smtlib2_expr_op(expr_state)
+ args = ' '.join(i._smtlib2_expr(expr_state) for i in self.inputs)
+ return f"({op} {args})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtUnary(SmtValue):
+ inp: SmtValue
+
+ @abstractmethod
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return str(...)
+
+ def _expected_input_class(self):
+ return SmtValue
+
+ def __init__(self, inp):
+ object.__setattr__(self, "inp", inp)
+ assert isinstance(inp, self._expected_input_class())
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ op = self._smtlib2_expr_op(expr_state)
+ inp = self.inp._smtlib2_expr(expr_state)
+ return f"({op} {inp})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBinary(SmtValue):
+ lhs: SmtValue
+ rhs: SmtValue
+
+ @abstractmethod
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return str(...)
+
+ def _expected_input_class(self):
+ return SmtValue
+
+ def _expected_input_lhs_class(self):
+ return self._expected_input_class()
+
+ def _expected_input_rhs_class(self):
+ return self._expected_input_class()
+
+ def __init__(self, lhs, rhs):
+ object.__setattr__(self, "lhs", lhs)
+ object.__setattr__(self, "rhs", rhs)
+ assert isinstance(lhs, self._expected_input_lhs_class())
+ assert isinstance(rhs, self._expected_input_rhs_class())
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ op = self._smtlib2_expr_op(expr_state)
+ lhs = self.lhs._smtlib2_expr(expr_state)
+ rhs = self.rhs._smtlib2_expr(expr_state)
+ return f"({op} {lhs} {rhs})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtSame(_SmtNArySameSort, SmtBool):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "="
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtDistinct(_SmtNArySameSort, SmtBool):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "distinct"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBoolConst(SmtBool):
+ value: bool
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "true" if self.value else "false"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBoolNot(_SmtUnary, SmtBool):
+ inp: SmtBool
+
+ def _expected_input_class(self):
+ return SmtBool
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "not"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBoolNAryMinBinary(_SmtNArySameSort, SmtBool):
+ inputs: "tuple[SmtBool, ...]"
+
+ def _expected_input_class(self):
+ return SmtBool
+
+ def __new__(cls, *inputs):
+ return super().__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBoolImplies(_SmtBoolNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "=>"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBoolAnd(_SmtBoolNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "and"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBoolOr(_SmtBoolNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "or"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBoolXor(_SmtBoolNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "xor"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBoolITE(SmtITE, SmtBool):
+ then_v: SmtBool
+ else_v: SmtBool
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtBool)
+ assert isinstance(else_v, SmtBool)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtReal(SmtValue):
+ @staticmethod
+ def sort():
+ return SmtSortReal()
+
+ def __new__(cls, value=None):
+ if cls is not SmtReal:
+ assert value is None
+ return super().__new__(cls)
+ if isinstance(value, SmtInt):
+ return SmtIntToReal(value)
+ assert isinstance(value, Rational), ("value must be a rational "
+ "number, irrational numbers and "
+ "floats aren't yet supported.")
+ return SmtRealConst(value)
+
+ def __init__(self, value=None):
+ # __init__ for IDE type checker
+ # value != None already handled by __new__
+ assert value is None, "invalid argument"
+ return super().__init__()
+
+ def __neg__(self):
+ return SmtRealNeg(self)
+
+ def __pos__(self):
+ return self
+
+ def __add__(self, other):
+ return SmtRealAdd(self, other)
+
+ def __radd__(self, other):
+ return SmtRealAdd(other, self)
+
+ def __sub__(self, other):
+ return SmtRealSub(self, other)
+
+ def __rsub__(self, other):
+ return SmtRealSub(other, self)
+
+ def __mul__(self, other):
+ return SmtRealMul(self, other)
+
+ def __rmul__(self, other):
+ return SmtRealMul(other, self)
+
+ def __truediv__(self, other):
+ return SmtRealDiv(self, other)
+
+ def __rtruediv__(self, other):
+ return SmtRealDiv(other, self)
+
+ def __eq__(self, other):
+ return SmtSame(self, other)
+
+ def __ne__(self, other):
+ return SmtDistinct(self, other)
+
+ def __lt__(self, other):
+ return SmtRealLt(self, other)
+
+ def __le__(self, other):
+ return SmtRealLE(self, other)
+
+ def __gt__(self, other):
+ return SmtRealGt(self, other)
+
+ def __ge__(self, other):
+ return SmtRealGE(self, other)
+
+ def __abs__(self):
+ return self.__lt__(SmtReal(0)).ite(-self, self)
+
+ def __floor__(self):
+ return SmtRealToInt(self)
+
+ def __trunc__(self):
+ return self.__lt__(SmtReal(0)).ite(-(-self).__floor__(),
+ self.__floor__())
+
+ def __ceil__(self):
+ return -(-self).__floor__()
+
+ def is_int(self):
+ return SmtRealIsInt(self)
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRealConst(SmtReal):
+ value: Fraction
+
+ def __init__(self, value):
+ assert isinstance(value, Rational), ("value must be a rational "
+ "number, irrational numbers and "
+ "floats aren't yet supported.")
+ value = Fraction(value)
+ object.__setattr__(self, "value", value)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ retval = f"{abs(self.value.numerator)}.0"
+ if self.value.numerator < 0:
+ retval = f"(- {retval})"
+ if self.value.denominator != 1:
+ retval = f"(/ {retval} {self.value.denominator}.0)"
+ return retval
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRealNeg(_SmtUnary, SmtReal):
+ inp: SmtReal
+
+ def _expected_input_class(self):
+ return SmtReal
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "-"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRealIsInt(_SmtUnary, SmtBool):
+ inp: SmtReal
+
+ def _expected_input_class(self):
+ return SmtReal
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "is_int"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtRealNAryMinBinary(_SmtNArySameSort, SmtReal):
+ inputs: "tuple[SmtReal, ...]"
+
+ def _expected_input_class(self):
+ return SmtReal
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtRealCompareOp(_SmtNArySameSort, SmtBool):
+ inputs: "tuple[SmtReal, ...]"
+
+ def _expected_input_class(self):
+ return SmtReal
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealLt(_SmtRealCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "<"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealLE(_SmtRealCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "<="
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealGt(_SmtRealCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return ">"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealGE(_SmtRealCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return ">="
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealAdd(_SmtRealNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "+"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealSub(_SmtRealNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "-"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealMul(_SmtRealNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "*"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtRealDiv(_SmtRealNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "/"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRealITE(SmtITE, SmtReal):
+ then_v: SmtReal
+ else_v: SmtReal
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtReal)
+ assert isinstance(else_v, SmtReal)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtInt(SmtValue):
+ @staticmethod
+ def sort():
+ return SmtSortInt()
+
+ def __new__(cls, value=None):
+ if cls is not SmtInt:
+ assert value is None
+ return super().__new__(cls)
+ assert isinstance(value, int), "value must be an integer"
+ return SmtIntConst(value)
+
+ def __init__(self, value=None):
+ # __init__ for IDE type checker
+ # value != None already handled by __new__
+ assert value is None, "invalid argument"
+ return super().__init__()
+
+ def __neg__(self):
+ return SmtIntNeg(self)
+
+ def __pos__(self):
+ return self
+
+ def __add__(self, other):
+ return SmtIntAdd(self, other)
+
+ def __radd__(self, other):
+ return SmtIntAdd(other, self)
+
+ def __sub__(self, other):
+ return SmtIntSub(self, other)
+
+ def __rsub__(self, other):
+ return SmtIntSub(other, self)
+
+ def __mul__(self, other):
+ return SmtIntMul(self, other)
+
+ def __rmul__(self, other):
+ return SmtIntMul(other, self)
+
+ def euclid_div(self, other):
+ return SmtIntEuclidDiv(self, other)
+
+ def euclid_rem(self, other):
+ return SmtIntEuclidRem(self, other)
+
+ def __eq__(self, other):
+ return SmtSame(self, other)
+
+ def __ne__(self, other):
+ return SmtDistinct(self, other)
+
+ def __lt__(self, other):
+ return SmtIntLt(self, other)
+
+ def __le__(self, other):
+ return SmtIntLE(self, other)
+
+ def __gt__(self, other):
+ return SmtIntGt(self, other)
+
+ def __ge__(self, other):
+ return SmtIntGE(self, other)
+
+ def __abs__(self):
+ return SmtIntAbs(self)
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtIntConst(SmtInt):
+ value: int
+
+ def __init__(self, value):
+ assert isinstance(value, int), "value must be an integer"
+ object.__setattr__(self, "value", value)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ retval = f"{abs(self.value.numerator)}"
+ if self.value.numerator < 0:
+ retval = f"(- {retval})"
+ return retval
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtIntNeg(_SmtUnary, SmtInt):
+ inp: SmtInt
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "-"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtIntAbs(_SmtUnary, SmtInt):
+ inp: SmtInt
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "abs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtIntToReal(_SmtUnary, SmtReal):
+ inp: SmtInt
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "to_real"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRealToInt(_SmtUnary, SmtInt):
+ inp: SmtReal
+
+ def _expected_input_class(self):
+ return SmtReal
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "to_int"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtIntNAryMinBinary(_SmtNArySameSort, SmtInt):
+ inputs: "tuple[SmtInt, ...]"
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtIntBinary(_SmtBinary, SmtInt):
+ lhs: SmtInt
+ rhs: SmtInt
+
+ def _expected_input_class(self):
+ return SmtInt
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtIntCompareOp(_SmtNArySameSort, SmtBool):
+ inputs: "tuple[SmtInt, ...]"
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntLt(_SmtIntCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "<"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntLE(_SmtIntCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "<="
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntGt(_SmtIntCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return ">"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntGE(_SmtIntCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return ">="
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntAdd(_SmtIntNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "+"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntSub(_SmtIntNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "-"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntMul(_SmtIntNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "*"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntEuclidDiv(_SmtIntNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "div"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtIntEuclidRem(_SmtIntBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "rem"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtIntITE(SmtITE, SmtInt):
+ then_v: SmtInt
+ else_v: SmtInt
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtInt)
+ assert isinstance(else_v, SmtInt)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVec(SmtValue):
+ width: int
+
+ def sort(self):
+ return SmtSortBitVec(self.width)
+
+ @staticmethod
+ def __make_bitvec(value=None, *, width=None):
+ if isinstance(value, int):
+ assert width is not None, "missing width"
+ assert isinstance(width, int) and width > 0, "invalid width"
+ return SmtBitVecConst(value, width=width)
+ assert width is None, "can't give both width and nMigen Value"
+ return SmtBitVecInput(value)
+
+ def __new__(cls, *args, **kwargs):
+ if cls is not SmtInt:
+ return super().__new__(cls)
+ return SmtBitVec.__make_bitvec(*args, **kwargs)
+
+ def __init__(self, value=None, *, width=None):
+ # __init__ for IDE type checker
+ # value != None already handled by __new__
+ assert value is None, "invalid argument"
+ assert isinstance(width, int) and width > 0, "invalid width"
+ object.__setattr__(self, "width", width)
+ super().__init__()
+
+ def __neg__(self):
+ return SmtBitVecNeg(self)
+
+ def __invert__(self):
+ return SmtBitVecNot(self)
+
+ def __pos__(self):
+ return self
+
+ def __add__(self, other):
+ return SmtBitVecAdd(self, other)
+
+ def __radd__(self, other):
+ return SmtBitVecAdd(other, self)
+
+ def __sub__(self, other):
+ assert isinstance(other, SmtBitVec)
+ return self + -other
+
+ def __rsub__(self, other):
+ assert isinstance(other, SmtBitVec)
+ return other + -self
+
+ def __mul__(self, other):
+ return SmtBitVecMul(self, other)
+
+ def __rmul__(self, other):
+ return SmtBitVecMul(other, self)
+
+ def __floordiv__(self, other):
+ return SmtBitVecDiv(self, other)
+
+ def __rfloordiv__(self, other):
+ return SmtBitVecDiv(other, self)
+
+ def __mod__(self, other):
+ return SmtBitVecRem(self, other)
+
+ def __rmod__(self, other):
+ return SmtBitVecRem(other, self)
+
+ def __divmod__(self, other):
+ return self // other, self % other
+
+ def __eq__(self, other):
+ return SmtSame(self, other)
+
+ def __ne__(self, other):
+ return SmtDistinct(self, other)
+
+ def __lt__(self, other):
+ return SmtBitVecLt(self, other)
+
+ def __le__(self, other):
+ assert isinstance(other, SmtBitVec)
+ return other >= self
+
+ def __gt__(self, other):
+ assert isinstance(other, SmtBitVec)
+ return other < self
+
+ def __ge__(self, other):
+ return SmtBoolNot(SmtBitVecLt(self, other))
+
+ def __abs__(self):
+ return self.__lt__(SmtBitVec(0, width=self.width)).ite(-self, self)
+
+ def __and__(self, other):
+ return SmtBitVecAnd(self, other)
+
+ def __rand__(self, other):
+ return SmtBitVecAnd(other, self)
+
+ def __or__(self, other):
+ return SmtBitVecOr(self, other)
+
+ def __ror__(self, other):
+ return SmtBitVecOr(other, self)
+
+ def __xor__(self, other):
+ return SmtBitVecXor(self, other)
+
+ def __rxor__(self, other):
+ return SmtBitVecXor(other, self)
+
+ def __lshift__(self, other):
+ return SmtBitVecLShift(self, other)
+
+ def __rlshift__(self, other):
+ return SmtBitVecLShift(other, self)
+
+ def __rshift__(self, other):
+ return SmtBitVecRShift(self, other)
+
+ def __rrshift__(self, other):
+ return SmtBitVecRShift(other, self)
+
+ def __getitem__(self, index):
+ if isinstance(index, int):
+ if index < 0:
+ index += self.width
+ return SmtBitVecExtract(self, range(index, index + 1))
+ assert isinstance(index, slice)
+ r = range(*index.indices(self.width))
+ assert len(r) > 0, "empty slice"
+ if r.step != 1:
+ return SmtBitVecConcat(self[i] for i in r)
+ return SmtBitVecExtract(self, r)
+
+ def to_signal(self):
+ return SmtBitVecToSignal(self)
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecConcat(SmtBitVec):
+ inputs: "tuple[SmtBitVec, ...]"
+ """inputs in lsb to msb order"""
+
+ def __init__(self, *inputs):
+ inputs = tuple(flatten(inputs))
+ object.__setattr__(self, "inputs", inputs)
+ assert len(inputs) > 0, "not enough inputs"
+
+ width = 0
+ for i in inputs:
+ assert isinstance(i, SmtBitVec)
+ width += i.width
+
+ super().__init__(width=width)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ args = [i._smtlib2_expr(expr_state) for i in self.inputs]
+ return f"(concat {' '.join(reversed(args))})"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBitVecExtract(_SmtUnary, SmtBitVec):
+ inp: SmtBitVec
+ bit_range: range
+
+ def __init__(self, inp, bit_range):
+ assert isinstance(inp, SmtBitVec)
+ if isinstance(bit_range, int):
+ if bit_range < 0:
+ bit_range += inp.width
+ bit_range = range(bit_range, bit_range + 1)
+ assert isinstance(bit_range, range)
+ assert len(bit_range) > 0, "empty range"
+ assert bit_range.step == 1, "unsupported range step"
+ assert 0 <= bit_range.start < bit_range.stop <= inp.width, \
+ "bit range out-of-range"
+ _SmtUnary.__init__(self, inp)
+ SmtBitVec.__init__(self, width=len(bit_range))
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return f"(_ extract {self.bit_range.stop - 1} {self.bit_range.start})"
+
+
+class SmtBitVecToSignal(dsl.Elaboratable):
+ def __init__(self, bv):
+ if isinstance(bv, SmtBool):
+ bv = bv.to_bit_vec()
+ assert isinstance(bv, SmtBitVec)
+ self.bv = bv
+ self.o_sig = ast.Signal(ast.unsigned(self.bv.width))
+ self.__expr_state = _ExprState()
+ self.__expr = bv._smtlib2_expr(self.__expr_state)
+ if self.__expr_state.input_width == 0:
+ # can't have 0-width input
+ self.__expr_state.input_bit_ranges[ast.Const(0, 1)] = range(1)
+ self.__expr_state.input_width = 1
+ self.__instance = ir.Instance(
+ "$smtlib2_expr",
+ p_A_WIDTH=self.__expr_state.input_width,
+ p_Y_WIDTH=self.bv.width,
+ p_EXPR=self.__expr,
+ i_A=ast.Cat(self.__expr_state.input_bit_ranges.keys()),
+ o_Y=self.o_sig)
+
+ def elaborate(self, platform):
+ return self.__instance
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBitVecInput(SmtBitVec):
+ value: ast.Value
+
+ def __init_subclass__(cls):
+ try:
+ _ = SmtBitVecConst
+ except AttributeError:
+ # only possible when we're defining SmtBitVecConst
+ return
+ raise TypeError("subclassing SmtBitVecInput isn't supported")
+
+ def __new__(cls, value):
+ value = ast.Value.cast(value)
+ assert isinstance(value, ast.Value)
+ if isinstance(value, ast.Const):
+ if cls is not SmtBitVecConst:
+ return SmtBitVecConst(value)
+ retval = super().__new__(cls)
+ object.__setattr__(retval, "value", value)
+ return retval
+
+ def __init__(self, value):
+ # self.value assigned in __new__
+ super().__init__(width=self.value.shape().width) # type: ignore
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ if self.value not in expr_state.input_bit_ranges:
+ start = expr_state.input_width
+ expr_state.input_width += self.width
+ r = range(start, expr_state.input_width)
+ expr_state.input_bit_ranges[self.value] = r
+ else:
+ r = expr_state.input_bit_ranges[self.value]
+ return f"((_ extract {r.stop - 1} {r.start}) A)"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBitVecConst(SmtBitVecInput):
+ value: ast.Const
+
+ def __new__(cls, value, *, width=None):
+ if isinstance(value, ast.Const):
+ assert width is None
+ # decompose -- needed since we switch to unsigned
+ width = value.width
+ value = value.value
+ assert isinstance(value, int), "value must be an integer"
+ assert isinstance(width, int) and width > 0, "invalid width"
+ value = ast.Const(value, ast.unsigned(width))
+ return super().__new__(cls, value)
+
+ def __init__(self, value, *, width=None):
+ # rest of logic already handled by __new__
+ super().__init__(self.value)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ hex_digit_count, hex_digit_count_rem = divmod(self.width, 4)
+ if hex_digit_count_rem == 0:
+ digit_count = hex_digit_count
+ digits = hex(self.value.value)
+ else:
+ digit_count = self.width
+ digits = bin(self.value.value)
+ return "#" + digits[1] + digits[2:].zfill(digit_count)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBitVecUnary(_SmtUnary, SmtBitVec):
+ inp: SmtBitVec
+
+ def __init__(self, inp):
+ assert isinstance(inp, SmtBitVec)
+ _SmtUnary.__init__(self, inp)
+ SmtBitVec.__init__(self, width=inp.width)
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecNeg(_SmtBitVecUnary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvneg"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecNot(_SmtBitVecUnary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvnot"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBitVecToNat(_SmtUnary, SmtInt):
+ inp: SmtBitVec
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bv2nat"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtNatToBitVec(_SmtUnary, SmtBitVec):
+ inp: SmtInt
+
+ def _expected_input_class(self):
+ return SmtInt
+
+ def __init__(self, inp, *, width):
+ assert isinstance(inp, SmtInt)
+ assert isinstance(width, int) and width > 0, "invalid width"
+ _SmtUnary.__init__(self, inp)
+ SmtBitVec.__init__(self, width=width)
+
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return f"(_ nat2bv {self.width})"
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBitVecNAryMinBinary(_SmtNArySameSort, SmtBitVec):
+ inputs: "tuple[SmtBitVec, ...]"
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ _SmtNArySameSort.__init__(self, *inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+ SmtBitVec.__init__(self, width=self.inputs[0].width)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBitVecBinary(_SmtBinary, SmtBitVec):
+ lhs: SmtBitVec
+ rhs: SmtBitVec
+
+ def __init__(self, lhs, rhs):
+ _SmtBinary.__init__(self, lhs, rhs)
+ SmtBitVec.__init__(self, width=self.lhs.width)
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class _SmtBitVecCompareOp(_SmtNArySameSort, SmtBool):
+ inputs: "tuple[SmtBitVec, ...]"
+
+ def _expected_input_class(self):
+ return SmtBitVec
+
+ def __new__(cls, *inputs):
+ return _SmtNArySameSort.__new__(cls)
+
+ def __init__(self, *inputs):
+ super().__init__(*inputs)
+ assert len(self.inputs) >= 2, "not enough inputs"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecLt(_SmtBitVecCompareOp):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvult"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecAdd(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvadd"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecAnd(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvand"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecOr(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvor"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecXor(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvxor"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecMul(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvmul"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecDiv(_SmtBitVecNAryMinBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvudiv"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecRem(_SmtBitVecBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvurem"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecLShift(_SmtBitVecBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvshl"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtBitVecRShift(_SmtBitVecBinary):
+ def _smtlib2_expr_op(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return "bvlshr"
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtBitVecITE(SmtITE, SmtBitVec):
+ then_v: SmtBitVec
+ else_v: SmtBitVec
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtBitVec)
+ assert isinstance(else_v, SmtBitVec)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+ def __init__(self, cond, then_v, else_v):
+ SmtITE.__init__(self, cond, then_v, else_v)
+ SmtBitVec.__init__(self, width=self.then_v.width)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRoundingMode(SmtValue):
+ def __new__(cls, value=None):
+ if cls is SmtRoundingMode:
+ return SmtRoundingModeConst(value)
+ return super().__new__(cls)
+
+ def __init__(self, value=None):
+ # __init__ for IDE type checker
+ # value != None already handled by __new__
+ assert value is None, "invalid argument"
+ super().__init__()
+
+ @staticmethod
+ def sort():
+ return SmtSortRoundingMode()
+
+ def __eq__(self, other):
+ return self.same(other)
+
+ def __ne__(self, other):
+ return self.distinct(other)
+
+
+@final
+class RoundingModeEnum(enum.Enum):
+ RNE = "RNE"
+ ROUND_DEFAULT = RNE
+ ROUND_NEAREST_TIES_TO_EVEN = RNE
+ RNA = "RNA"
+ ROUND_NEAREST_TIES_TO_AWAY = RNA
+ RTP = "RTP"
+ ROUND_TOWARD_POSITIVE = RTP
+ RTN = "RTN"
+ ROUND_TOWARD_NEGATIVE = RTN
+ RTZ = "RTZ"
+ ROUND_TOWARD_ZERO = RTZ
+
+ def __init__(self, value):
+ assert isinstance(value, str)
+ self._smtlib2_expr = value
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRoundingModeConst(SmtValue):
+ value: RoundingModeEnum
+
+ def __new__(cls, value):
+ value = RoundingModeEnum(value)
+ try:
+ if value is RoundingModeEnum.RNE:
+ return RNE
+ elif value is RoundingModeEnum.RNA:
+ return RNA
+ elif value is RoundingModeEnum.RTP:
+ return RTP
+ elif value is RoundingModeEnum.RTN:
+ return RTN
+ else:
+ assert value is RoundingModeEnum.RTZ
+ return RTZ
+ except AttributeError:
+ # instance not created yet
+ return super().__new__(cls)
+
+ def _smtlib2_expr(self, expr_state):
+ assert isinstance(expr_state, _ExprState)
+ return self.value._smtlib2_expr
+
+
+RNE = SmtRoundingModeConst(RoundingModeEnum.RNE)
+ROUND_NEAREST_TIES_TO_EVEN = SmtRoundingModeConst(
+ RoundingModeEnum.ROUND_NEAREST_TIES_TO_EVEN)
+ROUND_DEFAULT = SmtRoundingModeConst(
+ RoundingModeEnum.ROUND_DEFAULT)
+RNA = SmtRoundingModeConst(RoundingModeEnum.RNA)
+ROUND_NEAREST_TIES_TO_AWAY = SmtRoundingModeConst(
+ RoundingModeEnum.ROUND_NEAREST_TIES_TO_AWAY)
+RTP = SmtRoundingModeConst(RoundingModeEnum.RTP)
+ROUND_TOWARD_POSITIVE = SmtRoundingModeConst(
+ RoundingModeEnum.ROUND_TOWARD_POSITIVE)
+RTN = SmtRoundingModeConst(RoundingModeEnum.RTN)
+ROUND_TOWARD_NEGATIVE = SmtRoundingModeConst(
+ RoundingModeEnum.ROUND_TOWARD_NEGATIVE)
+RTZ = SmtRoundingModeConst(RoundingModeEnum.RTZ)
+ROUND_TOWARD_ZERO = SmtRoundingModeConst(RoundingModeEnum.ROUND_TOWARD_ZERO)
+
+# check all rounding modes are accounted for:
+# (we check rather than just make everything here so IDEs can more easily
+# see the definitions)
+for name, rm in RoundingModeEnum.__members__.items():
+ assert globals()[name] is SmtRoundingModeConst(rm), f"mismatch {name}"
+del rm, name
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtRoundingModeITE(SmtITE, SmtRoundingMode):
+ then_v: SmtRoundingMode
+ else_v: SmtRoundingMode
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtRoundingMode)
+ assert isinstance(else_v, SmtRoundingMode)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
+class SmtFloatingPoint(SmtValue):
+ eb: int
+ sb: int
+
+ def sort(self):
+ return SmtSortFloatingPoint(eb=self.eb, sb=self.sb)
+
+ def __init__(self, *, eb=None, sb=None, sort=None):
+ if sort is not None:
+ assert eb is sb is None and isinstance(sort, SmtSortFloatingPoint)
+ eb = sort.eb
+ sb = sort.sb
+ assert isinstance(self.eb, int) and self.eb > 1, "invalid eb"
+ assert isinstance(self.sb, int) and self.sb > 1, "invalid sb"
+ object.__setattr__(self, "eb", eb)
+ object.__setattr__(self, "sb", sb)
+
+ @staticmethod
+ def nan(*, eb=None, sb=None, sort=None):
+ return SmtFloatingPointNaN(eb=eb, sb=sb, sort=sort)
+
+ @staticmethod
+ def zero(*, sign=None, eb=None, sb=None, sort=None):
+ return SmtFloatingPointZero(sign=sign, eb=eb, sb=sb, sort=sort)
+
+ @staticmethod
+ def inf(*, sign=None, eb=None, sb=None, sort=None):
+ return SmtFloatingPointInf(sign=sign, eb=eb, sb=sb, sort=sort)
+
+ @staticmethod
+ def from_parts(*, sign, exponent, mantissa):
+ return SmtFloatingPointFromParts(sign=sign, exponent=exponent,
+ mantissa=mantissa)
+
+ def __abs__(self):
+ return SmtFloatingPointAbs(self)
+
+ def __neg__(self):
+ return SmtFloatingPointNeg(self)
+
+ def __pos__(self):
+ return self
+
+ def add(self, other, *, rm):
+ return SmtFloatingPointAdd(self, other, rm=rm)
+
+ def sub(self, other, *, rm):
+ return SmtFloatingPointSub(self, other, rm=rm)
+
+ def mul(self, other, *, rm):
+ return SmtFloatingPointMul(self, other, rm=rm)
+
+ def div(self, other, *, rm):
+ return SmtFloatingPointDiv(self, other, rm=rm)
+
+ def fma(self, factor, term, *, rm):
+ """returns `self * factor + term`"""
+ return SmtFloatingPointFma(self, factor, term, rm=rm)
+
+ def sqrt(self, *, rm):
+ return SmtFloatingPointSqrt(self, rm=rm)
+
+ def rem(self, other, *, rm):
+ return SmtFloatingPointRem(self, other, rm=rm)
+
+ def round_to_integral(self, *, rm):
+ return SmtFloatingPointRoundToIntegral(self, rm=rm)
+
+ def min(self, other):
+ return SmtFloatingPointMin(self, other)
+
+ def max(self, other):
+ return SmtFloatingPointMax(self, other)
+
+ def __eq__(self, other):
+ return SmtFloatingPointEq(self, other)
+
+ def __ne__(self, other):
+ return ~SmtFloatingPointEq(self, other)
+
+ def __lt__(self, other):
+ return ~SmtFloatingPointLt(self, other)
+
+ def __le__(self, other):
+ return ~SmtFloatingPointLE(self, other)
+
+ def __gt__(self, other):
+ return ~SmtFloatingPointGt(self, other)
+
+ def __ge__(self, other):
+ return ~SmtFloatingPointGE(self, other)
+
+ def is_normal(self):
+ return SmtFloatingPointIsNormal(self)
+
+ def is_subnormal(self):
+ return SmtFloatingPointIsSubnormal(self)
+
+ def is_zero(self):
+ return SmtFloatingPointIsZero(self)
+
+ def is_infinite(self):
+ return SmtFloatingPointIsInfinite(self)
+
+ def is_nan(self):
+ return SmtFloatingPointIsNaN(self)
+
+ def is_negative(self):
+ return SmtFloatingPointIsNegative(self)
+
+ def is_positive(self):
+ return SmtFloatingPointIsPositive(self)
+
+ @staticmethod
+ def from_signed_bv(bv, *, rm, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromSignedBV(bv, rm=rm, eb=eb, sb=sb,
+ sort=sort)
+
+ @staticmethod
+ def from_unsigned_bv(bv, *, rm, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromUnsignedBV(bv, rm=rm, eb=eb, sb=sb,
+ sort=sort)
+
+ @staticmethod
+ def from_real(value, *, rm, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromReal(value, rm=rm, eb=eb, sb=sb,
+ sort=sort)
+
+ @staticmethod
+ def from_int(value, *, rm, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromReal(SmtIntToReal(value), rm=rm, eb=eb,
+ sb=sb, sort=sort)
+
+ @staticmethod
+ def from_fp(value, *, rm, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromFP(value, rm=rm, eb=eb,
+ sb=sb, sort=sort)
+
+ @staticmethod
+ def from_bits(bv, *, eb=None, sb=None, sort=None):
+ return SmtFloatingPointFromBits(bv, eb=eb, sb=sb, sort=sort)
+
+ def to_real(self):
+ return SmtFloatingPointToReal(self)
+
+ def __ceil__(self):
+ return SmtFloatingPointToReal(self).__ceil__()
+
+ def __floor__(self):
+ return SmtFloatingPointToReal(self).__floor__()
+
+ def __trunc__(self):
+ return SmtFloatingPointToReal(self).__trunc__()
+
+ def to_unsigned_bv(self, *, width):
+ return SmtFloatingPointToUnsignedBV(self, width=width)
+
+ def to_signed_bv(self, *, width):
+ return SmtFloatingPointToSignedBV(self, width=width)
+
+
+@final
+@dataclass(frozen=True, unsafe_hash=False, eq=False)
+class SmtFloatingPointITE(SmtITE, SmtFloatingPoint):
+ then_v: SmtFloatingPoint
+ else_v: SmtFloatingPoint
+
+ def __new__(cls, cond, then_v, else_v):
+ assert isinstance(cond, SmtBool)
+ assert isinstance(then_v, SmtFloatingPoint)
+ assert isinstance(else_v, SmtFloatingPoint)
+ return SmtITE.__new__(cls, cond, then_v, else_v)
+
+ def __init__(self, cond, then_v, else_v):
+ SmtITE.__init__(self, cond, then_v, else_v)
+ SmtFloatingPoint.__init__(self, sort=self.then_v.sort())