From: Jacob Lifshay Date: Fri, 20 May 2022 08:55:20 +0000 (-0700) Subject: adding more tests X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=refs%2Fheads%2Fsmtlib2-expr-support;p=nmigen.git adding more tests --- diff --git a/nmigen/hdl/smtlib2.py b/nmigen/hdl/smtlib2.py index 04846dd..a32bde8 100644 --- a/nmigen/hdl/smtlib2.py +++ b/nmigen/hdl/smtlib2.py @@ -475,8 +475,9 @@ class _SmtUnary(SmtValue): assert isinstance(expr_state, _ExprState) # :nocov: return str(...) # :nocov: + @abstractmethod def _expected_input_class(self): - return SmtValue + return SmtValue # :nocov: def __init__(self, inp): object.__setattr__(self, "inp", inp) @@ -499,8 +500,9 @@ class _SmtBinary(SmtValue): assert isinstance(expr_state, _ExprState) # :nocov: return str(...) # :nocov: + @abstractmethod def _expected_input_class(self): - return SmtValue + return SmtValue # :nocov: @property def input_sort(self): @@ -542,13 +544,16 @@ class SmtDistinct(_SmtNAry, SmtBool): class SmtBoolConst(SmtBool): value: bool + def __post_init__(self): + assert isinstance(self.value, bool) + def _smtlib2_expr(self, expr_state): assert isinstance(expr_state, _ExprState) return "true" if self.value else "false" @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtBoolNot(_SmtUnary, SmtBool): inp: SmtBool @@ -720,7 +725,7 @@ class SmtRealConst(SmtReal): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtRealNeg(_SmtUnary, SmtReal): inp: SmtReal @@ -733,7 +738,7 @@ class SmtRealNeg(_SmtUnary, SmtReal): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtRealIsInt(_SmtUnary, SmtBool): inp: SmtReal @@ -927,7 +932,7 @@ class SmtIntConst(SmtInt): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtIntNeg(_SmtUnary, SmtInt): inp: SmtInt @@ -940,7 +945,7 @@ class SmtIntNeg(_SmtUnary, SmtInt): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtIntAbs(_SmtUnary, SmtInt): inp: SmtInt @@ -953,7 +958,7 @@ class SmtIntAbs(_SmtUnary, SmtInt): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtIntToReal(_SmtUnary, SmtReal): inp: SmtInt @@ -966,7 +971,7 @@ class SmtIntToReal(_SmtUnary, SmtReal): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtRealToInt(_SmtUnary, SmtInt): inp: SmtReal @@ -1407,7 +1412,7 @@ class SmtBitVecNot(_SmtBitVecUnary): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtBitVecToNat(_SmtUnary, SmtInt): inp: SmtBitVec @@ -2180,7 +2185,7 @@ class SmtFloatingPointGE(_SmtFloatingPointCompareOp): return "fp.geq" -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class _SmtFloatingPointToBoolUnary(_SmtUnary, SmtBool): inp: SmtFloatingPoint @@ -2330,7 +2335,7 @@ class SmtFloatingPointFromBits(_SmtUnary, SmtFloatingPoint): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtFloatingPointToReal(_SmtUnary, SmtReal): inp: SmtFloatingPoint @@ -2343,7 +2348,7 @@ class SmtFloatingPointToReal(_SmtUnary, SmtReal): @final -@dataclass(frozen=True, unsafe_hash=False, eq=False) +@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False) class SmtFloatingPointToUnsignedBV(_SmtUnary, SmtBitVec): inp: SmtFloatingPoint diff --git a/tests/test_hdl_smtlib2.py b/tests/test_hdl_smtlib2.py index 423ea8f..825c81f 100644 --- a/tests/test_hdl_smtlib2.py +++ b/tests/test_hdl_smtlib2.py @@ -1,6 +1,7 @@ from nmigen.hdl.smtlib2 import * from nmigen.hdl.smtlib2 import _ExprState -from nmigen.hdl.ast import ValueDict, Signal +from nmigen.hdl.ast import AnyConst, Assert, Mux, ValueDict, Signal, Value +from nmigen.hdl.dsl import Module from .utils import FHDLTestCase @@ -13,26 +14,46 @@ class SmtTestCase(FHDLTestCase): self.assertEqual(state.input_bit_ranges, ValueDict(input_bit_ranges)) self.assertEqual(state.input_width, input_width) + def assertSmtSame(self, expr, expected, *, inputs, assumptions=()): + assert isinstance(expr, (SmtBitVec, SmtBool)) + expected = Value.cast(expected) + m = Module() + inputs = list(inputs) + for sig in inputs: + assert isinstance(sig, Signal) + any_const = AnyConst(sig.shape()) + any_const.src_loc = sig.src_loc # makes it easier to debug + m.d.comb += sig.eq(any_const) + expr_s = expr.to_signal() + m.submodules += expr_s + expected_s = Signal(expected.shape()) + m.d.comb += [ + expected_s.eq(expected), + Assert(expr_s.o_sig == expected_s), + *assumptions, + ] + self.assertFormal(m) + class TestSorts(SmtTestCase): - def test_bool(self): + def test_sort_bool(self): self.assertSmtExpr(SmtSortBool(), SmtSortBool, "Bool") - def test_int(self): + def test_sort_int(self): self.assertSmtExpr(SmtSortInt(), SmtSortInt, "Int") - def test_real(self): + def test_sort_real(self): self.assertSmtExpr(SmtSortReal(), SmtSortReal, "Real") - def test_bv(self): + def test_sort_bv(self): self.assertSmtExpr(SmtSortBitVec(1), SmtSortBitVec, "(_ BitVec 1)") self.assertSmtExpr(SmtSortBitVec(16), SmtSortBitVec, "(_ BitVec 16)") - def test_rm(self): + def test_sort_rm(self): self.assertSmtExpr(SmtSortRoundingMode(), SmtSortRoundingMode, "RoundingMode") - def test_fp(self): + def test_sort_fp(self): self.assertSmtExpr(SmtSortFloat16(), SmtSortFloatingPoint, "Float16") self.assertSmtExpr(SmtSortFloat32(), SmtSortFloatingPoint, "Float32") self.assertSmtExpr(SmtSortFloat64(), SmtSortFloatingPoint, "Float64") @@ -44,7 +65,7 @@ class TestSorts(SmtTestCase): class TestBool(SmtTestCase): - def test_make(self): + def test_bool_make(self): self.assertSmtExpr(SmtBool.make(False), SmtBoolConst, "false") self.assertSmtExpr(SmtBool.make(True), SmtBoolConst, "true") sig = Signal() @@ -58,12 +79,54 @@ class TestBool(SmtTestCase): input_bit_ranges=[(sig2.bool(), range(0, 1))], input_width=1) - def test_ite(self): + def test_bool_same(self): + a = Signal() + b = Signal() + c = Signal() + expr = SmtBool.make(a).same(SmtBool.make(b), SmtBool.make(c)) + self.assertSmtExpr( + expr, + SmtSame, + "(= (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)) " + "(distinct #b0 ((_ extract 2 2) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + (c, range(2, 3)), + ], + input_width=3, + ) + self.assertSmtSame(expr, (a == b) & (b == c), inputs=(a, b)) + + def test_bool_distinct(self): + # only check 2 inputs since if we were to check 3 inputs it would + # always return false since there are only 2 possible Bool values + # but you'd need at least 3 to make distinct return True since every + # input needs to be different than all others. + a = Signal() + b = Signal() + expr = SmtBool.make(a).distinct(SmtBool.make(b)) + self.assertSmtExpr( + expr, + SmtDistinct, + "(distinct (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertSmtSame(expr, a != b, inputs=(a, b)) + + def test_bool_ite(self): a = Signal() b = Signal() c = Signal() + expr = SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)) self.assertSmtExpr( - SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)), + expr, SmtBoolITE, "(ite (distinct #b0 ((_ extract 0 0) A)) " "(distinct #b0 ((_ extract 1 1) A)) " @@ -75,5 +138,133 @@ class TestBool(SmtTestCase): ], input_width=3, ) + self.assertSmtSame(expr, Mux(a, b, c), inputs=(a, b, c)) + + def test_bool_bool(self): + with self.assertRaises(TypeError): + bool(SmtBool.make(False)) + + def test_bool_invert(self): + a = Signal() + expr = ~SmtBool.make(a) + self.assertSmtExpr( + expr, + SmtBoolNot, + "(not (distinct #b0 ((_ extract 0 0) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + ], + input_width=1, + ) + self.assertSmtSame(expr, ~a, inputs=(a,)) + + def test_bool_and(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a) & SmtBool.make(b) + self.assertSmtExpr( + expr, + SmtBoolAnd, + "(and (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertEqual(repr(expr), + repr(SmtBool.make(b).__rand__(SmtBool.make(a)))) + self.assertSmtSame(expr, a & b, inputs=(a, b)) + + def test_bool_xor(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a) ^ SmtBool.make(b) + self.assertSmtExpr( + expr, + SmtBoolXor, + "(xor (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertEqual(repr(expr), + repr(SmtBool.make(b).__rxor__(SmtBool.make(a)))) + self.assertSmtSame(expr, a ^ b, inputs=(a, b)) + + def test_bool_or(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a) | SmtBool.make(b) + self.assertSmtExpr( + expr, + SmtBoolOr, + "(or (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertEqual(repr(expr), + repr(SmtBool.make(b).__ror__(SmtBool.make(a)))) + self.assertSmtSame(expr, a | b, inputs=(a, b)) + + def test_bool_eq(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a) == SmtBool.make(b) + self.assertSmtExpr( + expr, + SmtSame, + "(= (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertSmtSame(expr, a == b, inputs=(a, b)) + + def test_bool_ne(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a) != SmtBool.make(b) + self.assertSmtExpr( + expr, + SmtDistinct, + "(distinct (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertSmtSame(expr, a != b, inputs=(a, b)) + + def test_bool_implies(self): + a = Signal() + b = Signal() + expr = SmtBool.make(a).implies(SmtBool.make(b)) + self.assertSmtExpr( + expr, + SmtBoolImplies, + "(=> (distinct #b0 ((_ extract 0 0) A)) " + "(distinct #b0 ((_ extract 1 1) A)))", + input_bit_ranges=[ + (a, range(0, 1)), + (b, range(1, 2)), + ], + input_width=2, + ) + self.assertSmtSame(expr, a.implies(b), inputs=(a, b)) + # FIXME: add more tests