fhdl.ir: add tests for port propagation.
[nmigen.git] / nmigen / test / test_fhdl_dsl.py
1 import re
2 import unittest
3 from contextlib import contextmanager
4
5 from nmigen.fhdl.ast import *
6 from nmigen.fhdl.dsl import *
7
8
9 class DSLTestCase(unittest.TestCase):
10 def setUp(self):
11 self.s1 = Signal()
12 self.s2 = Signal()
13 self.s3 = Signal()
14 self.c1 = Signal()
15 self.c2 = Signal()
16 self.c3 = Signal()
17 self.w1 = Signal(4)
18
19 @contextmanager
20 def assertRaises(self, exception, msg=None):
21 with super().assertRaises(exception) as cm:
22 yield
23 if msg is not None:
24 # WTF? unittest.assertRaises is completely broken.
25 self.assertEqual(str(cm.exception), msg)
26
27 def assertRepr(self, obj, repr_str):
28 obj = Statement.wrap(obj)
29 repr_str = re.sub(r"\s+", " ", repr_str)
30 repr_str = re.sub(r"\( (?=\()", "(", repr_str)
31 repr_str = re.sub(r"\) (?=\))", ")", repr_str)
32 self.assertEqual(repr(obj), repr_str.strip())
33
34 def test_d_comb(self):
35 m = Module()
36 m.d.comb += self.c1.eq(1)
37 m._flush()
38 self.assertEqual(m._driving[self.c1], None)
39 self.assertRepr(m._statements, """(
40 (eq (sig c1) (const 1'd1))
41 )""")
42
43 def test_d_sync(self):
44 m = Module()
45 m.d.sync += self.c1.eq(1)
46 m._flush()
47 self.assertEqual(m._driving[self.c1], "sync")
48 self.assertRepr(m._statements, """(
49 (eq (sig c1) (const 1'd1))
50 )""")
51
52 def test_d_pix(self):
53 m = Module()
54 m.d.pix += self.c1.eq(1)
55 m._flush()
56 self.assertEqual(m._driving[self.c1], "pix")
57 self.assertRepr(m._statements, """(
58 (eq (sig c1) (const 1'd1))
59 )""")
60
61 def test_d_index(self):
62 m = Module()
63 m.d["pix"] += self.c1.eq(1)
64 m._flush()
65 self.assertEqual(m._driving[self.c1], "pix")
66 self.assertRepr(m._statements, """(
67 (eq (sig c1) (const 1'd1))
68 )""")
69
70 def test_d_no_conflict(self):
71 m = Module()
72 m.d.comb += self.w1[0].eq(1)
73 m.d.comb += self.w1[1].eq(1)
74
75 def test_d_conflict(self):
76 m = Module()
77 with self.assertRaises(SyntaxError,
78 msg="Driver-driver conflict: trying to drive (sig c1) from d.sync, but it "
79 "is already driven from d.comb"):
80 m.d.comb += self.c1.eq(1)
81 m.d.sync += self.c1.eq(1)
82
83 def test_d_wrong(self):
84 m = Module()
85 with self.assertRaises(AttributeError,
86 msg="Cannot assign 'd.pix' attribute; did you mean 'd.pix +='?"):
87 m.d.pix = None
88
89 def test_d_asgn_wrong(self):
90 m = Module()
91 with self.assertRaises(SyntaxError,
92 msg="Only assignments may be appended to d.sync"):
93 m.d.sync += Switch(self.s1, {})
94
95 def test_comb_wrong(self):
96 m = Module()
97 with self.assertRaises(AttributeError,
98 msg="'Module' object has no attribute 'comb'; did you mean 'd.comb'?"):
99 m.comb += self.c1.eq(1)
100
101 def test_sync_wrong(self):
102 m = Module()
103 with self.assertRaises(AttributeError,
104 msg="'Module' object has no attribute 'sync'; did you mean 'd.sync'?"):
105 m.sync += self.c1.eq(1)
106
107 def test_attr_wrong(self):
108 m = Module()
109 with self.assertRaises(AttributeError,
110 msg="'Module' object has no attribute 'nonexistentattr'"):
111 m.nonexistentattr
112
113 def test_If(self):
114 m = Module()
115 with m.If(self.s1):
116 m.d.comb += self.c1.eq(1)
117 m._flush()
118 self.assertRepr(m._statements, """
119 (
120 (switch (cat (sig s1))
121 (case 1 (eq (sig c1) (const 1'd1)))
122 )
123 )
124 """)
125
126 def test_If_Elif(self):
127 m = Module()
128 with m.If(self.s1):
129 m.d.comb += self.c1.eq(1)
130 with m.Elif(self.s2):
131 m.d.sync += self.c2.eq(0)
132 m._flush()
133 self.assertRepr(m._statements, """
134 (
135 (switch (cat (sig s1) (sig s2))
136 (case -1 (eq (sig c1) (const 1'd1)))
137 (case 1- (eq (sig c2) (const 0'd0)))
138 )
139 )
140 """)
141
142 def test_If_Elif_Else(self):
143 m = Module()
144 with m.If(self.s1):
145 m.d.comb += self.c1.eq(1)
146 with m.Elif(self.s2):
147 m.d.sync += self.c2.eq(0)
148 with m.Else():
149 m.d.comb += self.c3.eq(1)
150 m._flush()
151 self.assertRepr(m._statements, """
152 (
153 (switch (cat (sig s1) (sig s2))
154 (case -1 (eq (sig c1) (const 1'd1)))
155 (case 1- (eq (sig c2) (const 0'd0)))
156 (case -- (eq (sig c3) (const 1'd1)))
157 )
158 )
159 """)
160
161 def test_If_If(self):
162 m = Module()
163 with m.If(self.s1):
164 m.d.comb += self.c1.eq(1)
165 with m.If(self.s2):
166 m.d.comb += self.c2.eq(1)
167 m._flush()
168 self.assertRepr(m._statements, """
169 (
170 (switch (cat (sig s1))
171 (case 1 (eq (sig c1) (const 1'd1)))
172 )
173 (switch (cat (sig s2))
174 (case 1 (eq (sig c2) (const 1'd1)))
175 )
176 )
177 """)
178
179 def test_If_nested_If(self):
180 m = Module()
181 with m.If(self.s1):
182 m.d.comb += self.c1.eq(1)
183 with m.If(self.s2):
184 m.d.comb += self.c2.eq(1)
185 m._flush()
186 self.assertRepr(m._statements, """
187 (
188 (switch (cat (sig s1))
189 (case 1 (eq (sig c1) (const 1'd1))
190 (switch (cat (sig s2))
191 (case 1 (eq (sig c2) (const 1'd1)))
192 )
193 )
194 )
195 )
196 """)
197
198 def test_If_dangling_Else(self):
199 m = Module()
200 with m.If(self.s1):
201 m.d.comb += self.c1.eq(1)
202 with m.If(self.s2):
203 m.d.comb += self.c2.eq(1)
204 with m.Else():
205 m.d.comb += self.c3.eq(1)
206 m._flush()
207 self.assertRepr(m._statements, """
208 (
209 (switch (cat (sig s1))
210 (case 1
211 (eq (sig c1) (const 1'd1))
212 (switch (cat (sig s2))
213 (case 1 (eq (sig c2) (const 1'd1)))
214 )
215 )
216 (case -
217 (eq (sig c3) (const 1'd1))
218 )
219 )
220 )
221 """)
222
223 def test_Elif_wrong(self):
224 m = Module()
225 with self.assertRaises(SyntaxError,
226 msg="Elif without preceding If"):
227 with m.Elif(self.s2):
228 pass
229
230 def test_Else_wrong(self):
231 m = Module()
232 with self.assertRaises(SyntaxError,
233 msg="Else without preceding If/Elif"):
234 with m.Else():
235 pass
236
237 def test_If_wide(self):
238 m = Module()
239 with m.If(self.w1):
240 m.d.comb += self.c1.eq(1)
241 m._flush()
242 self.assertRepr(m._statements, """
243 (
244 (switch (cat (b (sig w1)))
245 (case 1 (eq (sig c1) (const 1'd1)))
246 )
247 )
248 """)
249
250 def test_Switch(self):
251 m = Module()
252 with m.Switch(self.w1):
253 with m.Case(3):
254 m.d.comb += self.c1.eq(1)
255 with m.Case("11--"):
256 m.d.comb += self.c2.eq(1)
257 m._flush()
258 self.assertRepr(m._statements, """
259 (
260 (switch (sig w1)
261 (case 0011 (eq (sig c1) (const 1'd1)))
262 (case 11-- (eq (sig c2) (const 1'd1)))
263 )
264 )
265 """)
266
267 def test_Switch_default(self):
268 m = Module()
269 with m.Switch(self.w1):
270 with m.Case(3):
271 m.d.comb += self.c1.eq(1)
272 with m.Case():
273 m.d.comb += self.c2.eq(1)
274 m._flush()
275 self.assertRepr(m._statements, """
276 (
277 (switch (sig w1)
278 (case 0011 (eq (sig c1) (const 1'd1)))
279 (case ---- (eq (sig c2) (const 1'd1)))
280 )
281 )
282 """)
283
284 def test_Case_width_wrong(self):
285 m = Module()
286 with m.Switch(self.w1):
287 with self.assertRaises(SyntaxError,
288 msg="Case value '--' must have the same width as test (which is 4)"):
289 with m.Case("--"):
290 pass
291
292 def test_Case_outside_Switch_wrong(self):
293 m = Module()
294 with self.assertRaises(SyntaxError,
295 msg="Case is not permitted outside of Switch"):
296 with m.Case():
297 pass
298
299 def test_If_inside_Switch_wrong(self):
300 m = Module()
301 with m.Switch(self.s1):
302 with self.assertRaises(SyntaxError,
303 msg="If is not permitted inside of Switch"):
304 with m.If(self.s2):
305 pass
306
307 def test_auto_pop_ctrl(self):
308 m = Module()
309 with m.If(self.w1):
310 m.d.comb += self.c1.eq(1)
311 m.d.comb += self.c2.eq(1)
312 self.assertRepr(m._statements, """
313 (
314 (switch (cat (b (sig w1)))
315 (case 1 (eq (sig c1) (const 1'd1)))
316 )
317 (eq (sig c2) (const 1'd1))
318 )
319 """)
320
321 def test_submodule_anon(self):
322 m1 = Module()
323 m2 = Module()
324 m1.submodules += m2
325 self.assertEqual(m1._submodules, [(m2, None)])
326
327 def test_submodule_anon_multi(self):
328 m1 = Module()
329 m2 = Module()
330 m3 = Module()
331 m1.submodules += m2, m3
332 self.assertEqual(m1._submodules, [(m2, None), (m3, None)])
333
334 def test_submodule_named(self):
335 m1 = Module()
336 m2 = Module()
337 m1.submodules.foo = m2
338 self.assertEqual(m1._submodules, [(m2, "foo")])
339
340 def test_submodule_wrong(self):
341 m = Module()
342 with self.assertRaises(TypeError,
343 msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"):
344 m.submodules.foo = 1
345 with self.assertRaises(TypeError,
346 msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"):
347 m.submodules += 1
348
349 def test_lower(self):
350 m1 = Module()
351 m1.d.comb += self.c1.eq(self.s1)
352 m2 = Module()
353 m2.d.comb += self.c2.eq(self.s2)
354 m2.d.sync += self.c3.eq(self.s3)
355 m1.submodules.foo = m2
356
357 f1 = m1.lower(platform=None)
358 self.assertRepr(f1.statements, """
359 (
360 (eq (sig c1) (sig s1))
361 )
362 """)
363 self.assertEqual(f1.drivers, {
364 None: ValueSet((self.c1,))
365 })
366 self.assertEqual(len(f1.subfragments), 1)
367 (f2, f2_name), = f1.subfragments
368 self.assertEqual(f2_name, "foo")
369 self.assertRepr(f2.statements, """
370 (
371 (eq (sig c2) (sig s2))
372 (eq (sig c3) (sig s3))
373 )
374 """)
375 self.assertEqual(f2.drivers, {
376 None: ValueSet((self.c2,)),
377 "sync": ValueSet((self.c3,))
378 })
379 self.assertEqual(len(f2.subfragments), 0)