From 061119d8afac3a1ed4935bd2b893015261e5a7cc Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 4 Aug 2022 00:05:17 -0700 Subject: [PATCH] add byte_reverse formal proof --- src/nmutil/formal/test_byterev.py | 116 ++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/nmutil/formal/test_byterev.py diff --git a/src/nmutil/formal/test_byterev.py b/src/nmutil/formal/test_byterev.py new file mode 100644 index 0000000..d141c80 --- /dev/null +++ b/src/nmutil/formal/test_byterev.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3-or-later +# Copyright 2022 Jacob Lifshay + +from functools import reduce +import operator +import unittest +from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Assume +from nmigen.hdl.dsl import Module +from nmutil.formaltest import FHDLTestCase +from nmutil.byterev import byte_reverse +from nmutil.grev import grev +from nmutil.sim_util import write_il + + +VALID_BYTE_REVERSE_LENGTHS = tuple(1 << i for i in range(4)) +LOG2_BYTE_SIZE = 3 + + +class TestByteReverse(FHDLTestCase): + def tst(self, log2_width, rev_length=None): + assert isinstance(log2_width, int) and log2_width >= LOG2_BYTE_SIZE + assert rev_length is None or rev_length in VALID_BYTE_REVERSE_LENGTHS + m = Module() + width = 1 << log2_width + inp = Signal(width) + m.d.comb += inp.eq(AnyConst(width)) + length_sig = Signal(range(max(VALID_BYTE_REVERSE_LENGTHS) + 1)) + m.d.comb += length_sig.eq(AnyConst(length_sig.shape())) + + if rev_length is None: + rev_length = length_sig + else: + m.d.comb += Assume(length_sig == rev_length) + + with m.Switch(length_sig): + for l in VALID_BYTE_REVERSE_LENGTHS: + with m.Case(l): + m.d.comb += Assume(width >= l << LOG2_BYTE_SIZE) + with m.Default(): + m.d.comb += Assume(False) + + out = byte_reverse(m, name="out", data=inp, length=rev_length) + + expected = Signal(width) + for log2_chunk_size in range(LOG2_BYTE_SIZE, log2_width + 1): + chunk_size = 1 << log2_chunk_size + chunk_byte_size = chunk_size >> LOG2_BYTE_SIZE + chunk_sizes = chunk_size - 8 + with m.If(rev_length == chunk_byte_size): + m.d.comb += expected.eq(grev(inp, chunk_sizes, log2_width) + & ((1 << chunk_size) - 1)) + + m.d.comb += Assert(expected == out) + + self.assertFormal(m) + + def test_8_len_1(self): + self.tst(log2_width=3, rev_length=1) + + def test_8(self): + self.tst(log2_width=3) + + def test_16_len_1(self): + self.tst(log2_width=4, rev_length=1) + + def test_16_len_2(self): + self.tst(log2_width=4, rev_length=2) + + def test_16(self): + self.tst(log2_width=4) + + def test_32_len_1(self): + self.tst(log2_width=5, rev_length=1) + + def test_32_len_2(self): + self.tst(log2_width=5, rev_length=2) + + def test_32_len_4(self): + self.tst(log2_width=5, rev_length=4) + + def test_32(self): + self.tst(log2_width=5) + + def test_64_len_1(self): + self.tst(log2_width=6, rev_length=1) + + def test_64_len_2(self): + self.tst(log2_width=6, rev_length=2) + + def test_64_len_4(self): + self.tst(log2_width=6, rev_length=4) + + def test_64_len_8(self): + self.tst(log2_width=6, rev_length=8) + + def test_64(self): + self.tst(log2_width=6) + + def test_128_len_1(self): + self.tst(log2_width=7, rev_length=1) + + def test_128_len_2(self): + self.tst(log2_width=7, rev_length=2) + + def test_128_len_4(self): + self.tst(log2_width=7, rev_length=4) + + def test_128_len_8(self): + self.tst(log2_width=7, rev_length=8) + + def test_128(self): + self.tst(log2_width=7) + + +if __name__ == "__main__": + unittest.main() -- 2.30.2