remove unneeded imports
[nmutil.git] / src / nmutil / formal / test_byterev.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 import unittest
5 from nmigen.hdl.ast import AnyConst, Assert, Signal, Assume
6 from nmigen.hdl.dsl import Module
7 from nmutil.formaltest import FHDLTestCase
8 from nmutil.byterev import byte_reverse
9 from nmutil.grev import grev
10
11
12 VALID_BYTE_REVERSE_LENGTHS = tuple(1 << i for i in range(4))
13 LOG2_BYTE_SIZE = 3
14
15
16 class TestByteReverse(FHDLTestCase):
17 def tst(self, log2_width, rev_length=None):
18 assert isinstance(log2_width, int) and log2_width >= LOG2_BYTE_SIZE
19 assert rev_length is None or rev_length in VALID_BYTE_REVERSE_LENGTHS
20 m = Module()
21 width = 1 << log2_width
22 inp = Signal(width)
23 m.d.comb += inp.eq(AnyConst(width))
24 length_sig = Signal(range(max(VALID_BYTE_REVERSE_LENGTHS) + 1))
25 m.d.comb += length_sig.eq(AnyConst(length_sig.shape()))
26
27 if rev_length is None:
28 rev_length = length_sig
29 else:
30 m.d.comb += Assume(length_sig == rev_length)
31
32 with m.Switch(length_sig):
33 for l in VALID_BYTE_REVERSE_LENGTHS:
34 with m.Case(l):
35 m.d.comb += Assume(width >= l << LOG2_BYTE_SIZE)
36 with m.Default():
37 m.d.comb += Assume(False)
38
39 out = byte_reverse(m, name="out", data=inp, length=rev_length)
40
41 expected = Signal(width)
42 for log2_chunk_size in range(LOG2_BYTE_SIZE, log2_width + 1):
43 chunk_size = 1 << log2_chunk_size
44 chunk_byte_size = chunk_size >> LOG2_BYTE_SIZE
45 chunk_sizes = chunk_size - 8
46 with m.If(rev_length == chunk_byte_size):
47 m.d.comb += expected.eq(grev(inp, chunk_sizes, log2_width)
48 & ((1 << chunk_size) - 1))
49
50 m.d.comb += Assert(expected == out)
51
52 self.assertFormal(m)
53
54 def test_8_len_1(self):
55 self.tst(log2_width=3, rev_length=1)
56
57 def test_8(self):
58 self.tst(log2_width=3)
59
60 def test_16_len_1(self):
61 self.tst(log2_width=4, rev_length=1)
62
63 def test_16_len_2(self):
64 self.tst(log2_width=4, rev_length=2)
65
66 def test_16(self):
67 self.tst(log2_width=4)
68
69 def test_32_len_1(self):
70 self.tst(log2_width=5, rev_length=1)
71
72 def test_32_len_2(self):
73 self.tst(log2_width=5, rev_length=2)
74
75 def test_32_len_4(self):
76 self.tst(log2_width=5, rev_length=4)
77
78 def test_32(self):
79 self.tst(log2_width=5)
80
81 def test_64_len_1(self):
82 self.tst(log2_width=6, rev_length=1)
83
84 def test_64_len_2(self):
85 self.tst(log2_width=6, rev_length=2)
86
87 def test_64_len_4(self):
88 self.tst(log2_width=6, rev_length=4)
89
90 def test_64_len_8(self):
91 self.tst(log2_width=6, rev_length=8)
92
93 def test_64(self):
94 self.tst(log2_width=6)
95
96 def test_128_len_1(self):
97 self.tst(log2_width=7, rev_length=1)
98
99 def test_128_len_2(self):
100 self.tst(log2_width=7, rev_length=2)
101
102 def test_128_len_4(self):
103 self.tst(log2_width=7, rev_length=4)
104
105 def test_128_len_8(self):
106 self.tst(log2_width=7, rev_length=8)
107
108 def test_128(self):
109 self.tst(log2_width=7)
110
111
112 if __name__ == "__main__":
113 unittest.main()