hdl.ast: fix shape calculation for *.
[nmigen.git] / nmigen / test / test_hdl_ast.py
1 from ..hdl.ast import *
2 from .tools import *
3
4
5 class ValueTestCase(FHDLTestCase):
6 def test_wrap(self):
7 self.assertIsInstance(Value.wrap(0), Const)
8 self.assertIsInstance(Value.wrap(True), Const)
9 c = Const(0)
10 self.assertIs(Value.wrap(c), c)
11 with self.assertRaises(TypeError):
12 Value.wrap("str")
13
14 def test_bool(self):
15 with self.assertRaises(TypeError):
16 if Const(0):
17 pass
18
19 def test_len(self):
20 self.assertEqual(len(Const(10)), 4)
21
22 def test_getitem_int(self):
23 s1 = Const(10)[0]
24 self.assertIsInstance(s1, Slice)
25 self.assertEqual(s1.start, 0)
26 self.assertEqual(s1.end, 1)
27 s2 = Const(10)[-1]
28 self.assertIsInstance(s2, Slice)
29 self.assertEqual(s2.start, 3)
30 self.assertEqual(s2.end, 4)
31 with self.assertRaises(IndexError):
32 Const(10)[5]
33
34 def test_getitem_slice(self):
35 s1 = Const(10)[1:3]
36 self.assertIsInstance(s1, Slice)
37 self.assertEqual(s1.start, 1)
38 self.assertEqual(s1.end, 3)
39 s2 = Const(10)[1:-2]
40 self.assertIsInstance(s2, Slice)
41 self.assertEqual(s2.start, 1)
42 self.assertEqual(s2.end, 2)
43 s3 = Const(31)[::2]
44 self.assertIsInstance(s3, Cat)
45 self.assertIsInstance(s3.parts[0], Slice)
46 self.assertEqual(s3.parts[0].start, 0)
47 self.assertEqual(s3.parts[0].end, 1)
48 self.assertIsInstance(s3.parts[1], Slice)
49 self.assertEqual(s3.parts[1].start, 2)
50 self.assertEqual(s3.parts[1].end, 3)
51 self.assertIsInstance(s3.parts[2], Slice)
52 self.assertEqual(s3.parts[2].start, 4)
53 self.assertEqual(s3.parts[2].end, 5)
54
55 def test_getitem_wrong(self):
56 with self.assertRaises(TypeError):
57 Const(31)["str"]
58
59
60 class ConstTestCase(FHDLTestCase):
61 def test_shape(self):
62 self.assertEqual(Const(0).shape(), (1, False))
63 self.assertEqual(Const(1).shape(), (1, False))
64 self.assertEqual(Const(10).shape(), (4, False))
65 self.assertEqual(Const(-10).shape(), (5, True))
66
67 self.assertEqual(Const(1, 4).shape(), (4, False))
68 self.assertEqual(Const(1, (4, True)).shape(), (4, True))
69 self.assertEqual(Const(0, (0, False)).shape(), (0, False))
70
71 def test_shape_bad(self):
72 with self.assertRaises(TypeError):
73 Const(1, -1)
74
75 def test_normalization(self):
76 self.assertEqual(Const(0b10110, (5, True)).value, -10)
77
78 def test_value(self):
79 self.assertEqual(Const(10).value, 10)
80
81 def test_repr(self):
82 self.assertEqual(repr(Const(10)), "(const 4'd10)")
83 self.assertEqual(repr(Const(-10)), "(const 5'sd-10)")
84
85 def test_hash(self):
86 with self.assertRaises(TypeError):
87 hash(Const(0))
88
89
90 class OperatorTestCase(FHDLTestCase):
91 def test_bool(self):
92 v = Const(0, 4).bool()
93 self.assertEqual(repr(v), "(b (const 4'd0))")
94 self.assertEqual(v.shape(), (1, False))
95
96 def test_invert(self):
97 v = ~Const(0, 4)
98 self.assertEqual(repr(v), "(~ (const 4'd0))")
99 self.assertEqual(v.shape(), (4, False))
100
101 def test_neg(self):
102 v1 = -Const(0, (4, False))
103 self.assertEqual(repr(v1), "(- (const 4'd0))")
104 self.assertEqual(v1.shape(), (5, True))
105 v2 = -Const(0, (4, True))
106 self.assertEqual(repr(v2), "(- (const 4'sd0))")
107 self.assertEqual(v2.shape(), (4, True))
108
109 def test_add(self):
110 v1 = Const(0, (4, False)) + Const(0, (6, False))
111 self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))")
112 self.assertEqual(v1.shape(), (7, False))
113 v2 = Const(0, (4, True)) + Const(0, (6, True))
114 self.assertEqual(v2.shape(), (7, True))
115 v3 = Const(0, (4, True)) + Const(0, (4, False))
116 self.assertEqual(v3.shape(), (6, True))
117 v4 = Const(0, (4, False)) + Const(0, (4, True))
118 self.assertEqual(v4.shape(), (6, True))
119 v5 = 10 + Const(0, 4)
120 self.assertEqual(v5.shape(), (5, False))
121
122 def test_sub(self):
123 v1 = Const(0, (4, False)) - Const(0, (6, False))
124 self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
125 self.assertEqual(v1.shape(), (7, False))
126 v2 = Const(0, (4, True)) - Const(0, (6, True))
127 self.assertEqual(v2.shape(), (7, True))
128 v3 = Const(0, (4, True)) - Const(0, (4, False))
129 self.assertEqual(v3.shape(), (6, True))
130 v4 = Const(0, (4, False)) - Const(0, (4, True))
131 self.assertEqual(v4.shape(), (6, True))
132 v5 = 10 - Const(0, 4)
133 self.assertEqual(v5.shape(), (5, False))
134
135 def test_mul(self):
136 v1 = Const(0, (4, False)) * Const(0, (6, False))
137 self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))")
138 self.assertEqual(v1.shape(), (10, False))
139 v2 = Const(0, (4, True)) * Const(0, (6, True))
140 self.assertEqual(v2.shape(), (10, True))
141 v3 = Const(0, (4, True)) * Const(0, (4, False))
142 self.assertEqual(v3.shape(), (8, True))
143 v5 = 10 * Const(0, 4)
144 self.assertEqual(v5.shape(), (8, False))
145
146 def test_and(self):
147 v1 = Const(0, (4, False)) & Const(0, (6, False))
148 self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))")
149 self.assertEqual(v1.shape(), (6, False))
150 v2 = Const(0, (4, True)) & Const(0, (6, True))
151 self.assertEqual(v2.shape(), (6, True))
152 v3 = Const(0, (4, True)) & Const(0, (4, False))
153 self.assertEqual(v3.shape(), (5, True))
154 v4 = Const(0, (4, False)) & Const(0, (4, True))
155 self.assertEqual(v4.shape(), (5, True))
156 v5 = 10 & Const(0, 4)
157 self.assertEqual(v5.shape(), (4, False))
158
159 def test_or(self):
160 v1 = Const(0, (4, False)) | Const(0, (6, False))
161 self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))")
162 self.assertEqual(v1.shape(), (6, False))
163 v2 = Const(0, (4, True)) | Const(0, (6, True))
164 self.assertEqual(v2.shape(), (6, True))
165 v3 = Const(0, (4, True)) | Const(0, (4, False))
166 self.assertEqual(v3.shape(), (5, True))
167 v4 = Const(0, (4, False)) | Const(0, (4, True))
168 self.assertEqual(v4.shape(), (5, True))
169 v5 = 10 | Const(0, 4)
170 self.assertEqual(v5.shape(), (4, False))
171
172 def test_xor(self):
173 v1 = Const(0, (4, False)) ^ Const(0, (6, False))
174 self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))")
175 self.assertEqual(v1.shape(), (6, False))
176 v2 = Const(0, (4, True)) ^ Const(0, (6, True))
177 self.assertEqual(v2.shape(), (6, True))
178 v3 = Const(0, (4, True)) ^ Const(0, (4, False))
179 self.assertEqual(v3.shape(), (5, True))
180 v4 = Const(0, (4, False)) ^ Const(0, (4, True))
181 self.assertEqual(v4.shape(), (5, True))
182 v5 = 10 ^ Const(0, 4)
183 self.assertEqual(v5.shape(), (4, False))
184
185 def test_shl(self):
186 v1 = Const(1, 4) << Const(4)
187 self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))")
188 self.assertEqual(v1.shape(), (11, False))
189 v2 = Const(1, 4) << Const(-3)
190 self.assertEqual(v2.shape(), (7, False))
191
192 def test_shr(self):
193 v1 = Const(1, 4) >> Const(4)
194 self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))")
195 self.assertEqual(v1.shape(), (4, False))
196 v2 = Const(1, 4) >> Const(-3)
197 self.assertEqual(v2.shape(), (8, False))
198
199 def test_lt(self):
200 v = Const(0, 4) < Const(0, 6)
201 self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
202 self.assertEqual(v.shape(), (1, False))
203
204 def test_le(self):
205 v = Const(0, 4) <= Const(0, 6)
206 self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))")
207 self.assertEqual(v.shape(), (1, False))
208
209 def test_gt(self):
210 v = Const(0, 4) > Const(0, 6)
211 self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))")
212 self.assertEqual(v.shape(), (1, False))
213
214 def test_ge(self):
215 v = Const(0, 4) >= Const(0, 6)
216 self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))")
217 self.assertEqual(v.shape(), (1, False))
218
219 def test_eq(self):
220 v = Const(0, 4) == Const(0, 6)
221 self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))")
222 self.assertEqual(v.shape(), (1, False))
223
224 def test_ne(self):
225 v = Const(0, 4) != Const(0, 6)
226 self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))")
227 self.assertEqual(v.shape(), (1, False))
228
229 def test_mux(self):
230 s = Const(0)
231 v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
232 self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))")
233 self.assertEqual(v1.shape(), (6, False))
234 v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
235 self.assertEqual(v2.shape(), (6, True))
236 v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False)))
237 self.assertEqual(v3.shape(), (5, True))
238 v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True)))
239 self.assertEqual(v4.shape(), (5, True))
240
241 def test_bool(self):
242 v = Const(0).bool()
243 self.assertEqual(repr(v), "(b (const 1'd0))")
244 self.assertEqual(v.shape(), (1, False))
245
246 def test_hash(self):
247 with self.assertRaises(TypeError):
248 hash(Const(0) + Const(0))
249
250
251 class SliceTestCase(FHDLTestCase):
252 def test_shape(self):
253 s1 = Const(10)[2]
254 self.assertEqual(s1.shape(), (1, False))
255 s2 = Const(-10)[0:2]
256 self.assertEqual(s2.shape(), (2, False))
257
258 def test_start_end_negative(self):
259 c = Const(0, 8)
260 s1 = Slice(c, 0, -1)
261 self.assertEqual((s1.start, s1.end), (0, 7))
262 s1 = Slice(c, -4, -1)
263 self.assertEqual((s1.start, s1.end), (4, 7))
264
265 def test_start_end_wrong(self):
266 with self.assertRaises(TypeError):
267 Slice(0, "x", 1)
268 with self.assertRaises(TypeError):
269 Slice(0, 1, "x")
270
271 def test_start_end_out_of_range(self):
272 c = Const(0, 8)
273 with self.assertRaises(IndexError):
274 Slice(c, 10, 12)
275 with self.assertRaises(IndexError):
276 Slice(c, 0, 12)
277 with self.assertRaises(IndexError):
278 Slice(c, 4, 2)
279
280 def test_repr(self):
281 s1 = Const(10)[2]
282 self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
283
284
285 class PartTestCase(FHDLTestCase):
286 def setUp(self):
287 self.c = Const(0, 8)
288 self.s = Signal(max=self.c.nbits)
289
290 def test_shape(self):
291 s1 = self.c.part(self.s, 2)
292 self.assertEqual(s1.shape(), (2, False))
293 s2 = self.c.part(self.s, 0)
294 self.assertEqual(s2.shape(), (0, False))
295
296 def test_width_bad(self):
297 with self.assertRaises(TypeError):
298 self.c.part(self.s, -1)
299
300 def test_repr(self):
301 s = self.c.part(self.s, 2)
302 self.assertEqual(repr(s), "(part (const 8'd0) (sig s) 2)")
303
304
305 class CatTestCase(FHDLTestCase):
306 def test_shape(self):
307 c1 = Cat(Const(10))
308 self.assertEqual(c1.shape(), (4, False))
309 c2 = Cat(Const(10), Const(1))
310 self.assertEqual(c2.shape(), (5, False))
311 c3 = Cat(Const(10), Const(1), Const(0))
312 self.assertEqual(c3.shape(), (6, False))
313
314 def test_repr(self):
315 c1 = Cat(Const(10), Const(1))
316 self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))")
317
318
319 class ReplTestCase(FHDLTestCase):
320 def test_shape(self):
321 s1 = Repl(Const(10), 3)
322 self.assertEqual(s1.shape(), (12, False))
323 s2 = Repl(Const(10), 0)
324 self.assertEqual(s2.shape(), (0, False))
325
326 def test_count_wrong(self):
327 with self.assertRaises(TypeError):
328 Repl(Const(10), -1)
329 with self.assertRaises(TypeError):
330 Repl(Const(10), "str")
331
332 def test_repr(self):
333 s = Repl(Const(10), 3)
334 self.assertEqual(repr(s), "(repl (const 4'd10) 3)")
335
336
337 class ArrayTestCase(FHDLTestCase):
338 def test_acts_like_array(self):
339 a = Array([1,2,3])
340 self.assertSequenceEqual(a, [1,2,3])
341 self.assertEqual(a[1], 2)
342 a[1] = 4
343 self.assertSequenceEqual(a, [1,4,3])
344 del a[1]
345 self.assertSequenceEqual(a, [1,3])
346 a.insert(1, 2)
347 self.assertSequenceEqual(a, [1,2,3])
348
349 def test_becomes_immutable(self):
350 a = Array([1,2,3])
351 s1 = Signal(max=len(a))
352 s2 = Signal(max=len(a))
353 v1 = a[s1]
354 v2 = a[s2]
355 with self.assertRaisesRegex(ValueError,
356 regex=r"^Array can no longer be mutated after it was indexed with a value at "):
357 a[1] = 2
358 with self.assertRaisesRegex(ValueError,
359 regex=r"^Array can no longer be mutated after it was indexed with a value at "):
360 del a[1]
361 with self.assertRaisesRegex(ValueError,
362 regex=r"^Array can no longer be mutated after it was indexed with a value at "):
363 a.insert(1, 2)
364
365 def test_repr(self):
366 a = Array([1,2,3])
367 self.assertEqual(repr(a), "(array mutable [1, 2, 3])")
368 s = Signal(max=len(a))
369 v = a[s]
370 self.assertEqual(repr(a), "(array [1, 2, 3])")
371
372
373 class ArrayProxyTestCase(FHDLTestCase):
374 def test_index_shape(self):
375 m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4))
376 a = Signal(max=3)
377 b = Signal(max=3)
378 v = m[a][b]
379 self.assertEqual(v.shape(), (4, False))
380
381 def test_attr_shape(self):
382 from collections import namedtuple
383 pair = namedtuple("pair", ("p", "n"))
384 a = Array(pair(i, -i) for i in range(10))
385 s = Signal(max=len(a))
386 v = a[s]
387 self.assertEqual(v.p.shape(), (4, False))
388 self.assertEqual(v.n.shape(), (6, True))
389
390 def test_repr(self):
391 a = Array([1, 2, 3])
392 s = Signal(max=3)
393 v = a[s]
394 self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))")
395
396
397 class SignalTestCase(FHDLTestCase):
398 def test_shape(self):
399 s1 = Signal()
400 self.assertEqual(s1.shape(), (1, False))
401 s2 = Signal(2)
402 self.assertEqual(s2.shape(), (2, False))
403 s3 = Signal((2, False))
404 self.assertEqual(s3.shape(), (2, False))
405 s4 = Signal((2, True))
406 self.assertEqual(s4.shape(), (2, True))
407 s5 = Signal(max=16)
408 self.assertEqual(s5.shape(), (4, False))
409 s6 = Signal(min=4, max=16)
410 self.assertEqual(s6.shape(), (4, False))
411 s7 = Signal(min=-4, max=16)
412 self.assertEqual(s7.shape(), (5, True))
413 s8 = Signal(min=-20, max=16)
414 self.assertEqual(s8.shape(), (6, True))
415 s9 = Signal(0)
416 self.assertEqual(s9.shape(), (0, False))
417
418 def test_shape_bad(self):
419 with self.assertRaises(ValueError):
420 Signal(min=10, max=4)
421 with self.assertRaises(ValueError):
422 Signal(2, min=10)
423 with self.assertRaises(TypeError):
424 Signal(-10)
425
426 def test_name(self):
427 s1 = Signal()
428 self.assertEqual(s1.name, "s1")
429 s2 = Signal(name="sig")
430 self.assertEqual(s2.name, "sig")
431
432 def test_reset(self):
433 s1 = Signal(4, reset=0b111, reset_less=True)
434 self.assertEqual(s1.reset, 0b111)
435 self.assertEqual(s1.reset_less, True)
436
437 def test_attrs(self):
438 s1 = Signal()
439 self.assertEqual(s1.attrs, {})
440 s2 = Signal(attrs={"no_retiming": True})
441 self.assertEqual(s2.attrs, {"no_retiming": True})
442
443 def test_repr(self):
444 s1 = Signal()
445 self.assertEqual(repr(s1), "(sig s1)")
446
447 def test_like(self):
448 s1 = Signal.like(Signal(4))
449 self.assertEqual(s1.shape(), (4, False))
450 s2 = Signal.like(Signal(min=-15))
451 self.assertEqual(s2.shape(), (5, True))
452 s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
453 self.assertEqual(s3.reset, 0b111)
454 self.assertEqual(s3.reset_less, True)
455 s4 = Signal.like(Signal(attrs={"no_retiming": True}))
456 self.assertEqual(s4.attrs, {"no_retiming": True})
457 s5 = Signal.like(Signal(decoder=str))
458 self.assertEqual(s5.decoder, str)
459 s6 = Signal.like(10)
460 self.assertEqual(s6.shape(), (4, False))
461 s7 = [Signal.like(Signal(4))][0]
462 self.assertEqual(s7.name, "$like")
463
464
465 class ClockSignalTestCase(FHDLTestCase):
466 def test_domain(self):
467 s1 = ClockSignal()
468 self.assertEqual(s1.domain, "sync")
469 s2 = ClockSignal("pix")
470 self.assertEqual(s2.domain, "pix")
471
472 with self.assertRaises(TypeError):
473 ClockSignal(1)
474
475 def test_shape(self):
476 self.assertEqual(ClockSignal().shape(), (1, False))
477
478 def test_repr(self):
479 s1 = ClockSignal()
480 self.assertEqual(repr(s1), "(clk sync)")
481
482
483 class ResetSignalTestCase(FHDLTestCase):
484 def test_domain(self):
485 s1 = ResetSignal()
486 self.assertEqual(s1.domain, "sync")
487 s2 = ResetSignal("pix")
488 self.assertEqual(s2.domain, "pix")
489
490 with self.assertRaises(TypeError):
491 ResetSignal(1)
492
493 def test_shape(self):
494 self.assertEqual(ResetSignal().shape(), (1, False))
495
496 def test_repr(self):
497 s1 = ResetSignal()
498 self.assertEqual(repr(s1), "(rst sync)")
499
500
501 class SampleTestCase(FHDLTestCase):
502 def test_const(self):
503 s = Sample(1, 1, "sync")
504 self.assertEqual(s.shape(), (1, False))
505
506 def test_signal(self):
507 s1 = Sample(Signal(2), 1, "sync")
508 self.assertEqual(s1.shape(), (2, False))
509 s2 = Sample(ClockSignal(), 1, "sync")
510 s3 = Sample(ResetSignal(), 1, "sync")
511
512 def test_wrong_value_operator(self):
513 with self.assertRaises(TypeError,
514 "Sampled value may only be a signal or a constant, not "
515 "(+ (sig $signal) (const 1'd1))"):
516 Sample(Signal() + 1, 1, "sync")
517
518 def test_wrong_clocks_neg(self):
519 with self.assertRaises(ValueError,
520 "Cannot sample a value 1 cycles in the future"):
521 Sample(Signal(), -1, "sync")