test.sim: generalize assertOperator. NFC.
[nmigen.git] / nmigen / test / test_sim.py
1 from .tools import *
2 from ..hdl.ast import *
3 from ..hdl.ir import *
4 from ..back.pysim import *
5
6
7 class SimulatorUnitTestCase(FHDLTestCase):
8 def assertStatement(self, stmt, inputs, output):
9 inputs = [Value.wrap(i) for i in inputs]
10 output = Value.wrap(output)
11
12 isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")]
13 osig = Signal(output.shape(), name="y")
14
15 frag = Fragment()
16 frag.add_statements(stmt(osig, *isigs))
17 frag.add_driver(osig)
18
19 with Simulator(frag,
20 vcd_file =open("test.vcd", "w"),
21 gtkw_file=open("test.gtkw", "w"),
22 traces=[*isigs, osig]) as sim:
23 def process():
24 for isig, input in zip(isigs, inputs):
25 yield isig.eq(input)
26 yield Delay()
27 self.assertEqual((yield osig), output.value)
28 sim.add_process(process)
29 sim.run()
30
31 def test_invert(self):
32 stmt = lambda y, a: y.eq(~a)
33 self.assertStatement(stmt, [C(0b0000, 4)], C(0b1111, 4))
34 self.assertStatement(stmt, [C(0b1010, 4)], C(0b0101, 4))
35 self.assertStatement(stmt, [C(0, 4)], C(-1, 4))
36
37 def test_neg(self):
38 stmt = lambda y, a: y.eq(-a)
39 self.assertStatement(stmt, [C(0b0000, 4)], C(0b0000, 4))
40 self.assertStatement(stmt, [C(0b0001, 4)], C(0b1111, 4))
41 self.assertStatement(stmt, [C(0b1010, 4)], C(0b0110, 4))
42 self.assertStatement(stmt, [C(1, 4)], C(-1, 4))
43 self.assertStatement(stmt, [C(5, 4)], C(-5, 4))
44
45 def test_bool(self):
46 stmt = lambda y, a: y.eq(a.bool())
47 self.assertStatement(stmt, [C(0, 4)], C(0))
48 self.assertStatement(stmt, [C(1, 4)], C(1))
49 self.assertStatement(stmt, [C(2, 4)], C(1))
50
51 def test_add(self):
52 stmt = lambda y, a, b: y.eq(a + b)
53 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1, 4))
54 self.assertStatement(stmt, [C(-5, 4), C(-5, 4)], C(-10, 5))
55
56 def test_sub(self):
57 stmt = lambda y, a, b: y.eq(a - b)
58 self.assertStatement(stmt, [C(2, 4), C(1, 4)], C(1, 4))
59 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(-1, 4))
60 self.assertStatement(stmt, [C(0, 4), C(10, 4)], C(-10, 5))
61
62 def test_and(self):
63 stmt = lambda y, a, b: y.eq(a & b)
64 self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1000, 4))
65
66 def test_or(self):
67 stmt = lambda y, a, b: y.eq(a | b)
68 self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1110, 4))
69
70 def test_xor(self):
71 stmt = lambda y, a, b: y.eq(a ^ b)
72 self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4))
73
74 def test_shl(self):
75 stmt = lambda y, a, b: y.eq(a << b)
76 self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5))
77 self.assertStatement(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7))
78 self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7))
79
80 def test_shr(self):
81 stmt = lambda y, a, b: y.eq(a >> b)
82 self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4))
83 self.assertStatement(stmt, [C(0b1001, 4), C(2)], C(0b10, 4))
84 self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5))
85
86 def test_eq(self):
87 stmt = lambda y, a, b: y.eq(a == b)
88 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
89 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
90 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
91
92 def test_ne(self):
93 stmt = lambda y, a, b: y.eq(a != b)
94 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
95 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
96 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
97
98 def test_lt(self):
99 stmt = lambda y, a, b: y.eq(a < b)
100 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
101 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
102 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
103
104 def test_ge(self):
105 stmt = lambda y, a, b: y.eq(a >= b)
106 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
107 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
108 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
109
110 def test_gt(self):
111 stmt = lambda y, a, b: y.eq(a > b)
112 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(0))
113 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(0))
114 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(1))
115
116 def test_le(self):
117 stmt = lambda y, a, b: y.eq(a <= b)
118 self.assertStatement(stmt, [C(0, 4), C(0, 4)], C(1))
119 self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1))
120 self.assertStatement(stmt, [C(1, 4), C(0, 4)], C(0))
121
122 def test_mux(self):
123 stmt = lambda y, a, b, c: y.eq(Mux(c, a, b))
124 self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0)], C(3, 4))
125 self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1)], C(2, 4))
126
127 def test_slice(self):
128 stmt1 = lambda y, a: y.eq(a[2])
129 self.assertStatement(stmt1, [C(0b10110100, 8)], C(0b1, 1))
130 stmt2 = lambda y, a: y.eq(a[2:4])
131 self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2))
132
133 def test_part(self):
134 stmt = lambda y, a, b: y.eq(a.part(b, 3))
135 self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
136 self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3))
137 self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3))
138
139 def test_cat(self):
140 stmt = lambda y, *xs: y.eq(Cat(*xs))
141 self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4))
142
143 def test_repl(self):
144 stmt = lambda y, a: y.eq(Repl(a, 3))
145 self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))
146
147 def test_array(self):
148 array = Array([1, 4, 10])
149 stmt = lambda y, a: y.eq(array[a])
150 self.assertStatement(stmt, [C(0)], C(1))
151 self.assertStatement(stmt, [C(1)], C(4))
152 self.assertStatement(stmt, [C(2)], C(10))
153
154 def test_array_index(self):
155 array = Array(Array(x * y for y in range(10)) for x in range(10))
156 stmt = lambda y, a, b: y.eq(array[a][b])
157 for x in range(10):
158 for y in range(10):
159 self.assertStatement(stmt, [C(x), C(y)], C(x * y))
160
161 def test_array_attr(self):
162 from collections import namedtuple
163 pair = namedtuple("pair", ("p", "n"))
164
165 array = Array(pair(x, -x) for x in range(10))
166 stmt = lambda y, a: y.eq(array[a].p + array[a].n)
167 for i in range(10):
168 self.assertStatement(stmt, [C(i)], C(0))