hdl.xfrm: lower resets in DomainLowerer as well.
[nmigen.git] / nmigen / test / test_hdl_xfrm.py
1 from ..hdl.ast import *
2 from ..hdl.cd import *
3 from ..hdl.ir import *
4 from ..hdl.xfrm import *
5 from ..hdl.mem import *
6 from .tools import *
7
8
9 class DomainRenamerTestCase(FHDLTestCase):
10 def setUp(self):
11 self.s1 = Signal()
12 self.s2 = Signal()
13 self.s3 = Signal()
14 self.s4 = Signal()
15 self.s5 = Signal()
16 self.c1 = Signal()
17
18 def test_rename_signals(self):
19 f = Fragment()
20 f.add_statements(
21 self.s1.eq(ClockSignal()),
22 ResetSignal().eq(self.s2),
23 self.s3.eq(0),
24 self.s4.eq(ClockSignal("other")),
25 self.s5.eq(ResetSignal("other")),
26 )
27 f.add_driver(self.s1, None)
28 f.add_driver(self.s2, None)
29 f.add_driver(self.s3, "sync")
30
31 f = DomainRenamer("pix")(f)
32 self.assertRepr(f.statements, """
33 (
34 (eq (sig s1) (clk pix))
35 (eq (rst pix) (sig s2))
36 (eq (sig s3) (const 1'd0))
37 (eq (sig s4) (clk other))
38 (eq (sig s5) (rst other))
39 )
40 """)
41 self.assertEqual(f.drivers, {
42 None: SignalSet((self.s1, self.s2)),
43 "pix": SignalSet((self.s3,)),
44 })
45
46 def test_rename_multi(self):
47 f = Fragment()
48 f.add_statements(
49 self.s1.eq(ClockSignal()),
50 self.s2.eq(ResetSignal("other")),
51 )
52
53 f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
54 self.assertRepr(f.statements, """
55 (
56 (eq (sig s1) (clk pix))
57 (eq (sig s2) (rst pix2))
58 )
59 """)
60
61 def test_rename_cd(self):
62 cd_sync = ClockDomain()
63 cd_pix = ClockDomain()
64
65 f = Fragment()
66 f.add_domains(cd_sync, cd_pix)
67
68 f = DomainRenamer("ext")(f)
69 self.assertEqual(cd_sync.name, "ext")
70 self.assertEqual(f.domains, {
71 "ext": cd_sync,
72 "pix": cd_pix,
73 })
74
75 def test_rename_cd_subfragment(self):
76 cd_sync = ClockDomain()
77 cd_pix = ClockDomain()
78
79 f1 = Fragment()
80 f1.add_domains(cd_sync, cd_pix)
81 f2 = Fragment()
82 f2.add_domains(cd_sync)
83 f1.add_subfragment(f2)
84
85 f1 = DomainRenamer("ext")(f1)
86 self.assertEqual(cd_sync.name, "ext")
87 self.assertEqual(f1.domains, {
88 "ext": cd_sync,
89 "pix": cd_pix,
90 })
91
92 def test_rename_wrong_to_comb(self):
93 with self.assertRaises(ValueError,
94 msg="Domain 'sync' may not be renamed to 'comb'"):
95 DomainRenamer("comb")
96
97 def test_rename_wrong_from_comb(self):
98 with self.assertRaises(ValueError,
99 msg="Domain 'comb' may not be renamed"):
100 DomainRenamer({"comb": "sync"})
101
102
103 class DomainLowererTestCase(FHDLTestCase):
104 def setUp(self):
105 self.s = Signal()
106
107 def test_lower_clk(self):
108 sync = ClockDomain()
109 f = Fragment()
110 f.add_domains(sync)
111 f.add_statements(
112 self.s.eq(ClockSignal("sync"))
113 )
114
115 f = DomainLowerer()(f)
116 self.assertRepr(f.statements, """
117 (
118 (eq (sig s) (sig clk))
119 )
120 """)
121
122 def test_lower_rst(self):
123 sync = ClockDomain()
124 f = Fragment()
125 f.add_domains(sync)
126 f.add_statements(
127 self.s.eq(ResetSignal("sync"))
128 )
129
130 f = DomainLowerer()(f)
131 self.assertRepr(f.statements, """
132 (
133 (eq (sig s) (sig rst))
134 )
135 """)
136
137 def test_lower_rst_reset_less(self):
138 sync = ClockDomain(reset_less=True)
139 f = Fragment()
140 f.add_domains(sync)
141 f.add_statements(
142 self.s.eq(ResetSignal("sync", allow_reset_less=True))
143 )
144
145 f = DomainLowerer()(f)
146 self.assertRepr(f.statements, """
147 (
148 (eq (sig s) (const 1'd0))
149 )
150 """)
151
152 def test_lower_drivers(self):
153 sync = ClockDomain()
154 pix = ClockDomain()
155 f = Fragment()
156 f.add_domains(sync, pix)
157 f.add_driver(ClockSignal("pix"), None)
158 f.add_driver(ResetSignal("pix"), "sync")
159
160 f = DomainLowerer()(f)
161 self.assertEqual(f.drivers, {
162 None: SignalSet((pix.clk,)),
163 "sync": SignalSet((pix.rst,))
164 })
165
166 def test_lower_wrong_domain(self):
167 f = Fragment()
168 f.add_statements(
169 self.s.eq(ClockSignal("xxx"))
170 )
171
172 with self.assertRaises(DomainError,
173 msg="Signal (clk xxx) refers to nonexistent domain 'xxx'"):
174 DomainLowerer()(f)
175
176 def test_lower_wrong_reset_less_domain(self):
177 sync = ClockDomain(reset_less=True)
178 f = Fragment()
179 f.add_domains(sync)
180 f.add_statements(
181 self.s.eq(ResetSignal("sync"))
182 )
183
184 with self.assertRaises(DomainError,
185 msg="Signal (rst sync) refers to reset of reset-less domain 'sync'"):
186 DomainLowerer()(f)
187
188
189 class SampleLowererTestCase(FHDLTestCase):
190 def setUp(self):
191 self.i = Signal()
192 self.o1 = Signal()
193 self.o2 = Signal()
194 self.o3 = Signal()
195
196 def test_lower_signal(self):
197 f = Fragment()
198 f.add_statements(
199 self.o1.eq(Sample(self.i, 2, "sync")),
200 self.o2.eq(Sample(self.i, 1, "sync")),
201 self.o3.eq(Sample(self.i, 1, "pix")),
202 )
203
204 f = SampleLowerer()(f)
205 self.assertRepr(f.statements, """
206 (
207 (eq (sig o1) (sig $sample$s$i$sync$2))
208 (eq (sig o2) (sig $sample$s$i$sync$1))
209 (eq (sig o3) (sig $sample$s$i$pix$1))
210 (eq (sig $sample$s$i$sync$1) (sig i))
211 (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
212 (eq (sig $sample$s$i$pix$1) (sig i))
213 )
214 """)
215 self.assertEqual(len(f.drivers["sync"]), 2)
216 self.assertEqual(len(f.drivers["pix"]), 1)
217
218 def test_lower_const(self):
219 f = Fragment()
220 f.add_statements(
221 self.o1.eq(Sample(1, 2, "sync")),
222 )
223
224 f = SampleLowerer()(f)
225 self.assertRepr(f.statements, """
226 (
227 (eq (sig o1) (sig $sample$c$1$sync$2))
228 (eq (sig $sample$c$1$sync$1) (const 1'd1))
229 (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
230 )
231 """)
232 self.assertEqual(len(f.drivers["sync"]), 2)
233
234
235 class SwitchCleanerTestCase(FHDLTestCase):
236 def test_clean(self):
237 a = Signal()
238 b = Signal()
239 c = Signal()
240 stmts = [
241 Switch(a, {
242 1: a.eq(0),
243 0: [
244 b.eq(1),
245 Switch(b, {1: [
246 Switch(a|b, {})
247 ]})
248 ]
249 })
250 ]
251
252 self.assertRepr(SwitchCleaner()(stmts), """
253 (
254 (switch (sig a)
255 (case 1
256 (eq (sig a) (const 1'd0)))
257 (case 0
258 (eq (sig b) (const 1'd1)))
259 )
260 )
261 """)
262
263
264 class LHSGroupAnalyzerTestCase(FHDLTestCase):
265 def test_no_group_unrelated(self):
266 a = Signal()
267 b = Signal()
268 stmts = [
269 a.eq(0),
270 b.eq(0),
271 ]
272
273 groups = LHSGroupAnalyzer()(stmts)
274 self.assertEqual(list(groups.values()), [
275 SignalSet((a,)),
276 SignalSet((b,)),
277 ])
278
279 def test_group_related(self):
280 a = Signal()
281 b = Signal()
282 stmts = [
283 a.eq(0),
284 Cat(a, b).eq(0),
285 ]
286
287 groups = LHSGroupAnalyzer()(stmts)
288 self.assertEqual(list(groups.values()), [
289 SignalSet((a, b)),
290 ])
291
292 def test_no_loops(self):
293 a = Signal()
294 b = Signal()
295 stmts = [
296 a.eq(0),
297 Cat(a, b).eq(0),
298 Cat(a, b).eq(0),
299 ]
300
301 groups = LHSGroupAnalyzer()(stmts)
302 self.assertEqual(list(groups.values()), [
303 SignalSet((a, b)),
304 ])
305
306 def test_switch(self):
307 a = Signal()
308 b = Signal()
309 stmts = [
310 a.eq(0),
311 Switch(a, {
312 1: b.eq(0),
313 })
314 ]
315
316 groups = LHSGroupAnalyzer()(stmts)
317 self.assertEqual(list(groups.values()), [
318 SignalSet((a,)),
319 SignalSet((b,)),
320 ])
321
322 def test_lhs_empty(self):
323 stmts = [
324 Cat().eq(0)
325 ]
326
327 groups = LHSGroupAnalyzer()(stmts)
328 self.assertEqual(list(groups.values()), [
329 ])
330
331
332 class LHSGroupFilterTestCase(FHDLTestCase):
333 def test_filter(self):
334 a = Signal()
335 b = Signal()
336 c = Signal()
337 stmts = [
338 Switch(a, {
339 1: a.eq(0),
340 0: [
341 b.eq(1),
342 Switch(b, {1: []})
343 ]
344 })
345 ]
346
347 self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
348 (
349 (switch (sig a)
350 (case 1
351 (eq (sig a) (const 1'd0)))
352 (case 0 )
353 )
354 )
355 """)
356
357 def test_lhs_empty(self):
358 stmts = [
359 Cat().eq(0)
360 ]
361
362 self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
363
364
365 class ResetInserterTestCase(FHDLTestCase):
366 def setUp(self):
367 self.s1 = Signal()
368 self.s2 = Signal(reset=1)
369 self.s3 = Signal(reset=1, reset_less=True)
370 self.c1 = Signal()
371
372 def test_reset_default(self):
373 f = Fragment()
374 f.add_statements(
375 self.s1.eq(1)
376 )
377 f.add_driver(self.s1, "sync")
378
379 f = ResetInserter(self.c1)(f)
380 self.assertRepr(f.statements, """
381 (
382 (eq (sig s1) (const 1'd1))
383 (switch (sig c1)
384 (case 1 (eq (sig s1) (const 1'd0)))
385 )
386 )
387 """)
388
389 def test_reset_cd(self):
390 f = Fragment()
391 f.add_statements(
392 self.s1.eq(1),
393 self.s2.eq(0),
394 )
395 f.add_domains(ClockDomain("sync"))
396 f.add_driver(self.s1, "sync")
397 f.add_driver(self.s2, "pix")
398
399 f = ResetInserter({"pix": self.c1})(f)
400 self.assertRepr(f.statements, """
401 (
402 (eq (sig s1) (const 1'd1))
403 (eq (sig s2) (const 1'd0))
404 (switch (sig c1)
405 (case 1 (eq (sig s2) (const 1'd1)))
406 )
407 )
408 """)
409
410 def test_reset_value(self):
411 f = Fragment()
412 f.add_statements(
413 self.s2.eq(0)
414 )
415 f.add_driver(self.s2, "sync")
416
417 f = ResetInserter(self.c1)(f)
418 self.assertRepr(f.statements, """
419 (
420 (eq (sig s2) (const 1'd0))
421 (switch (sig c1)
422 (case 1 (eq (sig s2) (const 1'd1)))
423 )
424 )
425 """)
426
427 def test_reset_less(self):
428 f = Fragment()
429 f.add_statements(
430 self.s3.eq(0)
431 )
432 f.add_driver(self.s3, "sync")
433
434 f = ResetInserter(self.c1)(f)
435 self.assertRepr(f.statements, """
436 (
437 (eq (sig s3) (const 1'd0))
438 (switch (sig c1)
439 (case 1 )
440 )
441 )
442 """)
443
444
445 class EnableInserterTestCase(FHDLTestCase):
446 def setUp(self):
447 self.s1 = Signal()
448 self.s2 = Signal()
449 self.s3 = Signal()
450 self.c1 = Signal()
451
452 def test_enable_default(self):
453 f = Fragment()
454 f.add_statements(
455 self.s1.eq(1)
456 )
457 f.add_driver(self.s1, "sync")
458
459 f = EnableInserter(self.c1)(f)
460 self.assertRepr(f.statements, """
461 (
462 (eq (sig s1) (const 1'd1))
463 (switch (sig c1)
464 (case 0 (eq (sig s1) (sig s1)))
465 )
466 )
467 """)
468
469 def test_enable_cd(self):
470 f = Fragment()
471 f.add_statements(
472 self.s1.eq(1),
473 self.s2.eq(0),
474 )
475 f.add_driver(self.s1, "sync")
476 f.add_driver(self.s2, "pix")
477
478 f = EnableInserter({"pix": self.c1})(f)
479 self.assertRepr(f.statements, """
480 (
481 (eq (sig s1) (const 1'd1))
482 (eq (sig s2) (const 1'd0))
483 (switch (sig c1)
484 (case 0 (eq (sig s2) (sig s2)))
485 )
486 )
487 """)
488
489 def test_enable_subfragment(self):
490 f1 = Fragment()
491 f1.add_statements(
492 self.s1.eq(1)
493 )
494 f1.add_driver(self.s1, "sync")
495
496 f2 = Fragment()
497 f2.add_statements(
498 self.s2.eq(1)
499 )
500 f2.add_driver(self.s2, "sync")
501 f1.add_subfragment(f2)
502
503 f1 = EnableInserter(self.c1)(f1)
504 (f2, _), = f1.subfragments
505 self.assertRepr(f1.statements, """
506 (
507 (eq (sig s1) (const 1'd1))
508 (switch (sig c1)
509 (case 0 (eq (sig s1) (sig s1)))
510 )
511 )
512 """)
513 self.assertRepr(f2.statements, """
514 (
515 (eq (sig s2) (const 1'd1))
516 (switch (sig c1)
517 (case 0 (eq (sig s2) (sig s2)))
518 )
519 )
520 """)
521
522 def test_enable_read_port(self):
523 mem = Memory(width=8, depth=4)
524 f = EnableInserter(self.c1)(mem.read_port(transparent=False)).elaborate(platform=None)
525 self.assertRepr(f.named_ports["EN"][0], """
526 (m (sig c1) (sig mem_r_en) (const 1'd0))
527 """)
528
529 def test_enable_write_port(self):
530 mem = Memory(width=8, depth=4)
531 f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None)
532 self.assertRepr(f.named_ports["EN"][0], """
533 (m (sig c1) (cat (repl (slice (sig mem_w_en) 0:1) 8)) (const 8'd0))
534 """)
535
536
537 class _MockElaboratable(Elaboratable):
538 def __init__(self):
539 self.s1 = Signal()
540
541 def elaborate(self, platform):
542 f = Fragment()
543 f.add_statements(
544 self.s1.eq(1)
545 )
546 f.add_driver(self.s1, "sync")
547 return f
548
549
550 class TransformedElaboratableTestCase(FHDLTestCase):
551 def setUp(self):
552 self.c1 = Signal()
553 self.c2 = Signal()
554
555 def test_getattr(self):
556 e = _MockElaboratable()
557 te = EnableInserter(self.c1)(e)
558
559 self.assertIs(te.s1, e.s1)
560
561 def test_composition(self):
562 e = _MockElaboratable()
563 te1 = EnableInserter(self.c1)(e)
564 te2 = ResetInserter(self.c2)(te1)
565
566 self.assertIsInstance(te1, TransformedElaboratable)
567 self.assertIs(te1, te2)
568
569 f = Fragment.get(te2, None)
570 self.assertRepr(f.statements, """
571 (
572 (eq (sig s1) (const 1'd1))
573 (switch (sig c1)
574 (case 0 (eq (sig s1) (sig s1)))
575 )
576 (switch (sig c2)
577 (case 1 (eq (sig s1) (const 1'd0)))
578 )
579 )
580 """)
581
582
583 class MockUserValue(UserValue):
584 def __init__(self, lowered):
585 super().__init__()
586 self.lowered = lowered
587
588 def lower(self):
589 return self.lowered
590
591
592 class UserValueTestCase(FHDLTestCase):
593 def setUp(self):
594 self.s = Signal()
595 self.c = Signal()
596 self.uv = MockUserValue(self.s)
597
598 def test_lower(self):
599 sync = ClockDomain()
600 f = Fragment()
601 f.add_domains(sync)
602 f.add_statements(
603 self.uv.eq(1)
604 )
605 for signal in self.uv._lhs_signals():
606 f.add_driver(signal, "sync")
607
608 f = ResetInserter(self.c)(f)
609 f = DomainLowerer()(f)
610 self.assertRepr(f.statements, """
611 (
612 (eq (sig s) (const 1'd1))
613 (switch (sig c)
614 (case 1 (eq (sig s) (const 1'd0)))
615 )
616 (switch (sig rst)
617 (case 1 (eq (sig s) (const 1'd0)))
618 )
619 )
620 """)