tests: move out of the main package.
[nmigen.git] / tests / test_hdl_xfrm.py
1 # nmigen: UnusedElaboratable=no
2
3 from nmigen.hdl.ast import *
4 from nmigen.hdl.cd import *
5 from nmigen.hdl.ir import *
6 from nmigen.hdl.xfrm import *
7 from nmigen.hdl.mem import *
8
9 from .utils import *
10
11
12 class DomainRenamerTestCase(FHDLTestCase):
13 def setUp(self):
14 self.s1 = Signal()
15 self.s2 = Signal()
16 self.s3 = Signal()
17 self.s4 = Signal()
18 self.s5 = Signal()
19 self.c1 = Signal()
20
21 def test_rename_signals(self):
22 f = Fragment()
23 f.add_statements(
24 self.s1.eq(ClockSignal()),
25 ResetSignal().eq(self.s2),
26 self.s3.eq(0),
27 self.s4.eq(ClockSignal("other")),
28 self.s5.eq(ResetSignal("other")),
29 )
30 f.add_driver(self.s1, None)
31 f.add_driver(self.s2, None)
32 f.add_driver(self.s3, "sync")
33
34 f = DomainRenamer("pix")(f)
35 self.assertRepr(f.statements, """
36 (
37 (eq (sig s1) (clk pix))
38 (eq (rst pix) (sig s2))
39 (eq (sig s3) (const 1'd0))
40 (eq (sig s4) (clk other))
41 (eq (sig s5) (rst other))
42 )
43 """)
44 self.assertEqual(f.drivers, {
45 None: SignalSet((self.s1, self.s2)),
46 "pix": SignalSet((self.s3,)),
47 })
48
49 def test_rename_multi(self):
50 f = Fragment()
51 f.add_statements(
52 self.s1.eq(ClockSignal()),
53 self.s2.eq(ResetSignal("other")),
54 )
55
56 f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
57 self.assertRepr(f.statements, """
58 (
59 (eq (sig s1) (clk pix))
60 (eq (sig s2) (rst pix2))
61 )
62 """)
63
64 def test_rename_cd(self):
65 cd_sync = ClockDomain()
66 cd_pix = ClockDomain()
67
68 f = Fragment()
69 f.add_domains(cd_sync, cd_pix)
70
71 f = DomainRenamer("ext")(f)
72 self.assertEqual(cd_sync.name, "ext")
73 self.assertEqual(f.domains, {
74 "ext": cd_sync,
75 "pix": cd_pix,
76 })
77
78 def test_rename_cd_preserves_allow_reset_less(self):
79 cd_pix = ClockDomain(reset_less=True)
80
81 f = Fragment()
82 f.add_domains(cd_pix)
83 f.add_statements(
84 self.s1.eq(ResetSignal(allow_reset_less=True)),
85 )
86
87 f = DomainRenamer("pix")(f)
88 f = DomainLowerer()(f)
89 self.assertRepr(f.statements, """
90 (
91 (eq (sig s1) (const 1'd0))
92 )
93 """)
94
95
96 def test_rename_cd_subfragment(self):
97 cd_sync = ClockDomain()
98 cd_pix = ClockDomain()
99
100 f1 = Fragment()
101 f1.add_domains(cd_sync, cd_pix)
102 f2 = Fragment()
103 f2.add_domains(cd_sync)
104 f1.add_subfragment(f2)
105
106 f1 = DomainRenamer("ext")(f1)
107 self.assertEqual(cd_sync.name, "ext")
108 self.assertEqual(f1.domains, {
109 "ext": cd_sync,
110 "pix": cd_pix,
111 })
112
113 def test_rename_wrong_to_comb(self):
114 with self.assertRaisesRegex(ValueError,
115 r"^Domain 'sync' may not be renamed to 'comb'$"):
116 DomainRenamer("comb")
117
118 def test_rename_wrong_from_comb(self):
119 with self.assertRaisesRegex(ValueError,
120 r"^Domain 'comb' may not be renamed$"):
121 DomainRenamer({"comb": "sync"})
122
123
124 class DomainLowererTestCase(FHDLTestCase):
125 def setUp(self):
126 self.s = Signal()
127
128 def test_lower_clk(self):
129 sync = ClockDomain()
130 f = Fragment()
131 f.add_domains(sync)
132 f.add_statements(
133 self.s.eq(ClockSignal("sync"))
134 )
135
136 f = DomainLowerer()(f)
137 self.assertRepr(f.statements, """
138 (
139 (eq (sig s) (sig clk))
140 )
141 """)
142
143 def test_lower_rst(self):
144 sync = ClockDomain()
145 f = Fragment()
146 f.add_domains(sync)
147 f.add_statements(
148 self.s.eq(ResetSignal("sync"))
149 )
150
151 f = DomainLowerer()(f)
152 self.assertRepr(f.statements, """
153 (
154 (eq (sig s) (sig rst))
155 )
156 """)
157
158 def test_lower_rst_reset_less(self):
159 sync = ClockDomain(reset_less=True)
160 f = Fragment()
161 f.add_domains(sync)
162 f.add_statements(
163 self.s.eq(ResetSignal("sync", allow_reset_less=True))
164 )
165
166 f = DomainLowerer()(f)
167 self.assertRepr(f.statements, """
168 (
169 (eq (sig s) (const 1'd0))
170 )
171 """)
172
173 def test_lower_drivers(self):
174 sync = ClockDomain()
175 pix = ClockDomain()
176 f = Fragment()
177 f.add_domains(sync, pix)
178 f.add_driver(ClockSignal("pix"), None)
179 f.add_driver(ResetSignal("pix"), "sync")
180
181 f = DomainLowerer()(f)
182 self.assertEqual(f.drivers, {
183 None: SignalSet((pix.clk,)),
184 "sync": SignalSet((pix.rst,))
185 })
186
187 def test_lower_wrong_domain(self):
188 f = Fragment()
189 f.add_statements(
190 self.s.eq(ClockSignal("xxx"))
191 )
192
193 with self.assertRaisesRegex(DomainError,
194 r"^Signal \(clk xxx\) refers to nonexistent domain 'xxx'$"):
195 DomainLowerer()(f)
196
197 def test_lower_wrong_reset_less_domain(self):
198 sync = ClockDomain(reset_less=True)
199 f = Fragment()
200 f.add_domains(sync)
201 f.add_statements(
202 self.s.eq(ResetSignal("sync"))
203 )
204
205 with self.assertRaisesRegex(DomainError,
206 r"^Signal \(rst sync\) refers to reset of reset-less domain 'sync'$"):
207 DomainLowerer()(f)
208
209
210 class SampleLowererTestCase(FHDLTestCase):
211 def setUp(self):
212 self.i = Signal()
213 self.o1 = Signal()
214 self.o2 = Signal()
215 self.o3 = Signal()
216
217 def test_lower_signal(self):
218 f = Fragment()
219 f.add_statements(
220 self.o1.eq(Sample(self.i, 2, "sync")),
221 self.o2.eq(Sample(self.i, 1, "sync")),
222 self.o3.eq(Sample(self.i, 1, "pix")),
223 )
224
225 f = SampleLowerer()(f)
226 self.assertRepr(f.statements, """
227 (
228 (eq (sig o1) (sig $sample$s$i$sync$2))
229 (eq (sig o2) (sig $sample$s$i$sync$1))
230 (eq (sig o3) (sig $sample$s$i$pix$1))
231 (eq (sig $sample$s$i$sync$1) (sig i))
232 (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
233 (eq (sig $sample$s$i$pix$1) (sig i))
234 )
235 """)
236 self.assertEqual(len(f.drivers["sync"]), 2)
237 self.assertEqual(len(f.drivers["pix"]), 1)
238
239 def test_lower_const(self):
240 f = Fragment()
241 f.add_statements(
242 self.o1.eq(Sample(1, 2, "sync")),
243 )
244
245 f = SampleLowerer()(f)
246 self.assertRepr(f.statements, """
247 (
248 (eq (sig o1) (sig $sample$c$1$sync$2))
249 (eq (sig $sample$c$1$sync$1) (const 1'd1))
250 (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
251 )
252 """)
253 self.assertEqual(len(f.drivers["sync"]), 2)
254
255
256 class SwitchCleanerTestCase(FHDLTestCase):
257 def test_clean(self):
258 a = Signal()
259 b = Signal()
260 c = Signal()
261 stmts = [
262 Switch(a, {
263 1: a.eq(0),
264 0: [
265 b.eq(1),
266 Switch(b, {1: [
267 Switch(a|b, {})
268 ]})
269 ]
270 })
271 ]
272
273 self.assertRepr(SwitchCleaner()(stmts), """
274 (
275 (switch (sig a)
276 (case 1
277 (eq (sig a) (const 1'd0)))
278 (case 0
279 (eq (sig b) (const 1'd1)))
280 )
281 )
282 """)
283
284
285 class LHSGroupAnalyzerTestCase(FHDLTestCase):
286 def test_no_group_unrelated(self):
287 a = Signal()
288 b = Signal()
289 stmts = [
290 a.eq(0),
291 b.eq(0),
292 ]
293
294 groups = LHSGroupAnalyzer()(stmts)
295 self.assertEqual(list(groups.values()), [
296 SignalSet((a,)),
297 SignalSet((b,)),
298 ])
299
300 def test_group_related(self):
301 a = Signal()
302 b = Signal()
303 stmts = [
304 a.eq(0),
305 Cat(a, b).eq(0),
306 ]
307
308 groups = LHSGroupAnalyzer()(stmts)
309 self.assertEqual(list(groups.values()), [
310 SignalSet((a, b)),
311 ])
312
313 def test_no_loops(self):
314 a = Signal()
315 b = Signal()
316 stmts = [
317 a.eq(0),
318 Cat(a, b).eq(0),
319 Cat(a, b).eq(0),
320 ]
321
322 groups = LHSGroupAnalyzer()(stmts)
323 self.assertEqual(list(groups.values()), [
324 SignalSet((a, b)),
325 ])
326
327 def test_switch(self):
328 a = Signal()
329 b = Signal()
330 stmts = [
331 a.eq(0),
332 Switch(a, {
333 1: b.eq(0),
334 })
335 ]
336
337 groups = LHSGroupAnalyzer()(stmts)
338 self.assertEqual(list(groups.values()), [
339 SignalSet((a,)),
340 SignalSet((b,)),
341 ])
342
343 def test_lhs_empty(self):
344 stmts = [
345 Cat().eq(0)
346 ]
347
348 groups = LHSGroupAnalyzer()(stmts)
349 self.assertEqual(list(groups.values()), [
350 ])
351
352
353 class LHSGroupFilterTestCase(FHDLTestCase):
354 def test_filter(self):
355 a = Signal()
356 b = Signal()
357 c = Signal()
358 stmts = [
359 Switch(a, {
360 1: a.eq(0),
361 0: [
362 b.eq(1),
363 Switch(b, {1: []})
364 ]
365 })
366 ]
367
368 self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
369 (
370 (switch (sig a)
371 (case 1
372 (eq (sig a) (const 1'd0)))
373 (case 0 )
374 )
375 )
376 """)
377
378 def test_lhs_empty(self):
379 stmts = [
380 Cat().eq(0)
381 ]
382
383 self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
384
385
386 class ResetInserterTestCase(FHDLTestCase):
387 def setUp(self):
388 self.s1 = Signal()
389 self.s2 = Signal(reset=1)
390 self.s3 = Signal(reset=1, reset_less=True)
391 self.c1 = Signal()
392
393 def test_reset_default(self):
394 f = Fragment()
395 f.add_statements(
396 self.s1.eq(1)
397 )
398 f.add_driver(self.s1, "sync")
399
400 f = ResetInserter(self.c1)(f)
401 self.assertRepr(f.statements, """
402 (
403 (eq (sig s1) (const 1'd1))
404 (switch (sig c1)
405 (case 1 (eq (sig s1) (const 1'd0)))
406 )
407 )
408 """)
409
410 def test_reset_cd(self):
411 f = Fragment()
412 f.add_statements(
413 self.s1.eq(1),
414 self.s2.eq(0),
415 )
416 f.add_domains(ClockDomain("sync"))
417 f.add_driver(self.s1, "sync")
418 f.add_driver(self.s2, "pix")
419
420 f = ResetInserter({"pix": self.c1})(f)
421 self.assertRepr(f.statements, """
422 (
423 (eq (sig s1) (const 1'd1))
424 (eq (sig s2) (const 1'd0))
425 (switch (sig c1)
426 (case 1 (eq (sig s2) (const 1'd1)))
427 )
428 )
429 """)
430
431 def test_reset_value(self):
432 f = Fragment()
433 f.add_statements(
434 self.s2.eq(0)
435 )
436 f.add_driver(self.s2, "sync")
437
438 f = ResetInserter(self.c1)(f)
439 self.assertRepr(f.statements, """
440 (
441 (eq (sig s2) (const 1'd0))
442 (switch (sig c1)
443 (case 1 (eq (sig s2) (const 1'd1)))
444 )
445 )
446 """)
447
448 def test_reset_less(self):
449 f = Fragment()
450 f.add_statements(
451 self.s3.eq(0)
452 )
453 f.add_driver(self.s3, "sync")
454
455 f = ResetInserter(self.c1)(f)
456 self.assertRepr(f.statements, """
457 (
458 (eq (sig s3) (const 1'd0))
459 (switch (sig c1)
460 (case 1 )
461 )
462 )
463 """)
464
465
466 class EnableInserterTestCase(FHDLTestCase):
467 def setUp(self):
468 self.s1 = Signal()
469 self.s2 = Signal()
470 self.s3 = Signal()
471 self.c1 = Signal()
472
473 def test_enable_default(self):
474 f = Fragment()
475 f.add_statements(
476 self.s1.eq(1)
477 )
478 f.add_driver(self.s1, "sync")
479
480 f = EnableInserter(self.c1)(f)
481 self.assertRepr(f.statements, """
482 (
483 (eq (sig s1) (const 1'd1))
484 (switch (sig c1)
485 (case 0 (eq (sig s1) (sig s1)))
486 )
487 )
488 """)
489
490 def test_enable_cd(self):
491 f = Fragment()
492 f.add_statements(
493 self.s1.eq(1),
494 self.s2.eq(0),
495 )
496 f.add_driver(self.s1, "sync")
497 f.add_driver(self.s2, "pix")
498
499 f = EnableInserter({"pix": self.c1})(f)
500 self.assertRepr(f.statements, """
501 (
502 (eq (sig s1) (const 1'd1))
503 (eq (sig s2) (const 1'd0))
504 (switch (sig c1)
505 (case 0 (eq (sig s2) (sig s2)))
506 )
507 )
508 """)
509
510 def test_enable_subfragment(self):
511 f1 = Fragment()
512 f1.add_statements(
513 self.s1.eq(1)
514 )
515 f1.add_driver(self.s1, "sync")
516
517 f2 = Fragment()
518 f2.add_statements(
519 self.s2.eq(1)
520 )
521 f2.add_driver(self.s2, "sync")
522 f1.add_subfragment(f2)
523
524 f1 = EnableInserter(self.c1)(f1)
525 (f2, _), = f1.subfragments
526 self.assertRepr(f1.statements, """
527 (
528 (eq (sig s1) (const 1'd1))
529 (switch (sig c1)
530 (case 0 (eq (sig s1) (sig s1)))
531 )
532 )
533 """)
534 self.assertRepr(f2.statements, """
535 (
536 (eq (sig s2) (const 1'd1))
537 (switch (sig c1)
538 (case 0 (eq (sig s2) (sig s2)))
539 )
540 )
541 """)
542
543 def test_enable_read_port(self):
544 mem = Memory(width=8, depth=4)
545 f = EnableInserter(self.c1)(mem.read_port(transparent=False)).elaborate(platform=None)
546 self.assertRepr(f.named_ports["EN"][0], """
547 (m (sig c1) (sig mem_r_en) (const 1'd0))
548 """)
549
550 def test_enable_write_port(self):
551 mem = Memory(width=8, depth=4)
552 f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None)
553 self.assertRepr(f.named_ports["EN"][0], """
554 (m (sig c1) (cat (repl (slice (sig mem_w_en) 0:1) 8)) (const 8'd0))
555 """)
556
557
558 class _MockElaboratable(Elaboratable):
559 def __init__(self):
560 self.s1 = Signal()
561
562 def elaborate(self, platform):
563 f = Fragment()
564 f.add_statements(
565 self.s1.eq(1)
566 )
567 f.add_driver(self.s1, "sync")
568 return f
569
570
571 class TransformedElaboratableTestCase(FHDLTestCase):
572 def setUp(self):
573 self.c1 = Signal()
574 self.c2 = Signal()
575
576 def test_getattr(self):
577 e = _MockElaboratable()
578 te = EnableInserter(self.c1)(e)
579
580 self.assertIs(te.s1, e.s1)
581
582 def test_composition(self):
583 e = _MockElaboratable()
584 te1 = EnableInserter(self.c1)(e)
585 te2 = ResetInserter(self.c2)(te1)
586
587 self.assertIsInstance(te1, TransformedElaboratable)
588 self.assertIs(te1, te2)
589
590 f = Fragment.get(te2, None)
591 self.assertRepr(f.statements, """
592 (
593 (eq (sig s1) (const 1'd1))
594 (switch (sig c1)
595 (case 0 (eq (sig s1) (sig s1)))
596 )
597 (switch (sig c2)
598 (case 1 (eq (sig s1) (const 1'd0)))
599 )
600 )
601 """)
602
603
604 class MockUserValue(UserValue):
605 def __init__(self, lowered):
606 super().__init__()
607 self.lowered = lowered
608
609 def lower(self):
610 return self.lowered
611
612
613 class UserValueTestCase(FHDLTestCase):
614 def setUp(self):
615 self.s = Signal()
616 self.c = Signal()
617 self.uv = MockUserValue(self.s)
618
619 def test_lower(self):
620 sync = ClockDomain()
621 f = Fragment()
622 f.add_domains(sync)
623 f.add_statements(
624 self.uv.eq(1)
625 )
626 for signal in self.uv._lhs_signals():
627 f.add_driver(signal, "sync")
628
629 f = ResetInserter(self.c)(f)
630 f = DomainLowerer()(f)
631 self.assertRepr(f.statements, """
632 (
633 (eq (sig s) (const 1'd1))
634 (switch (sig c)
635 (case 1 (eq (sig s) (const 1'd0)))
636 )
637 (switch (sig rst)
638 (case 1 (eq (sig s) (const 1'd0)))
639 )
640 )
641 """)
642
643
644 class UserValueRecursiveTestCase(UserValueTestCase):
645 def setUp(self):
646 self.s = Signal()
647 self.c = Signal()
648 self.uv = MockUserValue(MockUserValue(self.s))
649
650 # inherit the test_lower method from UserValueTestCase because the checks are the same