From: Jacob Lifshay Date: Fri, 20 May 2022 04:53:16 +0000 (-0700) Subject: working on implementing smtlib2 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=5a2849ba81d5ed341bf4d61d6bae1058b4ffa143;p=nmigen.git working on implementing smtlib2 --- diff --git a/.gitignore b/.gitignore index d11a3eb..8dc0a8f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ __pycache__/ # coverage /.coverage /htmlcov +/cov.xml +/coverage.xml # tests /tests/spec_*/ diff --git a/nmigen/hdl/smtlib2.py b/nmigen/hdl/smtlib2.py index b93bdec..04846dd 100644 --- a/nmigen/hdl/smtlib2.py +++ b/nmigen/hdl/smtlib2.py @@ -6,25 +6,151 @@ from numbers import Rational from .._utils import final, flatten from . import ast, dsl, ir from typing import overload, TYPE_CHECKING -if TYPE_CHECKING: +if TYPE_CHECKING: # :nobr: # make typechecker check final - from typing import final + from typing import final # :nocov: __all__ = [ - # FIXME + 'RNA', + 'RNE', + 'ROUND_DEFAULT', + 'ROUND_NEAREST_TIES_TO_AWAY', + 'ROUND_NEAREST_TIES_TO_EVEN', + 'ROUND_TOWARD_NEGATIVE', + 'ROUND_TOWARD_POSITIVE', + 'ROUND_TOWARD_ZERO', + 'RTN', + 'RTP', + 'RTZ', + 'RoundingModeEnum', + 'SmtBitVec', + 'SmtBitVecAdd', + 'SmtBitVecAnd', + 'SmtBitVecConcat', + 'SmtBitVecConst', + 'SmtBitVecDiv', + 'SmtBitVecExtract', + 'SmtBitVecITE', + 'SmtBitVecInput', + 'SmtBitVecLShift', + 'SmtBitVecLt', + 'SmtBitVecMul', + 'SmtBitVecNeg', + 'SmtBitVecNot', + 'SmtBitVecOr', + 'SmtBitVecRShift', + 'SmtBitVecRem', + 'SmtBitVecToNat', + 'SmtBitVecToSignal', + 'SmtBitVecXor', + 'SmtBool', + 'SmtBoolAnd', + 'SmtBoolConst', + 'SmtBoolITE', + 'SmtBoolImplies', + 'SmtBoolNot', + 'SmtBoolOr', + 'SmtBoolXor', + 'SmtDistinct', + 'SmtFloatingPoint', + 'SmtFloatingPointAbs', + 'SmtFloatingPointAdd', + 'SmtFloatingPointDiv', + 'SmtFloatingPointEq', + 'SmtFloatingPointFma', + 'SmtFloatingPointFromBits', + 'SmtFloatingPointFromFP', + 'SmtFloatingPointFromParts', + 'SmtFloatingPointFromReal', + 'SmtFloatingPointFromSignedBV', + 'SmtFloatingPointFromUnsignedBV', + 'SmtFloatingPointGE', + 'SmtFloatingPointGt', + 'SmtFloatingPointITE', + 'SmtFloatingPointIsInfinite', + 'SmtFloatingPointIsNaN', + 'SmtFloatingPointIsNegative', + 'SmtFloatingPointIsNormal', + 'SmtFloatingPointIsPositive', + 'SmtFloatingPointIsSubnormal', + 'SmtFloatingPointIsZero', + 'SmtFloatingPointLE', + 'SmtFloatingPointLt', + 'SmtFloatingPointMax', + 'SmtFloatingPointMin', + 'SmtFloatingPointMul', + 'SmtFloatingPointNaN', + 'SmtFloatingPointNeg', + 'SmtFloatingPointNegInfinity', + 'SmtFloatingPointNegZero', + 'SmtFloatingPointPosInfinity', + 'SmtFloatingPointPosZero', + 'SmtFloatingPointRem', + 'SmtFloatingPointRoundToIntegral', + 'SmtFloatingPointSqrt', + 'SmtFloatingPointSub', + 'SmtFloatingPointToReal', + 'SmtFloatingPointToSignedBV', + 'SmtFloatingPointToUnsignedBV', + 'SmtITE', + 'SmtInt', + 'SmtIntAbs', + 'SmtIntAdd', + 'SmtIntConst', + 'SmtIntEuclidDiv', + 'SmtIntEuclidRem', + 'SmtIntGE', + 'SmtIntGt', + 'SmtIntITE', + 'SmtIntLE', + 'SmtIntLt', + 'SmtIntMul', + 'SmtIntNeg', + 'SmtIntSub', + 'SmtIntToReal', + 'SmtNatToBitVec', + 'SmtReal', + 'SmtRealAdd', + 'SmtRealConst', + 'SmtRealDiv', + 'SmtRealGE', + 'SmtRealGt', + 'SmtRealITE', + 'SmtRealIsInt', + 'SmtRealLE', + 'SmtRealLt', + 'SmtRealMul', + 'SmtRealNeg', + 'SmtRealSub', + 'SmtRealToInt', + 'SmtRoundingMode', + 'SmtRoundingModeConst', + 'SmtRoundingModeITE', + 'SmtSame', + 'SmtSort', + 'SmtSortBitVec', + 'SmtSortBool', + 'SmtSortFloat128', + 'SmtSortFloat16', + 'SmtSortFloat32', + 'SmtSortFloat64', + 'SmtSortFloatingPoint', + 'SmtSortInt', + 'SmtSortReal', + 'SmtSortRoundingMode', + 'SmtValue', ] @dataclass(frozen=True, unsafe_hash=True, eq=True) -class SmtSort(meta=ABCMeta): +class SmtSort(metaclass=ABCMeta): @abstractmethod def _smtlib2_expr(self, expr_state): - return str(...) + return str(...) # :nocov: @abstractmethod - @staticmethod - def _ite_class(): - return SmtITE + def _ite_class(self): + return SmtITE # :nocov: @final @@ -145,15 +271,14 @@ class _ExprState: @dataclass(frozen=True, unsafe_hash=False, eq=False) -class SmtValue(ast.DUID, meta=ABCMeta): +class SmtValue(ast.DUID, metaclass=ABCMeta): @abstractmethod - @staticmethod - def sort(): - return SmtSort() + def sort(self): + return SmtSort() # type: ignore :nocov: @abstractmethod def _smtlib2_expr(self, expr_state): - return str(...) + return str(...) # :nocov: def same(self, other, *rest): return SmtSame(self, other, *rest) @@ -171,35 +296,40 @@ class SmtBool(SmtValue): def sort(): return SmtSortBool() - def __new__(cls, value=None): - if cls is not SmtBool: - assert value is None - return super().__new__(cls) - assert isinstance(value, bool) - return SmtBoolConst(value) + @staticmethod + def make(value): + value = ast.Value.cast(value) + if isinstance(value, ast.Const): + return SmtBoolConst(bool(value.value)) + return SmtBitVecInput(value).bool() # type deduction: @overload - def ite(self, then_v: "SmtBool", else_v: "SmtBool") -> "SmtBoolITE": ... + def ite(self, then_v: "SmtBool", + else_v: "SmtBool") -> "SmtBoolITE": ... # :nocov: + @overload - def ite(self, then_v: "SmtInt", else_v: "SmtInt") -> "SmtIntITE": ... + def ite(self, then_v: "SmtInt", + else_v: "SmtInt") -> "SmtIntITE": ... # :nocov: + @overload - def ite(self, then_v: "SmtReal", else_v: "SmtReal") -> "SmtRealITE": ... + def ite(self, then_v: "SmtReal", + else_v: "SmtReal") -> "SmtRealITE": ... # :nocov: @overload def ite(self, then_v: "SmtBitVec", else_v: "SmtBitVec") -> "SmtBitVecITE": - ... + ... # :nocov: @overload def ite(self, then_v: "SmtRoundingMode", - else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ... + else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ... # :nocov: @overload - def ite(self, then_v: "SmtFloatingPoint", - else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE": ... + def ite(self, then_v: "SmtFloatingPoint", else_v: "SmtFloatingPoint" + ) -> "SmtFloatingPointITE": ... # :nocov: - def ite(self, then_v, else_v): - return SmtITE(self, then_v, else_v) + def ite(self, then_v, else_v): # type: ignore + return SmtITE.make(self, then_v, else_v) def __invert__(self): return SmtBoolNot(self) @@ -232,8 +362,8 @@ class SmtBool(SmtValue): return SmtBoolImplies(self, *rest) def to_bit_vec(self): - return self.ite(SmtBitVec(1, width=1), - SmtBitVec(0, width=1)) + return self.ite(SmtBitVecConst(1, width=1), + SmtBitVecConst(0, width=1)) def to_signal(self): return self.to_bit_vec().to_signal() @@ -245,15 +375,57 @@ class SmtITE(SmtValue): then_v: SmtValue else_v: SmtValue - def __new__(cls, cond, then_v, else_v): + # type deduction: + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtBool", + else_v: "SmtBool") -> "SmtBoolITE": ... # :nocov: + + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtInt", + else_v: "SmtInt") -> "SmtIntITE": ... # :nocov: + + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtReal", + else_v: "SmtReal") -> "SmtRealITE": ... # :nocov: + + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtBitVec", + else_v: "SmtBitVec") -> "SmtBitVecITE": ... # :nocov: + + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtRoundingMode", + else_v: "SmtRoundingMode") -> "SmtRoundingModeITE": ... # :nocov: + + @overload + @staticmethod + def make(cond: "SmtBool", then_v: "SmtFloatingPoint", + else_v: "SmtFloatingPoint") -> "SmtFloatingPointITE": + ... # :nocov: + + @staticmethod + def make(cond, then_v, else_v): + assert isinstance(cond, SmtBool) + assert isinstance(then_v, SmtValue) + assert isinstance(else_v, SmtValue) + sort = then_v.sort() + assert sort == else_v.sort() + return sort._ite_class()(cond, then_v, else_v) # type: ignore + + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtValue) assert isinstance(else_v, SmtValue) sort = then_v.sort() assert sort == else_v.sort() - if cls is SmtITE: - return sort._ite_class()(cond, then_v, else_v) - return super().__new__(cls) + assert isinstance(self, sort._ite_class()) + object.__setattr__(self, "cond", cond) + object.__setattr__(self, "then_v", then_v) + object.__setattr__(self, "else_v", else_v) def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) @@ -264,13 +436,13 @@ class SmtITE(SmtValue): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtNArySameSort(SmtValue): +class _SmtNAry(SmtValue): inputs: "tuple[SmtValue, ...]" @abstractmethod def _smtlib2_expr_op(self, expr_state): - assert isinstance(expr_state, _ExprState) - return str(...) + assert isinstance(expr_state, _ExprState) # :nocov: + return str(...) # :nocov: def _expected_input_class(self): return SmtValue @@ -281,11 +453,11 @@ class _SmtNArySameSort(SmtValue): for i in inputs: assert isinstance(i, self._expected_input_class()) - assert i.sort == self.input_sort, "all input sorts must match" + assert i.sort() == self.input_sort, "all input sorts must match" @property def input_sort(self): - return self.inputs[0].sort + return self.inputs[0].sort() def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) @@ -300,8 +472,8 @@ class _SmtUnary(SmtValue): @abstractmethod def _smtlib2_expr_op(self, expr_state): - assert isinstance(expr_state, _ExprState) - return str(...) + assert isinstance(expr_state, _ExprState) # :nocov: + return str(...) # :nocov: def _expected_input_class(self): return SmtValue @@ -324,23 +496,22 @@ class _SmtBinary(SmtValue): @abstractmethod def _smtlib2_expr_op(self, expr_state): - assert isinstance(expr_state, _ExprState) - return str(...) + assert isinstance(expr_state, _ExprState) # :nocov: + return str(...) # :nocov: def _expected_input_class(self): return SmtValue - def _expected_input_lhs_class(self): - return self._expected_input_class() - - def _expected_input_rhs_class(self): - return self._expected_input_class() + @property + def input_sort(self): + return self.lhs.sort() def __init__(self, lhs, rhs): object.__setattr__(self, "lhs", lhs) object.__setattr__(self, "rhs", rhs) - assert isinstance(lhs, self._expected_input_lhs_class()) - assert isinstance(rhs, self._expected_input_rhs_class()) + assert isinstance(lhs, self._expected_input_class()) + assert isinstance(rhs, self._expected_input_class()) + assert lhs.sort() == rhs.sort() def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) @@ -352,7 +523,7 @@ class _SmtBinary(SmtValue): @final @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) -class SmtSame(_SmtNArySameSort, SmtBool): +class SmtSame(_SmtNAry, SmtBool): def _smtlib2_expr_op(self, expr_state): assert isinstance(expr_state, _ExprState) return "=" @@ -360,7 +531,7 @@ class SmtSame(_SmtNArySameSort, SmtBool): @final @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) -class SmtDistinct(_SmtNArySameSort, SmtBool): +class SmtDistinct(_SmtNAry, SmtBool): def _smtlib2_expr_op(self, expr_state): assert isinstance(expr_state, _ExprState) return "distinct" @@ -390,15 +561,12 @@ class SmtBoolNot(_SmtUnary, SmtBool): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtBoolNAryMinBinary(_SmtNArySameSort, SmtBool): +class _SmtBoolNAryMinBinary(_SmtNAry, SmtBool): inputs: "tuple[SmtBool, ...]" def _expected_input_class(self): return SmtBool - def __new__(cls, *inputs): - return super().__new__(cls) - def __init__(self, *inputs): super().__init__(*inputs) assert len(self.inputs) >= 2, "not enough inputs" @@ -442,11 +610,11 @@ class SmtBoolITE(SmtITE, SmtBool): then_v: SmtBool else_v: SmtBool - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtBool) assert isinstance(else_v, SmtBool) - return SmtITE.__new__(cls, cond, then_v, else_v) + super().__init__(cond, then_v, else_v) @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) @@ -455,10 +623,8 @@ class SmtReal(SmtValue): def sort(): return SmtSortReal() - def __new__(cls, value=None): - if cls is not SmtReal: - assert value is None - return super().__new__(cls) + @staticmethod + def make(value): if isinstance(value, SmtInt): return SmtIntToReal(value) assert isinstance(value, Rational), ("value must be a rational " @@ -466,12 +632,6 @@ class SmtReal(SmtValue): "floats aren't yet supported.") return SmtRealConst(value) - def __init__(self, value=None): - # __init__ for IDE type checker - # value != None already handled by __new__ - assert value is None, "invalid argument" - return super().__init__() - def __neg__(self): return SmtRealNeg(self) @@ -521,14 +681,14 @@ class SmtReal(SmtValue): return SmtRealGE(self, other) def __abs__(self): - return self.__lt__(SmtReal(0)).ite(-self, self) + return self.__lt__(SmtRealConst(0)).ite(-self, self) def __floor__(self): return SmtRealToInt(self) def __trunc__(self): - return self.__lt__(SmtReal(0)).ite(-(-self).__floor__(), - self.__floor__()) + return self.__lt__(SmtRealConst(0)).ite(-(-self).__floor__(), + self.__floor__()) def __ceil__(self): return -(-self).__floor__() @@ -586,30 +746,24 @@ class SmtRealIsInt(_SmtUnary, SmtBool): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtRealNAryMinBinary(_SmtNArySameSort, SmtReal): +class _SmtRealNAryMinBinary(_SmtNAry, SmtReal): inputs: "tuple[SmtReal, ...]" def _expected_input_class(self): return SmtReal - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - def __init__(self, *inputs): super().__init__(*inputs) assert len(self.inputs) >= 2, "not enough inputs" @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtRealCompareOp(_SmtNArySameSort, SmtBool): +class _SmtRealCompareOp(_SmtNAry, SmtBool): inputs: "tuple[SmtReal, ...]" def _expected_input_class(self): return SmtReal - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - def __init__(self, *inputs): super().__init__(*inputs) assert len(self.inputs) >= 2, "not enough inputs" @@ -685,11 +839,11 @@ class SmtRealITE(SmtITE, SmtReal): then_v: SmtReal else_v: SmtReal - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtReal) assert isinstance(else_v, SmtReal) - return SmtITE.__new__(cls, cond, then_v, else_v) + super().__init__(cond, then_v, else_v) @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) @@ -698,19 +852,11 @@ class SmtInt(SmtValue): def sort(): return SmtSortInt() - def __new__(cls, value=None): - if cls is not SmtInt: - assert value is None - return super().__new__(cls) + @staticmethod + def make(value): assert isinstance(value, int), "value must be an integer" return SmtIntConst(value) - def __init__(self, value=None): - # __init__ for IDE type checker - # value != None already handled by __new__ - assert value is None, "invalid argument" - return super().__init__() - def __neg__(self): return SmtIntNeg(self) @@ -833,15 +979,12 @@ class SmtRealToInt(_SmtUnary, SmtInt): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtIntNAryMinBinary(_SmtNArySameSort, SmtInt): +class _SmtIntNAryMinBinary(_SmtNAry, SmtInt): inputs: "tuple[SmtInt, ...]" def _expected_input_class(self): return SmtInt - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - def __init__(self, *inputs): super().__init__(*inputs) assert len(self.inputs) >= 2, "not enough inputs" @@ -857,15 +1000,12 @@ class _SmtIntBinary(_SmtBinary, SmtInt): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtIntCompareOp(_SmtNArySameSort, SmtBool): +class _SmtIntCompareOp(_SmtNAry, SmtBool): inputs: "tuple[SmtInt, ...]" def _expected_input_class(self): return SmtInt - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - def __init__(self, *inputs): super().__init__(*inputs) assert len(self.inputs) >= 2, "not enough inputs" @@ -949,11 +1089,11 @@ class SmtIntITE(SmtITE, SmtInt): then_v: SmtInt else_v: SmtInt - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtInt) assert isinstance(else_v, SmtInt) - return SmtITE.__new__(cls, cond, then_v, else_v) + super().__init__(cond, then_v, else_v) @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) @@ -964,7 +1104,7 @@ class SmtBitVec(SmtValue): return SmtSortBitVec(self.width) @staticmethod - def __make_bitvec(value=None, *, width=None): + def make(value=None, *, width=None): if isinstance(value, int): assert width is not None, "missing width" assert isinstance(width, int) and width > 0, "invalid width" @@ -972,15 +1112,7 @@ class SmtBitVec(SmtValue): assert width is None, "can't give both width and nMigen Value" return SmtBitVecInput(value) - def __new__(cls, *args, **kwargs): - if cls is not SmtInt: - return super().__new__(cls) - return SmtBitVec.__make_bitvec(*args, **kwargs) - - def __init__(self, value=None, *, width=None): - # __init__ for IDE type checker - # value != None already handled by __new__ - assert value is None, "invalid argument" + def __init__(self, *, width=None): assert isinstance(width, int) and width > 0, "invalid width" object.__setattr__(self, "width", width) super().__init__() @@ -1050,7 +1182,8 @@ class SmtBitVec(SmtValue): return SmtBoolNot(SmtBitVecLt(self, other)) def __abs__(self): - return self.__lt__(SmtBitVec(0, width=self.width)).ite(-self, self) + lt = self.__lt__(SmtBitVecConst(0, width=self.width)) + return lt.ite(-self, self) def __and__(self, other): return SmtBitVecAnd(self, other) @@ -1094,6 +1227,9 @@ class SmtBitVec(SmtValue): return SmtBitVecConcat(self[i] for i in r) return SmtBitVecExtract(self, r) + def bool(self): + return self != SmtBitVecConst(0, width=self.width) + def to_signal(self): return SmtBitVecToSignal(self) @@ -1182,23 +1318,23 @@ class SmtBitVecInput(SmtBitVec): def __init_subclass__(cls): try: _ = SmtBitVecConst - except AttributeError: + except NameError: # only possible when we're defining SmtBitVecConst return raise TypeError("subclassing SmtBitVecInput isn't supported") - def __new__(cls, value): + @staticmethod + def make(value): value = ast.Value.cast(value) assert isinstance(value, ast.Value) if isinstance(value, ast.Const): - if cls is not SmtBitVecConst: - return SmtBitVecConst(value) - retval = super().__new__(cls) - object.__setattr__(retval, "value", value) - return retval + return SmtBitVecConst(value) + return SmtBitVecInput(value) def __init__(self, value): - # self.value assigned in __new__ + value = ast.Value.cast(value) + assert isinstance(value, ast.Value) + object.__setattr__(self, "value", value) super().__init__(width=self.value.shape().width) # type: ignore def _smtlib2_expr(self, expr_state): @@ -1218,7 +1354,7 @@ class SmtBitVecInput(SmtBitVec): class SmtBitVecConst(SmtBitVecInput): value: ast.Const - def __new__(cls, value, *, width=None): + def __init__(self, value, *, width=None): if isinstance(value, ast.Const): assert width is None # decompose -- needed since we switch to unsigned @@ -1227,11 +1363,7 @@ class SmtBitVecConst(SmtBitVecInput): assert isinstance(value, int), "value must be an integer" assert isinstance(width, int) and width > 0, "invalid width" value = ast.Const(value, ast.unsigned(width)) - return super().__new__(cls, value) - - def __init__(self, value, *, width=None): - # rest of logic already handled by __new__ - super().__init__(self.value) + super().__init__(value) def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) @@ -1307,17 +1439,14 @@ class SmtNatToBitVec(_SmtUnary, SmtBitVec): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtBitVecNAryMinBinary(_SmtNArySameSort, SmtBitVec): +class _SmtBitVecNAryMinBinary(_SmtNAry, SmtBitVec): inputs: "tuple[SmtBitVec, ...]" def _expected_input_class(self): return SmtBitVec - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - def __init__(self, *inputs): - _SmtNArySameSort.__init__(self, *inputs) + _SmtNAry.__init__(self, *inputs) assert len(self.inputs) >= 2, "not enough inputs" SmtBitVec.__init__(self, width=self.inputs[0].width) @@ -1336,19 +1465,17 @@ class _SmtBitVecBinary(_SmtBinary, SmtBitVec): @dataclass(frozen=True, unsafe_hash=False, eq=False) -class _SmtBitVecCompareOp(_SmtNArySameSort, SmtBool): - inputs: "tuple[SmtBitVec, ...]" +class _SmtBitVecCompareOp(_SmtBinary, SmtBool): + lhs: SmtBitVec + rhs: SmtBitVec + + def __init__(self, lhs, rhs): + _SmtBinary.__init__(self, lhs, rhs) + SmtBool.__init__(self) def _expected_input_class(self): return SmtBitVec - def __new__(cls, *inputs): - return _SmtNArySameSort.__new__(cls) - - def __init__(self, *inputs): - super().__init__(*inputs) - assert len(self.inputs) >= 2, "not enough inputs" - @final @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) @@ -1436,29 +1563,19 @@ class SmtBitVecITE(SmtITE, SmtBitVec): then_v: SmtBitVec else_v: SmtBitVec - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtBitVec) assert isinstance(else_v, SmtBitVec) - return SmtITE.__new__(cls, cond, then_v, else_v) - - def __init__(self, cond, then_v, else_v): SmtITE.__init__(self, cond, then_v, else_v) SmtBitVec.__init__(self, width=self.then_v.width) @dataclass(frozen=True, unsafe_hash=False, eq=False) class SmtRoundingMode(SmtValue): - def __new__(cls, value=None): - if cls is SmtRoundingMode: - return SmtRoundingModeConst(value) - return super().__new__(cls) - - def __init__(self, value=None): - # __init__ for IDE type checker - # value != None already handled by __new__ - assert value is None, "invalid argument" - super().__init__() + @staticmethod + def make(value): + return SmtRoundingModeConst(value) @staticmethod def sort(): @@ -1492,11 +1609,11 @@ class RoundingModeEnum(enum.Enum): @final @dataclass(frozen=True, unsafe_hash=False, eq=False) -class SmtRoundingModeConst(SmtValue): +class SmtRoundingModeConst(SmtRoundingMode): value: RoundingModeEnum - def __new__(cls, value): - value = RoundingModeEnum(value) + def __new__(cls, *args, **kwargs): + value = RoundingModeEnum(*args, **kwargs) try: if value is RoundingModeEnum.RNE: return RNE @@ -1509,10 +1626,14 @@ class SmtRoundingModeConst(SmtValue): else: assert value is RoundingModeEnum.RTZ return RTZ - except AttributeError: + except NameError: # instance not created yet return super().__new__(cls) + def __init__(self, value): + assert isinstance(value, RoundingModeEnum) + object.__setattr__(self, "value", value) + def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) return self.value._smtlib2_expr @@ -1549,11 +1670,11 @@ class SmtRoundingModeITE(SmtITE, SmtRoundingMode): then_v: SmtRoundingMode else_v: SmtRoundingMode - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtRoundingMode) assert isinstance(else_v, SmtRoundingMode) - return SmtITE.__new__(cls, cond, then_v, else_v) + super().__init__(cond, then_v, else_v) @dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) @@ -1580,11 +1701,33 @@ class SmtFloatingPoint(SmtValue): @staticmethod def zero(*, sign=None, eb=None, sb=None, sort=None): - return SmtFloatingPointZero(sign=sign, eb=eb, sb=sb, sort=sort) + if sign is None: + return SmtFloatingPointPosZero(eb=eb, sb=sb, sort=sort) + if isinstance(sign, SmtBitVec): + assert sign.width == 1, "invalid sign width" + sign = sign.bool() + assert isinstance(sign, SmtBool) + return sign.ite(SmtFloatingPointNegZero(eb=eb, sb=sb, sort=sort), + SmtFloatingPointPosZero(eb=eb, sb=sb, sort=sort)) + + @staticmethod + def neg_zero(*, eb=None, sb=None, sort=None): + return SmtFloatingPointNegZero(eb=eb, sb=sb, sort=sort) @staticmethod - def inf(*, sign=None, eb=None, sb=None, sort=None): - return SmtFloatingPointInf(sign=sign, eb=eb, sb=sb, sort=sort) + def infinity(*, sign=None, eb=None, sb=None, sort=None): + if sign is None: + return SmtFloatingPointPosInfinity(eb=eb, sb=sb, sort=sort) + if isinstance(sign, SmtBitVec): + assert sign.width == 1, "invalid sign width" + sign = sign.bool() + assert isinstance(sign, SmtBool) + return sign.ite(SmtFloatingPointNegInfinity(eb=eb, sb=sb, sort=sort), + SmtFloatingPointPosInfinity(eb=eb, sb=sb, sort=sort)) + + @staticmethod + def neg_infinity(*, eb=None, sb=None, sort=None): + return SmtFloatingPointNegInfinity(eb=eb, sb=sb, sort=sort) @staticmethod def from_parts(*, sign, exponent, mantissa): @@ -1619,8 +1762,8 @@ class SmtFloatingPoint(SmtValue): def sqrt(self, *, rm): return SmtFloatingPointSqrt(self, rm=rm) - def rem(self, other, *, rm): - return SmtFloatingPointRem(self, other, rm=rm) + def rem(self, other): + return SmtFloatingPointRem(self, other) def round_to_integral(self, *, rm): return SmtFloatingPointRoundToIntegral(self, rm=rm) @@ -1638,16 +1781,16 @@ class SmtFloatingPoint(SmtValue): return ~SmtFloatingPointEq(self, other) def __lt__(self, other): - return ~SmtFloatingPointLt(self, other) + return SmtFloatingPointLt(self, other) def __le__(self, other): - return ~SmtFloatingPointLE(self, other) + return SmtFloatingPointLE(self, other) def __gt__(self, other): - return ~SmtFloatingPointGt(self, other) + return SmtFloatingPointGt(self, other) def __ge__(self, other): - return ~SmtFloatingPointGE(self, other) + return SmtFloatingPointGE(self, other) def is_normal(self): return SmtFloatingPointIsNormal(self) @@ -1724,12 +1867,510 @@ class SmtFloatingPointITE(SmtITE, SmtFloatingPoint): then_v: SmtFloatingPoint else_v: SmtFloatingPoint - def __new__(cls, cond, then_v, else_v): + def __init__(self, cond, then_v, else_v): assert isinstance(cond, SmtBool) assert isinstance(then_v, SmtFloatingPoint) assert isinstance(else_v, SmtFloatingPoint) - return SmtITE.__new__(cls, cond, then_v, else_v) - - def __init__(self, cond, then_v, else_v): SmtITE.__init__(self, cond, then_v, else_v) SmtFloatingPoint.__init__(self, sort=self.then_v.sort()) + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointNaN(SmtFloatingPoint): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ NaN {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointPosZero(SmtFloatingPoint): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ +zero {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointNegZero(SmtFloatingPoint): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ -zero {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointPosInfinity(SmtFloatingPoint): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ +oo {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointNegInfinity(SmtFloatingPoint): + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ -oo {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointFromParts(SmtFloatingPoint): + sign: SmtBitVec + exponent: SmtBitVec + mantissa: SmtBitVec + + def __init__(self, sign, exponent, mantissa): + if isinstance(sign, SmtBool): + sign = sign.to_bit_vec() + assert isinstance(sign, SmtBitVec) and sign.width == 1 + assert isinstance(exponent, SmtBitVec) + assert isinstance(mantissa, SmtBitVec) + super().__init__(eb=exponent.width, sb=mantissa.width + 1) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + sign = self.sign._smtlib2_expr(expr_state) + exponent = self.exponent._smtlib2_expr(expr_state) + mantissa = self.mantissa._smtlib2_expr(expr_state) + return f"(fp {sign} {exponent} {mantissa})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointUnary(_SmtUnary, SmtFloatingPoint): + inp: SmtFloatingPoint + + def __init__(self, inp): + assert isinstance(inp, SmtFloatingPoint) + _SmtUnary.__init__(self, inp) + SmtFloatingPoint.__init__(self, eb=inp.eb, sb=inp.sb) + + def _expected_input_class(self): + return SmtFloatingPoint + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointAbs(_SmtFloatingPointUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.abs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointNeg(_SmtFloatingPointUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.neg" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointBinaryRounded(_SmtBinary, SmtFloatingPoint): + lhs: SmtFloatingPoint + rhs: SmtFloatingPoint + rm: SmtRoundingMode + + def __init__(self, lhs, rhs, *, rm): + assert isinstance(rm, SmtRoundingMode) + _SmtBinary.__init__(self, lhs, rhs) + object.__setattr__(self, "rm", rm) + SmtFloatingPoint.__init__(self, eb=self.lhs.eb, sb=self.rhs.sb) + + def _expected_input_class(self): + return SmtFloatingPoint + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + rm = self.rm._smtlib2_expr(expr_state) + lhs = self.lhs._smtlib2_expr(expr_state) + rhs = self.rhs._smtlib2_expr(expr_state) + return f"({op} {rm} {lhs} {rhs})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointAdd(_SmtFloatingPointBinaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.add" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointSub(_SmtFloatingPointBinaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.sub" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointMul(_SmtFloatingPointBinaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.mul" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointDiv(_SmtFloatingPointBinaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.div" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointFma(SmtFloatingPoint): + """returns `self.factor1 * self.factor2 + self.term`""" + factor1: SmtFloatingPoint + factor2: SmtFloatingPoint + term: SmtFloatingPoint + rm: SmtRoundingMode + + def __init__(self, factor1, factor2, term, *, rm): + assert isinstance(factor1, SmtFloatingPoint) + assert isinstance(factor2, SmtFloatingPoint) + assert isinstance(term, SmtFloatingPoint) + assert factor1.sort() == factor2.sort() == term.sort(), "sort mismatch" + assert isinstance(rm, SmtRoundingMode) + object.__setattr__(self, "factor1", factor1) + object.__setattr__(self, "factor2", factor2) + object.__setattr__(self, "term", term) + object.__setattr__(self, "rm", rm) + super().__init__(eb=factor1.eb, sb=factor1.sb) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + rm = self.rm._smtlib2_expr(expr_state) + factor1 = self.factor1._smtlib2_expr(expr_state) + factor2 = self.factor2._smtlib2_expr(expr_state) + term = self.term._smtlib2_expr(expr_state) + return f"(fp {rm} {factor1} {factor2} {term})" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointUnaryRounded(_SmtUnary, SmtFloatingPoint): + inp: SmtFloatingPoint + rm: SmtRoundingMode + + def __init__(self, inp, *, rm): + assert isinstance(rm, SmtRoundingMode) + _SmtUnary.__init__(self, inp) + object.__setattr__(self, "rm", rm) + SmtFloatingPoint.__init__(self, eb=inp.eb, sb=inp.sb) + + def _expected_input_class(self): + return SmtFloatingPoint + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + rm = self.rm._smtlib2_expr(expr_state) + inp = self.inp._smtlib2_expr(expr_state) + return f"({op} {rm} {inp})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointSqrt(_SmtFloatingPointUnaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.sqrt" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointBinary(_SmtBinary, SmtFloatingPoint): + lhs: SmtFloatingPoint + rhs: SmtFloatingPoint + + def __init__(self, lhs, rhs): + _SmtBinary.__init__(self, lhs, rhs) + SmtFloatingPoint.__init__(self, eb=self.lhs.eb, sb=self.rhs.sb) + + def _expected_input_class(self): + return SmtFloatingPoint + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointRem(_SmtFloatingPointBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.rem" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointRoundToIntegral(_SmtFloatingPointUnaryRounded): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.roundToIntegral" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointMin(_SmtFloatingPointBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.min" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointMax(_SmtFloatingPointBinary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.max" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointCompareOp(_SmtNAry, SmtBool): + inputs: "tuple[SmtFloatingPoint, ...]" + + def _expected_input_class(self): + return SmtFloatingPoint + + def __init__(self, *inputs): + super().__init__(*inputs) + assert len(self.inputs) >= 2, "not enough inputs" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointEq(_SmtFloatingPointCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.eq" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointLt(_SmtFloatingPointCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.lt" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointLE(_SmtFloatingPointCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.leq" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointGt(_SmtFloatingPointCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.gt" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointGE(_SmtFloatingPointCompareOp): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.geq" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointToBoolUnary(_SmtUnary, SmtBool): + inp: SmtFloatingPoint + + def _expected_input_class(self): + return SmtFloatingPoint + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsNormal(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isNormal" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsSubnormal(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isSubnormal" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsZero(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isZero" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsInfinite(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isInfinite" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsNaN(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isNaN" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsNegative(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isNegative" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointIsPositive(_SmtFloatingPointToBoolUnary): + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.isPositive" + + +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class _SmtFloatingPointRoundFrom(_SmtUnary, SmtFloatingPoint): + rm: SmtRoundingMode + + def __init__(self, inp, *, rm, eb=None, sb=None, sort=None): + assert isinstance(rm, SmtRoundingMode) + _SmtUnary.__init__(self, inp) + object.__setattr__(self, "rm", rm) + SmtFloatingPoint.__init__(self, eb=eb, sb=sb, sort=sort) + + def _smtlib2_expr(self, expr_state): + assert isinstance(expr_state, _ExprState) + op = self._smtlib2_expr_op(expr_state) + rm = self.rm._smtlib2_expr(expr_state) + inp = self.inp._smtlib2_expr(expr_state) + return f"({op} {rm} {inp})" + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ to_fp {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointFromSignedBV(_SmtFloatingPointRoundFrom): + inp: SmtBitVec + + def _expected_input_class(self): + return SmtBitVec + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointFromUnsignedBV(_SmtFloatingPointRoundFrom): + inp: SmtBitVec + + def _expected_input_class(self): + return SmtBitVec + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ to_fp_unsigned {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointFromReal(_SmtFloatingPointRoundFrom): + inp: SmtReal + + def _expected_input_class(self): + return SmtReal + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) +class SmtFloatingPointFromFP(_SmtFloatingPointRoundFrom): + inp: SmtFloatingPoint + + def _expected_input_class(self): + return SmtFloatingPoint + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointFromBits(_SmtUnary, SmtFloatingPoint): + inp: SmtBitVec + + def __init__(self, inp, *, eb=None, sb=None, sort=None): + assert isinstance(inp, SmtBitVec) + _SmtUnary.__init__(self, inp) + SmtFloatingPoint.__init__(self, eb=eb, sb=sb, sort=sort) + mantissa_field_width = sb - 1 + sign_field_width = 1 + expected_width = sign_field_width + eb + mantissa_field_width + assert inp.width == expected_width, \ + f"input BitVec is the wrong width, expected {expected_width}" + + def _expected_input_class(self): + return SmtBitVec + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return f"(_ to_fp {self.eb} {self.sb})" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointToReal(_SmtUnary, SmtReal): + inp: SmtFloatingPoint + + def _expected_input_class(self): + return SmtFloatingPoint + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.to_real" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointToUnsignedBV(_SmtUnary, SmtBitVec): + inp: SmtFloatingPoint + + def __init__(self, inp, *, width): + _SmtUnary.__init__(self, inp) + SmtBitVec.__init__(self, width=width) + + def _expected_input_class(self): + return SmtFloatingPoint + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.to_ubv" + + +@final +@dataclass(frozen=True, unsafe_hash=False, eq=False) +class SmtFloatingPointToSignedBV(_SmtUnary, SmtBitVec): + inp: SmtFloatingPoint + + def __init__(self, inp, *, width): + _SmtUnary.__init__(self, inp) + SmtBitVec.__init__(self, width=width) + + def _expected_input_class(self): + return SmtFloatingPoint + + def _smtlib2_expr_op(self, expr_state): + assert isinstance(expr_state, _ExprState) + return "fp.to_sbv" diff --git a/tests/test_hdl_smtlib2.py b/tests/test_hdl_smtlib2.py new file mode 100644 index 0000000..423ea8f --- /dev/null +++ b/tests/test_hdl_smtlib2.py @@ -0,0 +1,79 @@ +from nmigen.hdl.smtlib2 import * +from nmigen.hdl.smtlib2 import _ExprState +from nmigen.hdl.ast import ValueDict, Signal +from .utils import FHDLTestCase + + +class SmtTestCase(FHDLTestCase): + def assertSmtExpr(self, v, ty, expected_expr, input_bit_ranges=(), input_width=0): + self.assertIs(type(v), ty) + state = _ExprState() + expr = v._smtlib2_expr(state) + self.assertEqual(expr, expected_expr) + self.assertEqual(state.input_bit_ranges, ValueDict(input_bit_ranges)) + self.assertEqual(state.input_width, input_width) + + +class TestSorts(SmtTestCase): + def test_bool(self): + self.assertSmtExpr(SmtSortBool(), SmtSortBool, "Bool") + + def test_int(self): + self.assertSmtExpr(SmtSortInt(), SmtSortInt, "Int") + + def test_real(self): + self.assertSmtExpr(SmtSortReal(), SmtSortReal, "Real") + + def test_bv(self): + self.assertSmtExpr(SmtSortBitVec(1), SmtSortBitVec, "(_ BitVec 1)") + self.assertSmtExpr(SmtSortBitVec(16), SmtSortBitVec, "(_ BitVec 16)") + + def test_rm(self): + self.assertSmtExpr(SmtSortRoundingMode(), + SmtSortRoundingMode, "RoundingMode") + + def test_fp(self): + self.assertSmtExpr(SmtSortFloat16(), SmtSortFloatingPoint, "Float16") + self.assertSmtExpr(SmtSortFloat32(), SmtSortFloatingPoint, "Float32") + self.assertSmtExpr(SmtSortFloat64(), SmtSortFloatingPoint, "Float64") + self.assertSmtExpr(SmtSortFloat128(), SmtSortFloatingPoint, "Float128") + self.assertSmtExpr(SmtSortFloatingPoint(3, 5), + SmtSortFloatingPoint, "(_ FloatingPoint 3 5)") + self.assertSmtExpr(SmtSortFloatingPoint(2, 2), + SmtSortFloatingPoint, "(_ FloatingPoint 2 2)") + + +class TestBool(SmtTestCase): + def test_make(self): + self.assertSmtExpr(SmtBool.make(False), SmtBoolConst, "false") + self.assertSmtExpr(SmtBool.make(True), SmtBoolConst, "true") + sig = Signal() + self.assertSmtExpr(SmtBool.make(sig), SmtDistinct, + "(distinct #b0 ((_ extract 0 0) A))", + input_bit_ranges=[(sig, range(0, 1))], + input_width=1) + sig2 = Signal(2) + self.assertSmtExpr(SmtBool.make(sig2.bool()), SmtDistinct, + "(distinct #b0 ((_ extract 0 0) A))", + input_bit_ranges=[(sig2.bool(), range(0, 1))], + input_width=1) + + def test_ite(self): + a = Signal() + b = Signal() + c = Signal() + self.assertSmtExpr( + SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)), + SmtBoolITE, + "(ite (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)) " + "(distinct #b0 ((_ extract 2 2) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + (c, range(2, 3)), + ], + input_width=3, + ) + +# FIXME: add more tests