assert isinstance(expr_state, _ExprState) # :nocov:
return str(...) # :nocov:
+ @abstractmethod
def _expected_input_class(self):
- return SmtValue
+ return SmtValue # :nocov:
def __init__(self, inp):
object.__setattr__(self, "inp", inp)
assert isinstance(expr_state, _ExprState) # :nocov:
return str(...) # :nocov:
+ @abstractmethod
def _expected_input_class(self):
- return SmtValue
+ return SmtValue # :nocov:
@property
def input_sort(self):
class SmtBoolConst(SmtBool):
value: bool
+ def __post_init__(self):
+ assert isinstance(self.value, bool)
+
def _smtlib2_expr(self, expr_state):
assert isinstance(expr_state, _ExprState)
return "true" if self.value else "false"
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtBoolNot(_SmtUnary, SmtBool):
inp: SmtBool
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtRealNeg(_SmtUnary, SmtReal):
inp: SmtReal
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtRealIsInt(_SmtUnary, SmtBool):
inp: SmtReal
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtIntNeg(_SmtUnary, SmtInt):
inp: SmtInt
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtIntAbs(_SmtUnary, SmtInt):
inp: SmtInt
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtIntToReal(_SmtUnary, SmtReal):
inp: SmtInt
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtRealToInt(_SmtUnary, SmtInt):
inp: SmtReal
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtBitVecToNat(_SmtUnary, SmtInt):
inp: SmtBitVec
return "fp.geq"
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class _SmtFloatingPointToBoolUnary(_SmtUnary, SmtBool):
inp: SmtFloatingPoint
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtFloatingPointToReal(_SmtUnary, SmtReal):
inp: SmtFloatingPoint
@final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
class SmtFloatingPointToUnsignedBV(_SmtUnary, SmtBitVec):
inp: SmtFloatingPoint
from nmigen.hdl.smtlib2 import *
from nmigen.hdl.smtlib2 import _ExprState
-from nmigen.hdl.ast import ValueDict, Signal
+from nmigen.hdl.ast import AnyConst, Assert, Mux, ValueDict, Signal, Value
+from nmigen.hdl.dsl import Module
from .utils import FHDLTestCase
self.assertEqual(state.input_bit_ranges, ValueDict(input_bit_ranges))
self.assertEqual(state.input_width, input_width)
+ def assertSmtSame(self, expr, expected, *, inputs, assumptions=()):
+ assert isinstance(expr, (SmtBitVec, SmtBool))
+ expected = Value.cast(expected)
+ m = Module()
+ inputs = list(inputs)
+ for sig in inputs:
+ assert isinstance(sig, Signal)
+ any_const = AnyConst(sig.shape())
+ any_const.src_loc = sig.src_loc # makes it easier to debug
+ m.d.comb += sig.eq(any_const)
+ expr_s = expr.to_signal()
+ m.submodules += expr_s
+ expected_s = Signal(expected.shape())
+ m.d.comb += [
+ expected_s.eq(expected),
+ Assert(expr_s.o_sig == expected_s),
+ *assumptions,
+ ]
+ self.assertFormal(m)
+
class TestSorts(SmtTestCase):
- def test_bool(self):
+ def test_sort_bool(self):
self.assertSmtExpr(SmtSortBool(), SmtSortBool, "Bool")
- def test_int(self):
+ def test_sort_int(self):
self.assertSmtExpr(SmtSortInt(), SmtSortInt, "Int")
- def test_real(self):
+ def test_sort_real(self):
self.assertSmtExpr(SmtSortReal(), SmtSortReal, "Real")
- def test_bv(self):
+ def test_sort_bv(self):
self.assertSmtExpr(SmtSortBitVec(1), SmtSortBitVec, "(_ BitVec 1)")
self.assertSmtExpr(SmtSortBitVec(16), SmtSortBitVec, "(_ BitVec 16)")
- def test_rm(self):
+ def test_sort_rm(self):
self.assertSmtExpr(SmtSortRoundingMode(),
SmtSortRoundingMode, "RoundingMode")
- def test_fp(self):
+ def test_sort_fp(self):
self.assertSmtExpr(SmtSortFloat16(), SmtSortFloatingPoint, "Float16")
self.assertSmtExpr(SmtSortFloat32(), SmtSortFloatingPoint, "Float32")
self.assertSmtExpr(SmtSortFloat64(), SmtSortFloatingPoint, "Float64")
class TestBool(SmtTestCase):
- def test_make(self):
+ def test_bool_make(self):
self.assertSmtExpr(SmtBool.make(False), SmtBoolConst, "false")
self.assertSmtExpr(SmtBool.make(True), SmtBoolConst, "true")
sig = Signal()
input_bit_ranges=[(sig2.bool(), range(0, 1))],
input_width=1)
- def test_ite(self):
+ def test_bool_same(self):
+ a = Signal()
+ b = Signal()
+ c = Signal()
+ expr = SmtBool.make(a).same(SmtBool.make(b), SmtBool.make(c))
+ self.assertSmtExpr(
+ expr,
+ SmtSame,
+ "(= (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)) "
+ "(distinct #b0 ((_ extract 2 2) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ (c, range(2, 3)),
+ ],
+ input_width=3,
+ )
+ self.assertSmtSame(expr, (a == b) & (b == c), inputs=(a, b))
+
+ def test_bool_distinct(self):
+ # only check 2 inputs since if we were to check 3 inputs it would
+ # always return false since there are only 2 possible Bool values
+ # but you'd need at least 3 to make distinct return True since every
+ # input needs to be different than all others.
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a).distinct(SmtBool.make(b))
+ self.assertSmtExpr(
+ expr,
+ SmtDistinct,
+ "(distinct (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertSmtSame(expr, a != b, inputs=(a, b))
+
+ def test_bool_ite(self):
a = Signal()
b = Signal()
c = Signal()
+ expr = SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c))
self.assertSmtExpr(
- SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)),
+ expr,
SmtBoolITE,
"(ite (distinct #b0 ((_ extract 0 0) A)) "
"(distinct #b0 ((_ extract 1 1) A)) "
],
input_width=3,
)
+ self.assertSmtSame(expr, Mux(a, b, c), inputs=(a, b, c))
+
+ def test_bool_bool(self):
+ with self.assertRaises(TypeError):
+ bool(SmtBool.make(False))
+
+ def test_bool_invert(self):
+ a = Signal()
+ expr = ~SmtBool.make(a)
+ self.assertSmtExpr(
+ expr,
+ SmtBoolNot,
+ "(not (distinct #b0 ((_ extract 0 0) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ ],
+ input_width=1,
+ )
+ self.assertSmtSame(expr, ~a, inputs=(a,))
+
+ def test_bool_and(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a) & SmtBool.make(b)
+ self.assertSmtExpr(
+ expr,
+ SmtBoolAnd,
+ "(and (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertEqual(repr(expr),
+ repr(SmtBool.make(b).__rand__(SmtBool.make(a))))
+ self.assertSmtSame(expr, a & b, inputs=(a, b))
+
+ def test_bool_xor(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a) ^ SmtBool.make(b)
+ self.assertSmtExpr(
+ expr,
+ SmtBoolXor,
+ "(xor (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertEqual(repr(expr),
+ repr(SmtBool.make(b).__rxor__(SmtBool.make(a))))
+ self.assertSmtSame(expr, a ^ b, inputs=(a, b))
+
+ def test_bool_or(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a) | SmtBool.make(b)
+ self.assertSmtExpr(
+ expr,
+ SmtBoolOr,
+ "(or (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertEqual(repr(expr),
+ repr(SmtBool.make(b).__ror__(SmtBool.make(a))))
+ self.assertSmtSame(expr, a | b, inputs=(a, b))
+
+ def test_bool_eq(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a) == SmtBool.make(b)
+ self.assertSmtExpr(
+ expr,
+ SmtSame,
+ "(= (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertSmtSame(expr, a == b, inputs=(a, b))
+
+ def test_bool_ne(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a) != SmtBool.make(b)
+ self.assertSmtExpr(
+ expr,
+ SmtDistinct,
+ "(distinct (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertSmtSame(expr, a != b, inputs=(a, b))
+
+ def test_bool_implies(self):
+ a = Signal()
+ b = Signal()
+ expr = SmtBool.make(a).implies(SmtBool.make(b))
+ self.assertSmtExpr(
+ expr,
+ SmtBoolImplies,
+ "(=> (distinct #b0 ((_ extract 0 0) A)) "
+ "(distinct #b0 ((_ extract 1 1) A)))",
+ input_bit_ranges=[
+ (a, range(0, 1)),
+ (b, range(1, 2)),
+ ],
+ input_width=2,
+ )
+ self.assertSmtSame(expr, a.implies(b), inputs=(a, b))
+
# FIXME: add more tests