From: Jacob Lifshay Date: Thu, 19 May 2022 08:00:01 +0000 (-0700) Subject: working on adding smtlib2 expression support X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=8429519f5e64a81c5ff439be3db0238c5899906d;p=nmigen.git working on adding smtlib2 expression support --- diff --git a/nmigen/hdl/smtlib2.py b/nmigen/hdl/smtlib2.py new file mode 100644 index 0000000..b93bdec --- /dev/null +++ b/nmigen/hdl/smtlib2.py @@ -0,0 +1,1735 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field +import enum +from fractions import Fraction +from numbers import Rational +from .._utils import final, flatten +from . import ast, dsl, ir +from typing import overload, TYPE_CHECKING +if TYPE_CHECKING: + # make typechecker check final + from typing import final + +__all__ = [ + # FIXME +] + + +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSort(meta=ABCMeta): + @abstractmethod + def _smtlib2_expr(self, expr_state): + return str(...) + + @abstractmethod + @staticmethod + def _ite_class(): + return SmtITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortBool(SmtSort): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "Bool" + + @staticmethod + def _ite_class(): + return SmtBoolITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortInt(SmtSort): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "Int" + + @staticmethod + def _ite_class(): + return SmtIntITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortReal(SmtSort): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "Real" + + @staticmethod + def _ite_class(): + return SmtRealITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortBitVec(SmtSort): + width: int + + def __post_init__(self): + assert isinstance(self.width, int) and self.width > 0, "invalid width" + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ BitVec {self.width})" + + @staticmethod + def _ite_class(): + return SmtBitVecITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortRoundingMode(SmtSort): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "RoundingMode" + + @staticmethod + def _ite_class(): + return SmtRoundingModeITE + + +@final +@dataclass(frozen=True, unsafe_hash=True, eq=True) +class SmtSortFloatingPoint(SmtSort): + eb: int + """ number of bits in the exponent """ + + sb: int + """ number of bits in the significand including the hidden bit """ + + def __post_init__(self): + assert isinstance(self.eb, int) and self.eb > 1, "invalid eb" + assert isinstance(self.sb, int) and self.sb > 1, "invalid sb" + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + if self == SmtSortFloat16(): + return "Float16" + if self == SmtSortFloat32(): + return "Float32" + if self == SmtSortFloat64(): + return "Float64" + if self == SmtSortFloat128(): + return "Float128" + return f"(_ FloatingPoint {self.eb} {self.sb})" + + @staticmethod + def _ite_class(): + return SmtFloatingPointITE + + +def SmtSortFloat16(): + return SmtSortFloatingPoint(eb=5, sb=11) + + +def SmtSortFloat32(): + return SmtSortFloatingPoint(eb=8, sb=24) + + +def SmtSortFloat64(): + return SmtSortFloatingPoint(eb=11, sb=53) + + +def SmtSortFloat128(): + return SmtSortFloatingPoint(eb=15, sb=113) + + +@dataclass +class _ExprState: + input_bit_ranges: ast.ValueDict = field(default_factory=ast.ValueDict) + input_width: int = 0 + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtValue(ast.DUID, meta=ABCMeta): + @abstractmethod + @staticmethod + def sort(): + return SmtSort() + + @abstractmethod + def _smtlib2_expr(self, expr_state): + return str(...) + + def same(self, other, *rest): + return SmtSame(self, other, *rest) + + def distinct(self, other, *rest): + return SmtDistinct(self, other, *rest) + + def __bool__(self): + raise TypeError("Attempted to convert SmtValue to Python boolean") + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBool(SmtValue): + @staticmethod + def sort(): + return SmtSortBool() + + def __new__(cls, value=None): + if cls is not SmtBool: + assert value is None + return super().__new__(cls) + assert isinstance(value, bool) + return SmtBoolConst(value) + + # type deduction: + @overload + def ite(self, then_v: "SmtBool", else_v: "SmtBool") -> "SmtBoolITE": ... + @overload + def ite(self, then_v: "SmtInt", else_v: "SmtInt") -> "SmtIntITE": ... + @overload + def ite(self, then_v: "SmtReal", else_v: "SmtReal") -> "SmtRealITE": ... + + @overload + def ite(self, then_v: "SmtBitVec", else_v: "SmtBitVec") -> "SmtBitVecITE": + ... + + @overload + def ite(self, then_v: "SmtRoundingMode", + else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ... + + @overload + def ite(self, then_v: "SmtFloatingPoint", + else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE": ... + + def ite(self, then_v, else_v): + return SmtITE(self, then_v, else_v) + + def __invert__(self): + return SmtBoolNot(self) + + def __and__(self, other): + return SmtBoolAnd(self, other) + + def __rand__(self, other): + return SmtBoolAnd(other, self) + + def __xor__(self, other): + return SmtBoolXor(self, other) + + def __rxor__(self, other): + return SmtBoolXor(other, self) + + def __or__(self, other): + return SmtBoolOr(self, other) + + def __ror__(self, other): + return SmtBoolOr(other, self) + + def __eq__(self, other): + return self.same(other) + + def __ne__(self, other): + return self.distinct(other) + + def implies(self, *rest): + return SmtBoolImplies(self, *rest) + + def to_bit_vec(self): + return self.ite(SmtBitVec(1, width=1), + SmtBitVec(0, width=1)) + + def to_signal(self): + return self.to_bit_vec().to_signal() + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtITE(SmtValue): + cond: SmtBool + then_v: SmtValue + else_v: SmtValue + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtValue) + assert isinstance(else_v, SmtValue) + sort = then_v.sort() + assert sort == else_v.sort() + if cls is SmtITE: + return sort._ite_class()(cond, then_v, else_v) + return super().__new__(cls) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + cond_s = self.cond._smtlib2_expr(expr_state) + then_s = self.then_v._smtlib2_expr(expr_state) + else_s = self.else_v._smtlib2_expr(expr_state) + return f"(ite {cond_s} {then_s} {else_s})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtNArySameSort(SmtValue): + inputs: "tuple[SmtValue, ...]" + + @abstractmethod + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return str(...) + + def _expected_input_class(self): + return SmtValue + + def __init__(self, *inputs): + object.__setattr__(self, "inputs", inputs) + assert len(inputs) > 0, "not enough inputs" + + for i in inputs: + assert isinstance(i, self._expected_input_class()) + assert i.sort == self.input_sort, "all input sorts must match" + + @property + def input_sort(self): + return self.inputs[0].sort + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + args = ' '.join(i._smtlib2_expr(expr_state) for i in self.inputs) + return f"({op} {args})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtUnary(SmtValue): + inp: SmtValue + + @abstractmethod + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return str(...) + + def _expected_input_class(self): + return SmtValue + + def __init__(self, inp): + object.__setattr__(self, "inp", inp) + assert isinstance(inp, self._expected_input_class()) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + inp = self.inp._smtlib2_expr(expr_state) + return f"({op} {inp})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBinary(SmtValue): + lhs: SmtValue + rhs: SmtValue + + @abstractmethod + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return str(...) + + def _expected_input_class(self): + return SmtValue + + def _expected_input_lhs_class(self): + return self._expected_input_class() + + def _expected_input_rhs_class(self): + return self._expected_input_class() + + def __init__(self, lhs, rhs): + object.__setattr__(self, "lhs", lhs) + object.__setattr__(self, "rhs", rhs) + assert isinstance(lhs, self._expected_input_lhs_class()) + assert isinstance(rhs, self._expected_input_rhs_class()) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + lhs = self.lhs._smtlib2_expr(expr_state) + rhs = self.rhs._smtlib2_expr(expr_state) + return f"({op} {lhs} {rhs})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtSame(_SmtNArySameSort, SmtBool): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "=" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtDistinct(_SmtNArySameSort, SmtBool): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "distinct" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBoolConst(SmtBool): + value: bool + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "true" if self.value else "false" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBoolNot(_SmtUnary, SmtBool): + inp: SmtBool + + def _expected_input_class(self): + return SmtBool + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "not" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBoolNAryMinBinary(_SmtNArySameSort, SmtBool): + inputs: "tuple[SmtBool, ...]" + + def _expected_input_class(self): + return SmtBool + + def __new__(cls, *inputs): + return super().__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBoolImplies(_SmtBoolNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "=>" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBoolAnd(_SmtBoolNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "and" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBoolOr(_SmtBoolNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "or" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBoolXor(_SmtBoolNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "xor" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBoolITE(SmtITE, SmtBool): + then_v: SmtBool + else_v: SmtBool + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtBool) + assert isinstance(else_v, SmtBool) + return SmtITE.__new__(cls, cond, then_v, else_v) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtReal(SmtValue): + @staticmethod + def sort(): + return SmtSortReal() + + def __new__(cls, value=None): + if cls is not SmtReal: + assert value is None + return super().__new__(cls) + if isinstance(value, SmtInt): + return SmtIntToReal(value) + assert isinstance(value, Rational), ("value must be a rational " + "number, irrational numbers and " + "floats aren't yet supported.") + return SmtRealConst(value) + + def __init__(self, value=None): + # __init__ for IDE type checker + # value != None already handled by __new__ + assert value is None, "invalid argument" + return super().__init__() + + def __neg__(self): + return SmtRealNeg(self) + + def __pos__(self): + return self + + def __add__(self, other): + return SmtRealAdd(self, other) + + def __radd__(self, other): + return SmtRealAdd(other, self) + + def __sub__(self, other): + return SmtRealSub(self, other) + + def __rsub__(self, other): + return SmtRealSub(other, self) + + def __mul__(self, other): + return SmtRealMul(self, other) + + def __rmul__(self, other): + return SmtRealMul(other, self) + + def __truediv__(self, other): + return SmtRealDiv(self, other) + + def __rtruediv__(self, other): + return SmtRealDiv(other, self) + + def __eq__(self, other): + return SmtSame(self, other) + + def __ne__(self, other): + return SmtDistinct(self, other) + + def __lt__(self, other): + return SmtRealLt(self, other) + + def __le__(self, other): + return SmtRealLE(self, other) + + def __gt__(self, other): + return SmtRealGt(self, other) + + def __ge__(self, other): + return SmtRealGE(self, other) + + def __abs__(self): + return self.__lt__(SmtReal(0)).ite(-self, self) + + def __floor__(self): + return SmtRealToInt(self) + + def __trunc__(self): + return self.__lt__(SmtReal(0)).ite(-(-self).__floor__(), + self.__floor__()) + + def __ceil__(self): + return -(-self).__floor__() + + def is_int(self): + return SmtRealIsInt(self) + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRealConst(SmtReal): + value: Fraction + + def __init__(self, value): + assert isinstance(value, Rational), ("value must be a rational " + "number, irrational numbers and " + "floats aren't yet supported.") + value = Fraction(value) + object.__setattr__(self, "value", value) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + retval = f"{abs(self.value.numerator)}.0" + if self.value.numerator < 0: + retval = f"(- {retval})" + if self.value.denominator != 1: + retval = f"(/ {retval} {self.value.denominator}.0)" + return retval + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRealNeg(_SmtUnary, SmtReal): + inp: SmtReal + + def _expected_input_class(self): + return SmtReal + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "-" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRealIsInt(_SmtUnary, SmtBool): + inp: SmtReal + + def _expected_input_class(self): + return SmtReal + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "is_int" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtRealNAryMinBinary(_SmtNArySameSort, SmtReal): + inputs: "tuple[SmtReal, ...]" + + def _expected_input_class(self): + return SmtReal + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtRealCompareOp(_SmtNArySameSort, SmtBool): + inputs: "tuple[SmtReal, ...]" + + def _expected_input_class(self): + return SmtReal + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealLt(_SmtRealCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "<" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealLE(_SmtRealCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "<=" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealGt(_SmtRealCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return ">" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealGE(_SmtRealCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return ">=" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealAdd(_SmtRealNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "+" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealSub(_SmtRealNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "-" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealMul(_SmtRealNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "*" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtRealDiv(_SmtRealNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "/" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRealITE(SmtITE, SmtReal): + then_v: SmtReal + else_v: SmtReal + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtReal) + assert isinstance(else_v, SmtReal) + return SmtITE.__new__(cls, cond, then_v, else_v) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtInt(SmtValue): + @staticmethod + def sort(): + return SmtSortInt() + + def __new__(cls, value=None): + if cls is not SmtInt: + assert value is None + return super().__new__(cls) + assert isinstance(value, int), "value must be an integer" + return SmtIntConst(value) + + def __init__(self, value=None): + # __init__ for IDE type checker + # value != None already handled by __new__ + assert value is None, "invalid argument" + return super().__init__() + + def __neg__(self): + return SmtIntNeg(self) + + def __pos__(self): + return self + + def __add__(self, other): + return SmtIntAdd(self, other) + + def __radd__(self, other): + return SmtIntAdd(other, self) + + def __sub__(self, other): + return SmtIntSub(self, other) + + def __rsub__(self, other): + return SmtIntSub(other, self) + + def __mul__(self, other): + return SmtIntMul(self, other) + + def __rmul__(self, other): + return SmtIntMul(other, self) + + def euclid_div(self, other): + return SmtIntEuclidDiv(self, other) + + def euclid_rem(self, other): + return SmtIntEuclidRem(self, other) + + def __eq__(self, other): + return SmtSame(self, other) + + def __ne__(self, other): + return SmtDistinct(self, other) + + def __lt__(self, other): + return SmtIntLt(self, other) + + def __le__(self, other): + return SmtIntLE(self, other) + + def __gt__(self, other): + return SmtIntGt(self, other) + + def __ge__(self, other): + return SmtIntGE(self, other) + + def __abs__(self): + return SmtIntAbs(self) + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtIntConst(SmtInt): + value: int + + def __init__(self, value): + assert isinstance(value, int), "value must be an integer" + object.__setattr__(self, "value", value) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + retval = f"{abs(self.value.numerator)}" + if self.value.numerator < 0: + retval = f"(- {retval})" + return retval + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtIntNeg(_SmtUnary, SmtInt): + inp: SmtInt + + def _expected_input_class(self): + return SmtInt + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "-" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtIntAbs(_SmtUnary, SmtInt): + inp: SmtInt + + def _expected_input_class(self): + return SmtInt + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "abs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtIntToReal(_SmtUnary, SmtReal): + inp: SmtInt + + def _expected_input_class(self): + return SmtInt + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "to_real" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRealToInt(_SmtUnary, SmtInt): + inp: SmtReal + + def _expected_input_class(self): + return SmtReal + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "to_int" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtIntNAryMinBinary(_SmtNArySameSort, SmtInt): + inputs: "tuple[SmtInt, ...]" + + def _expected_input_class(self): + return SmtInt + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtIntBinary(_SmtBinary, SmtInt): + lhs: SmtInt + rhs: SmtInt + + def _expected_input_class(self): + return SmtInt + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtIntCompareOp(_SmtNArySameSort, SmtBool): + inputs: "tuple[SmtInt, ...]" + + def _expected_input_class(self): + return SmtInt + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntLt(_SmtIntCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "<" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntLE(_SmtIntCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "<=" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntGt(_SmtIntCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return ">" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntGE(_SmtIntCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return ">=" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntAdd(_SmtIntNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "+" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntSub(_SmtIntNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "-" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntMul(_SmtIntNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "*" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntEuclidDiv(_SmtIntNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "div" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtIntEuclidRem(_SmtIntBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "rem" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtIntITE(SmtITE, SmtInt): + then_v: SmtInt + else_v: SmtInt + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtInt) + assert isinstance(else_v, SmtInt) + return SmtITE.__new__(cls, cond, then_v, else_v) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVec(SmtValue): + width: int + + def sort(self): + return SmtSortBitVec(self.width) + + @staticmethod + def __make_bitvec(value=None, *, width=None): + if isinstance(value, int): + assert width is not None, "missing width" + assert isinstance(width, int) and width > 0, "invalid width" + return SmtBitVecConst(value, width=width) + assert width is None, "can't give both width and nMigen Value" + return SmtBitVecInput(value) + + def __new__(cls, *args, **kwargs): + if cls is not SmtInt: + return super().__new__(cls) + return SmtBitVec.__make_bitvec(*args, **kwargs) + + def __init__(self, value=None, *, width=None): + # __init__ for IDE type checker + # value != None already handled by __new__ + assert value is None, "invalid argument" + assert isinstance(width, int) and width > 0, "invalid width" + object.__setattr__(self, "width", width) + super().__init__() + + def __neg__(self): + return SmtBitVecNeg(self) + + def __invert__(self): + return SmtBitVecNot(self) + + def __pos__(self): + return self + + def __add__(self, other): + return SmtBitVecAdd(self, other) + + def __radd__(self, other): + return SmtBitVecAdd(other, self) + + def __sub__(self, other): + assert isinstance(other, SmtBitVec) + return self + -other + + def __rsub__(self, other): + assert isinstance(other, SmtBitVec) + return other + -self + + def __mul__(self, other): + return SmtBitVecMul(self, other) + + def __rmul__(self, other): + return SmtBitVecMul(other, self) + + def __floordiv__(self, other): + return SmtBitVecDiv(self, other) + + def __rfloordiv__(self, other): + return SmtBitVecDiv(other, self) + + def __mod__(self, other): + return SmtBitVecRem(self, other) + + def __rmod__(self, other): + return SmtBitVecRem(other, self) + + def __divmod__(self, other): + return self // other, self % other + + def __eq__(self, other): + return SmtSame(self, other) + + def __ne__(self, other): + return SmtDistinct(self, other) + + def __lt__(self, other): + return SmtBitVecLt(self, other) + + def __le__(self, other): + assert isinstance(other, SmtBitVec) + return other >= self + + def __gt__(self, other): + assert isinstance(other, SmtBitVec) + return other < self + + def __ge__(self, other): + return SmtBoolNot(SmtBitVecLt(self, other)) + + def __abs__(self): + return self.__lt__(SmtBitVec(0, width=self.width)).ite(-self, self) + + def __and__(self, other): + return SmtBitVecAnd(self, other) + + def __rand__(self, other): + return SmtBitVecAnd(other, self) + + def __or__(self, other): + return SmtBitVecOr(self, other) + + def __ror__(self, other): + return SmtBitVecOr(other, self) + + def __xor__(self, other): + return SmtBitVecXor(self, other) + + def __rxor__(self, other): + return SmtBitVecXor(other, self) + + def __lshift__(self, other): + return SmtBitVecLShift(self, other) + + def __rlshift__(self, other): + return SmtBitVecLShift(other, self) + + def __rshift__(self, other): + return SmtBitVecRShift(self, other) + + def __rrshift__(self, other): + return SmtBitVecRShift(other, self) + + def __getitem__(self, index): + if isinstance(index, int): + if index < 0: + index += self.width + return SmtBitVecExtract(self, range(index, index + 1)) + assert isinstance(index, slice) + r = range(*index.indices(self.width)) + assert len(r) > 0, "empty slice" + if r.step != 1: + return SmtBitVecConcat(self[i] for i in r) + return SmtBitVecExtract(self, r) + + def to_signal(self): + return SmtBitVecToSignal(self) + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecConcat(SmtBitVec): + inputs: "tuple[SmtBitVec, ...]" + """inputs in lsb to msb order""" + + def __init__(self, *inputs): + inputs = tuple(flatten(inputs)) + object.__setattr__(self, "inputs", inputs) + assert len(inputs) > 0, "not enough inputs" + + width = 0 + for i in inputs: + assert isinstance(i, SmtBitVec) + width += i.width + + super().__init__(width=width) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + args = [i._smtlib2_expr(expr_state) for i in self.inputs] + return f"(concat {' '.join(reversed(args))})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBitVecExtract(_SmtUnary, SmtBitVec): + inp: SmtBitVec + bit_range: range + + def __init__(self, inp, bit_range): + assert isinstance(inp, SmtBitVec) + if isinstance(bit_range, int): + if bit_range < 0: + bit_range += inp.width + bit_range = range(bit_range, bit_range + 1) + assert isinstance(bit_range, range) + assert len(bit_range) > 0, "empty range" + assert bit_range.step == 1, "unsupported range step" + assert 0 <= bit_range.start < bit_range.stop <= inp.width, \ + "bit range out-of-range" + _SmtUnary.__init__(self, inp) + SmtBitVec.__init__(self, width=len(bit_range)) + + def _expected_input_class(self): + return SmtBitVec + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ extract {self.bit_range.stop - 1} {self.bit_range.start})" + + +class SmtBitVecToSignal(dsl.Elaboratable): + def __init__(self, bv): + if isinstance(bv, SmtBool): + bv = bv.to_bit_vec() + assert isinstance(bv, SmtBitVec) + self.bv = bv + self.o_sig = ast.Signal(ast.unsigned(self.bv.width)) + self.__expr_state = _ExprState() + self.__expr = bv._smtlib2_expr(self.__expr_state) + if self.__expr_state.input_width == 0: + # can't have 0-width input + self.__expr_state.input_bit_ranges[ast.Const(0, 1)] = range(1) + self.__expr_state.input_width = 1 + self.__instance = ir.Instance( + "$smtlib2_expr", + p_A_WIDTH=self.__expr_state.input_width, + p_Y_WIDTH=self.bv.width, + p_EXPR=self.__expr, + i_A=ast.Cat(self.__expr_state.input_bit_ranges.keys()), + o_Y=self.o_sig) + + def elaborate(self, platform): + return self.__instance + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBitVecInput(SmtBitVec): + value: ast.Value + + def __init_subclass__(cls): + try: + _ = SmtBitVecConst + except AttributeError: + # only possible when we're defining SmtBitVecConst + return + raise TypeError("subclassing SmtBitVecInput isn't supported") + + def __new__(cls, value): + value = ast.Value.cast(value) + assert isinstance(value, ast.Value) + if isinstance(value, ast.Const): + if cls is not SmtBitVecConst: + return SmtBitVecConst(value) + retval = super().__new__(cls) + object.__setattr__(retval, "value", value) + return retval + + def __init__(self, value): + # self.value assigned in __new__ + super().__init__(width=self.value.shape().width) # type: ignore + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + if self.value not in expr_state.input_bit_ranges: + start = expr_state.input_width + expr_state.input_width += self.width + r = range(start, expr_state.input_width) + expr_state.input_bit_ranges[self.value] = r + else: + r = expr_state.input_bit_ranges[self.value] + return f"((_ extract {r.stop - 1} {r.start}) A)" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBitVecConst(SmtBitVecInput): + value: ast.Const + + def __new__(cls, value, *, width=None): + if isinstance(value, ast.Const): + assert width is None + # decompose -- needed since we switch to unsigned + width = value.width + value = value.value + assert isinstance(value, int), "value must be an integer" + assert isinstance(width, int) and width > 0, "invalid width" + value = ast.Const(value, ast.unsigned(width)) + return super().__new__(cls, value) + + def __init__(self, value, *, width=None): + # rest of logic already handled by __new__ + super().__init__(self.value) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + hex_digit_count, hex_digit_count_rem = divmod(self.width, 4) + if hex_digit_count_rem == 0: + digit_count = hex_digit_count + digits = hex(self.value.value) + else: + digit_count = self.width + digits = bin(self.value.value) + return "#" + digits[1] + digits[2:].zfill(digit_count) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBitVecUnary(_SmtUnary, SmtBitVec): + inp: SmtBitVec + + def __init__(self, inp): + assert isinstance(inp, SmtBitVec) + _SmtUnary.__init__(self, inp) + SmtBitVec.__init__(self, width=inp.width) + + def _expected_input_class(self): + return SmtBitVec + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecNeg(_SmtBitVecUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvneg" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecNot(_SmtBitVecUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvnot" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBitVecToNat(_SmtUnary, SmtInt): + inp: SmtBitVec + + def _expected_input_class(self): + return SmtBitVec + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bv2nat" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtNatToBitVec(_SmtUnary, SmtBitVec): + inp: SmtInt + + def _expected_input_class(self): + return SmtInt + + def __init__(self, inp, *, width): + assert isinstance(inp, SmtInt) + assert isinstance(width, int) and width > 0, "invalid width" + _SmtUnary.__init__(self, inp) + SmtBitVec.__init__(self, width=width) + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ nat2bv {self.width})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBitVecNAryMinBinary(_SmtNArySameSort, SmtBitVec): + inputs: "tuple[SmtBitVec, ...]" + + def _expected_input_class(self): + return SmtBitVec + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + _SmtNArySameSort.__init__(self, *inputs) + assert len(self.inputs) >= 2, "not enough inputs" + SmtBitVec.__init__(self, width=self.inputs[0].width) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBitVecBinary(_SmtBinary, SmtBitVec): + lhs: SmtBitVec + rhs: SmtBitVec + + def __init__(self, lhs, rhs): + _SmtBinary.__init__(self, lhs, rhs) + SmtBitVec.__init__(self, width=self.lhs.width) + + def _expected_input_class(self): + return SmtBitVec + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtBitVecCompareOp(_SmtNArySameSort, SmtBool): + inputs: "tuple[SmtBitVec, ...]" + + def _expected_input_class(self): + return SmtBitVec + + def __new__(cls, *inputs): + return _SmtNArySameSort.__new__(cls) + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecLt(_SmtBitVecCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvult" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecAdd(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvadd" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecAnd(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvand" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecOr(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvor" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecXor(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvxor" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecMul(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvmul" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecDiv(_SmtBitVecNAryMinBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvudiv" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecRem(_SmtBitVecBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvurem" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecLShift(_SmtBitVecBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvshl" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtBitVecRShift(_SmtBitVecBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "bvlshr" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtBitVecITE(SmtITE, SmtBitVec): + then_v: SmtBitVec + else_v: SmtBitVec + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtBitVec) + assert isinstance(else_v, SmtBitVec) + return SmtITE.__new__(cls, cond, then_v, else_v) + + def __init__(self, cond, then_v, else_v): + SmtITE.__init__(self, cond, then_v, else_v) + SmtBitVec.__init__(self, width=self.then_v.width) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRoundingMode(SmtValue): + def __new__(cls, value=None): + if cls is SmtRoundingMode: + return SmtRoundingModeConst(value) + return super().__new__(cls) + + def __init__(self, value=None): + # __init__ for IDE type checker + # value != None already handled by __new__ + assert value is None, "invalid argument" + super().__init__() + + @staticmethod + def sort(): + return SmtSortRoundingMode() + + def __eq__(self, other): + return self.same(other) + + def __ne__(self, other): + return self.distinct(other) + + +@final +class RoundingModeEnum(enum.Enum): + RNE = "RNE" + ROUND_DEFAULT = RNE + ROUND_NEAREST_TIES_TO_EVEN = RNE + RNA = "RNA" + ROUND_NEAREST_TIES_TO_AWAY = RNA + RTP = "RTP" + ROUND_TOWARD_POSITIVE = RTP + RTN = "RTN" + ROUND_TOWARD_NEGATIVE = RTN + RTZ = "RTZ" + ROUND_TOWARD_ZERO = RTZ + + def __init__(self, value): + assert isinstance(value, str) + self._smtlib2_expr = value + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRoundingModeConst(SmtValue): + value: RoundingModeEnum + + def __new__(cls, value): + value = RoundingModeEnum(value) + try: + if value is RoundingModeEnum.RNE: + return RNE + elif value is RoundingModeEnum.RNA: + return RNA + elif value is RoundingModeEnum.RTP: + return RTP + elif value is RoundingModeEnum.RTN: + return RTN + else: + assert value is RoundingModeEnum.RTZ + return RTZ + except AttributeError: + # instance not created yet + return super().__new__(cls) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return self.value._smtlib2_expr + + +RNE = SmtRoundingModeConst(RoundingModeEnum.RNE) +ROUND_NEAREST_TIES_TO_EVEN = SmtRoundingModeConst( + RoundingModeEnum.ROUND_NEAREST_TIES_TO_EVEN) +ROUND_DEFAULT = SmtRoundingModeConst( + RoundingModeEnum.ROUND_DEFAULT) +RNA = SmtRoundingModeConst(RoundingModeEnum.RNA) +ROUND_NEAREST_TIES_TO_AWAY = SmtRoundingModeConst( + RoundingModeEnum.ROUND_NEAREST_TIES_TO_AWAY) +RTP = SmtRoundingModeConst(RoundingModeEnum.RTP) +ROUND_TOWARD_POSITIVE = SmtRoundingModeConst( + RoundingModeEnum.ROUND_TOWARD_POSITIVE) +RTN = SmtRoundingModeConst(RoundingModeEnum.RTN) +ROUND_TOWARD_NEGATIVE = SmtRoundingModeConst( + RoundingModeEnum.ROUND_TOWARD_NEGATIVE) +RTZ = SmtRoundingModeConst(RoundingModeEnum.RTZ) +ROUND_TOWARD_ZERO = SmtRoundingModeConst(RoundingModeEnum.ROUND_TOWARD_ZERO) + +# check all rounding modes are accounted for: +# (we check rather than just make everything here so IDEs can more easily +# see the definitions) +for name, rm in RoundingModeEnum.__members__.items(): + assert globals()[name] is SmtRoundingModeConst(rm), f"mismatch {name}" +del rm, name + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtRoundingModeITE(SmtITE, SmtRoundingMode): + then_v: SmtRoundingMode + else_v: SmtRoundingMode + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtRoundingMode) + assert isinstance(else_v, SmtRoundingMode) + return SmtITE.__new__(cls, cond, then_v, else_v) + + +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPoint(SmtValue): + eb: int + sb: int + + def sort(self): + return SmtSortFloatingPoint(eb=self.eb, sb=self.sb) + + def __init__(self, *, eb=None, sb=None, sort=None): + if sort is not None: + assert eb is sb is None and isinstance(sort, SmtSortFloatingPoint) + eb = sort.eb + sb = sort.sb + assert isinstance(self.eb, int) and self.eb > 1, "invalid eb" + assert isinstance(self.sb, int) and self.sb > 1, "invalid sb" + object.__setattr__(self, "eb", eb) + object.__setattr__(self, "sb", sb) + + @staticmethod + def nan(*, eb=None, sb=None, sort=None): + return SmtFloatingPointNaN(eb=eb, sb=sb, sort=sort) + + @staticmethod + def zero(*, sign=None, eb=None, sb=None, sort=None): + return SmtFloatingPointZero(sign=sign, eb=eb, sb=sb, sort=sort) + + @staticmethod + def inf(*, sign=None, eb=None, sb=None, sort=None): + return SmtFloatingPointInf(sign=sign, eb=eb, sb=sb, sort=sort) + + @staticmethod + def from_parts(*, sign, exponent, mantissa): + return SmtFloatingPointFromParts(sign=sign, exponent=exponent, + mantissa=mantissa) + + def __abs__(self): + return SmtFloatingPointAbs(self) + + def __neg__(self): + return SmtFloatingPointNeg(self) + + def __pos__(self): + return self + + def add(self, other, *, rm): + return SmtFloatingPointAdd(self, other, rm=rm) + + def sub(self, other, *, rm): + return SmtFloatingPointSub(self, other, rm=rm) + + def mul(self, other, *, rm): + return SmtFloatingPointMul(self, other, rm=rm) + + def div(self, other, *, rm): + return SmtFloatingPointDiv(self, other, rm=rm) + + def fma(self, factor, term, *, rm): + """returns `self * factor + term`""" + return SmtFloatingPointFma(self, factor, term, rm=rm) + + def sqrt(self, *, rm): + return SmtFloatingPointSqrt(self, rm=rm) + + def rem(self, other, *, rm): + return SmtFloatingPointRem(self, other, rm=rm) + + def round_to_integral(self, *, rm): + return SmtFloatingPointRoundToIntegral(self, rm=rm) + + def min(self, other): + return SmtFloatingPointMin(self, other) + + def max(self, other): + return SmtFloatingPointMax(self, other) + + def __eq__(self, other): + return SmtFloatingPointEq(self, other) + + def __ne__(self, other): + return ~SmtFloatingPointEq(self, other) + + def __lt__(self, other): + return ~SmtFloatingPointLt(self, other) + + def __le__(self, other): + return ~SmtFloatingPointLE(self, other) + + def __gt__(self, other): + return ~SmtFloatingPointGt(self, other) + + def __ge__(self, other): + return ~SmtFloatingPointGE(self, other) + + def is_normal(self): + return SmtFloatingPointIsNormal(self) + + def is_subnormal(self): + return SmtFloatingPointIsSubnormal(self) + + def is_zero(self): + return SmtFloatingPointIsZero(self) + + def is_infinite(self): + return SmtFloatingPointIsInfinite(self) + + def is_nan(self): + return SmtFloatingPointIsNaN(self) + + def is_negative(self): + return SmtFloatingPointIsNegative(self) + + def is_positive(self): + return SmtFloatingPointIsPositive(self) + + @staticmethod + def from_signed_bv(bv, *, rm, eb=None, sb=None, sort=None): + return SmtFloatingPointFromSignedBV(bv, rm=rm, eb=eb, sb=sb, + sort=sort) + + @staticmethod + def from_unsigned_bv(bv, *, rm, eb=None, sb=None, sort=None): + return SmtFloatingPointFromUnsignedBV(bv, rm=rm, eb=eb, sb=sb, + sort=sort) + + @staticmethod + def from_real(value, *, rm, eb=None, sb=None, sort=None): + return SmtFloatingPointFromReal(value, rm=rm, eb=eb, sb=sb, + sort=sort) + + @staticmethod + def from_int(value, *, rm, eb=None, sb=None, sort=None): + return SmtFloatingPointFromReal(SmtIntToReal(value), rm=rm, eb=eb, + sb=sb, sort=sort) + + @staticmethod + def from_fp(value, *, rm, eb=None, sb=None, sort=None): + return SmtFloatingPointFromFP(value, rm=rm, eb=eb, + sb=sb, sort=sort) + + @staticmethod + def from_bits(bv, *, eb=None, sb=None, sort=None): + return SmtFloatingPointFromBits(bv, eb=eb, sb=sb, sort=sort) + + def to_real(self): + return SmtFloatingPointToReal(self) + + def __ceil__(self): + return SmtFloatingPointToReal(self).__ceil__() + + def __floor__(self): + return SmtFloatingPointToReal(self).__floor__() + + def __trunc__(self): + return SmtFloatingPointToReal(self).__trunc__() + + def to_unsigned_bv(self, *, width): + return SmtFloatingPointToUnsignedBV(self, width=width) + + def to_signed_bv(self, *, width): + return SmtFloatingPointToSignedBV(self, width=width) + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointITE(SmtITE, SmtFloatingPoint): + then_v: SmtFloatingPoint + else_v: SmtFloatingPoint + + def __new__(cls, cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtFloatingPoint) + assert isinstance(else_v, SmtFloatingPoint) + return SmtITE.__new__(cls, cond, then_v, else_v) + + def __init__(self, cond, then_v, else_v): + SmtITE.__init__(self, cond, then_v, else_v) + SmtFloatingPoint.__init__(self, sort=self.then_v.sort())