nmigen.hdl.rec: restore Record.shape().
[nmigen.git] / tests / test_hdl_rec.py
1 from enum import Enum
2
3 from nmigen.hdl.ast import *
4 from nmigen.hdl.rec import *
5
6 from .utils import *
7
8
9 class UnsignedEnum(Enum):
10 FOO = 1
11 BAR = 2
12 BAZ = 3
13
14
15 class LayoutTestCase(FHDLTestCase):
16 def assertFieldEqual(self, field, expected):
17 (shape, dir) = field
18 shape = Shape.cast(shape)
19 self.assertEqual((shape, dir), expected)
20
21 def test_fields(self):
22 layout = Layout.cast([
23 ("cyc", 1),
24 ("data", signed(32)),
25 ("stb", 1, DIR_FANOUT),
26 ("ack", 1, DIR_FANIN),
27 ("info", [
28 ("a", 1),
29 ("b", 1),
30 ])
31 ])
32
33 self.assertFieldEqual(layout["cyc"], ((1, False), DIR_NONE))
34 self.assertFieldEqual(layout["data"], ((32, True), DIR_NONE))
35 self.assertFieldEqual(layout["stb"], ((1, False), DIR_FANOUT))
36 self.assertFieldEqual(layout["ack"], ((1, False), DIR_FANIN))
37 sublayout = layout["info"][0]
38 self.assertEqual(layout["info"][1], DIR_NONE)
39 self.assertFieldEqual(sublayout["a"], ((1, False), DIR_NONE))
40 self.assertFieldEqual(sublayout["b"], ((1, False), DIR_NONE))
41
42 def test_enum_field(self):
43 layout = Layout.cast([
44 ("enum", UnsignedEnum),
45 ("enum_dir", UnsignedEnum, DIR_FANOUT),
46 ])
47 self.assertFieldEqual(layout["enum"], ((2, False), DIR_NONE))
48 self.assertFieldEqual(layout["enum_dir"], ((2, False), DIR_FANOUT))
49
50 def test_range_field(self):
51 layout = Layout.cast([
52 ("range", range(0, 7)),
53 ])
54 self.assertFieldEqual(layout["range"], ((3, False), DIR_NONE))
55
56 def test_slice_tuple(self):
57 layout = Layout.cast([
58 ("a", 1),
59 ("b", 2),
60 ("c", 3)
61 ])
62 expect = Layout.cast([
63 ("a", 1),
64 ("c", 3)
65 ])
66 self.assertEqual(layout["a", "c"], expect)
67
68 def test_repr(self):
69 self.assertEqual(repr(Layout([("a", unsigned(1)), ("b", signed(2))])),
70 "Layout([('a', unsigned(1)), ('b', signed(2))])")
71 self.assertEqual(repr(Layout([("a", unsigned(1)), ("b", [("c", signed(3))])])),
72 "Layout([('a', unsigned(1)), "
73 "('b', Layout([('c', signed(3))]))])")
74
75 def test_wrong_field(self):
76 with self.assertRaisesRegex(TypeError,
77 (r"^Field \(1,\) has invalid layout: should be either \(name, shape\) or "
78 r"\(name, shape, direction\)$")):
79 Layout.cast([(1,)])
80
81 def test_wrong_name(self):
82 with self.assertRaisesRegex(TypeError,
83 r"^Field \(1, 1\) has invalid name: should be a string$"):
84 Layout.cast([(1, 1)])
85
86 def test_wrong_name_duplicate(self):
87 with self.assertRaisesRegex(NameError,
88 r"^Field \('a', 2\) has a name that is already present in the layout$"):
89 Layout.cast([("a", 1), ("a", 2)])
90
91 def test_wrong_direction(self):
92 with self.assertRaisesRegex(TypeError,
93 (r"^Field \('a', 1, 0\) has invalid direction: should be a Direction "
94 r"instance like DIR_FANIN$")):
95 Layout.cast([("a", 1, 0)])
96
97 def test_wrong_shape(self):
98 with self.assertRaisesRegex(TypeError,
99 (r"^Field \('a', 'x'\) has invalid shape: should be castable to Shape or "
100 r"a list of fields of a nested record$")):
101 Layout.cast([("a", "x")])
102
103
104 class RecordTestCase(FHDLTestCase):
105 def test_basic(self):
106 r = Record([
107 ("stb", 1),
108 ("data", 32),
109 ("info", [
110 ("a", 1),
111 ("b", 1),
112 ])
113 ])
114
115 self.assertEqual(repr(r), "(rec r stb data (rec r__info a b))")
116 self.assertEqual(len(r), 35)
117 self.assertIsInstance(r.stb, Signal)
118 self.assertEqual(r.stb.name, "r__stb")
119 self.assertEqual(r["stb"].name, "r__stb")
120
121 self.assertTrue(hasattr(r, "stb"))
122 self.assertFalse(hasattr(r, "xxx"))
123
124 def test_unnamed(self):
125 r = [Record([
126 ("stb", 1)
127 ])][0]
128
129 self.assertEqual(repr(r), "(rec <unnamed> stb)")
130 self.assertEqual(r.stb.name, "stb")
131
132 def test_iter(self):
133 r = Record([
134 ("data", 4),
135 ("stb", 1),
136 ])
137
138 self.assertEqual(repr(r[0]), "(slice (cat (sig r__data) (sig r__stb)) 0:1)")
139 self.assertEqual(repr(r[0:3]), "(slice (cat (sig r__data) (sig r__stb)) 0:3)")
140
141 def test_wrong_field(self):
142 r = Record([
143 ("stb", 1),
144 ("ack", 1),
145 ])
146 with self.assertRaisesRegex(AttributeError,
147 r"^Record 'r' does not have a field 'en'\. Did you mean one of: stb, ack\?$"):
148 r["en"]
149 with self.assertRaisesRegex(AttributeError,
150 r"^Record 'r' does not have a field 'en'\. Did you mean one of: stb, ack\?$"):
151 r.en
152
153 def test_wrong_field_unnamed(self):
154 r = [Record([
155 ("stb", 1),
156 ("ack", 1),
157 ])][0]
158 with self.assertRaisesRegex(AttributeError,
159 r"^Unnamed record does not have a field 'en'\. Did you mean one of: stb, ack\?$"):
160 r.en
161
162 def test_construct_with_fields(self):
163 ns = Signal(1)
164 nr = Record([
165 ("burst", 1)
166 ])
167 r = Record([
168 ("stb", 1),
169 ("info", [
170 ("burst", 1)
171 ])
172 ], fields={
173 "stb": ns,
174 "info": nr
175 })
176 self.assertIs(r.stb, ns)
177 self.assertIs(r.info, nr)
178
179 def test_shape(self):
180 r1 = Record([("a", 1), ("b", 2)])
181 self.assertEqual(r1.shape(), unsigned(3))
182
183 def test_like(self):
184 r1 = Record([("a", 1), ("b", 2)])
185 r2 = Record.like(r1)
186 self.assertEqual(r1.layout, r2.layout)
187 self.assertEqual(r2.name, "r2")
188 r3 = Record.like(r1, name="foo")
189 self.assertEqual(r3.name, "foo")
190 r4 = Record.like(r1, name_suffix="foo")
191 self.assertEqual(r4.name, "r1foo")
192
193 def test_like_modifications(self):
194 r1 = Record([("a", 1), ("b", [("s", 1)])])
195 self.assertEqual(r1.a.name, "r1__a")
196 self.assertEqual(r1.b.name, "r1__b")
197 self.assertEqual(r1.b.s.name, "r1__b__s")
198 r1.a.reset = 1
199 r1.b.s.reset = 1
200 r2 = Record.like(r1)
201 self.assertEqual(r2.a.reset, 1)
202 self.assertEqual(r2.b.s.reset, 1)
203 self.assertEqual(r2.a.name, "r2__a")
204 self.assertEqual(r2.b.name, "r2__b")
205 self.assertEqual(r2.b.s.name, "r2__b__s")
206
207 def test_slice_tuple(self):
208 r1 = Record([("a", 1), ("b", 2), ("c", 3)])
209 r2 = r1["a", "c"]
210 self.assertEqual(r2.layout, Layout([("a", 1), ("c", 3)]))
211 self.assertIs(r2.a, r1.a)
212 self.assertIs(r2.c, r1.c)
213
214 def test_enum_decoder(self):
215 r1 = Record([("a", UnsignedEnum)])
216 self.assertEqual(r1.a.decoder(UnsignedEnum.FOO), "FOO/1")
217
218 def test_operators(self):
219 r1 = Record([("a", 1)])
220 s1 = Signal()
221
222 # __bool__
223 with self.assertRaisesRegex(TypeError,
224 r"^Attempted to convert nMigen value to Python boolean$"):
225 not r1
226
227 # __invert__, __neg__
228 self.assertEqual(repr(~r1), "(~ (cat (sig r1__a)))")
229 self.assertEqual(repr(-r1), "(- (cat (sig r1__a)))")
230
231 # __add__, __radd__, __sub__, __rsub__
232 self.assertEqual(repr(r1 + 1), "(+ (cat (sig r1__a)) (const 1'd1))")
233 self.assertEqual(repr(r1 + s1), "(+ (cat (sig r1__a)) (sig s1))")
234 self.assertEqual(repr(1 + r1), "(+ (const 1'd1) (cat (sig r1__a)))")
235 self.assertEqual(repr(s1 + r1), "(+ (sig s1) (cat (sig r1__a)))")
236 self.assertEqual(repr(r1 - 1), "(- (cat (sig r1__a)) (const 1'd1))")
237 self.assertEqual(repr(r1 - s1), "(- (cat (sig r1__a)) (sig s1))")
238 self.assertEqual(repr(1 - r1), "(- (const 1'd1) (cat (sig r1__a)))")
239 self.assertEqual(repr(s1 - r1), "(- (sig s1) (cat (sig r1__a)))")
240
241 # __mul__, __rmul__
242 self.assertEqual(repr(r1 * 1), "(* (cat (sig r1__a)) (const 1'd1))")
243 self.assertEqual(repr(r1 * s1), "(* (cat (sig r1__a)) (sig s1))")
244 self.assertEqual(repr(1 * r1), "(* (const 1'd1) (cat (sig r1__a)))")
245 self.assertEqual(repr(s1 * r1), "(* (sig s1) (cat (sig r1__a)))")
246
247 # __mod__, __rmod__, __floordiv__, __rfloordiv__
248 self.assertEqual(repr(r1 % 1), "(% (cat (sig r1__a)) (const 1'd1))")
249 self.assertEqual(repr(r1 % s1), "(% (cat (sig r1__a)) (sig s1))")
250 self.assertEqual(repr(1 % r1), "(% (const 1'd1) (cat (sig r1__a)))")
251 self.assertEqual(repr(s1 % r1), "(% (sig s1) (cat (sig r1__a)))")
252 self.assertEqual(repr(r1 // 1), "(// (cat (sig r1__a)) (const 1'd1))")
253 self.assertEqual(repr(r1 // s1), "(// (cat (sig r1__a)) (sig s1))")
254 self.assertEqual(repr(1 // r1), "(// (const 1'd1) (cat (sig r1__a)))")
255 self.assertEqual(repr(s1 // r1), "(// (sig s1) (cat (sig r1__a)))")
256
257 # __lshift__, __rlshift__, __rshift__, __rrshift__
258 self.assertEqual(repr(r1 >> 1), "(>> (cat (sig r1__a)) (const 1'd1))")
259 self.assertEqual(repr(r1 >> s1), "(>> (cat (sig r1__a)) (sig s1))")
260 self.assertEqual(repr(1 >> r1), "(>> (const 1'd1) (cat (sig r1__a)))")
261 self.assertEqual(repr(s1 >> r1), "(>> (sig s1) (cat (sig r1__a)))")
262 self.assertEqual(repr(r1 << 1), "(<< (cat (sig r1__a)) (const 1'd1))")
263 self.assertEqual(repr(r1 << s1), "(<< (cat (sig r1__a)) (sig s1))")
264 self.assertEqual(repr(1 << r1), "(<< (const 1'd1) (cat (sig r1__a)))")
265 self.assertEqual(repr(s1 << r1), "(<< (sig s1) (cat (sig r1__a)))")
266
267 # __and__, __rand__, __xor__, __rxor__, __or__, __ror__
268 self.assertEqual(repr(r1 & 1), "(& (cat (sig r1__a)) (const 1'd1))")
269 self.assertEqual(repr(r1 & s1), "(& (cat (sig r1__a)) (sig s1))")
270 self.assertEqual(repr(1 & r1), "(& (const 1'd1) (cat (sig r1__a)))")
271 self.assertEqual(repr(s1 & r1), "(& (sig s1) (cat (sig r1__a)))")
272 self.assertEqual(repr(r1 ^ 1), "(^ (cat (sig r1__a)) (const 1'd1))")
273 self.assertEqual(repr(r1 ^ s1), "(^ (cat (sig r1__a)) (sig s1))")
274 self.assertEqual(repr(1 ^ r1), "(^ (const 1'd1) (cat (sig r1__a)))")
275 self.assertEqual(repr(s1 ^ r1), "(^ (sig s1) (cat (sig r1__a)))")
276 self.assertEqual(repr(r1 | 1), "(| (cat (sig r1__a)) (const 1'd1))")
277 self.assertEqual(repr(r1 | s1), "(| (cat (sig r1__a)) (sig s1))")
278 self.assertEqual(repr(1 | r1), "(| (const 1'd1) (cat (sig r1__a)))")
279 self.assertEqual(repr(s1 | r1), "(| (sig s1) (cat (sig r1__a)))")
280
281 # __eq__, __ne__, __lt__, __le__, __gt__, __ge__
282 self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))")
283 self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))")
284 self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))")
285 self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))")
286 self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))")
287 self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))")
288 self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))")
289 self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))")
290 self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))")
291 self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))")
292 self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))")
293 self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))")
294 self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))")
295 self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))")
296 self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))")
297 self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))")
298 self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))")
299 self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))")
300
301 # __abs__, __len__
302 self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")
303 self.assertEqual(len(r1), 1)
304
305 # as_unsigned, as_signed, bool, any, all, xor, implies
306 self.assertEqual(repr(r1.as_unsigned()), "(u (cat (sig r1__a)))")
307 self.assertEqual(repr(r1.as_signed()), "(s (cat (sig r1__a)))")
308 self.assertEqual(repr(r1.bool()), "(b (cat (sig r1__a)))")
309 self.assertEqual(repr(r1.any()), "(r| (cat (sig r1__a)))")
310 self.assertEqual(repr(r1.all()), "(r& (cat (sig r1__a)))")
311 self.assertEqual(repr(r1.xor()), "(r^ (cat (sig r1__a)))")
312 self.assertEqual(repr(r1.implies(1)), "(| (~ (cat (sig r1__a))) (const 1'd1))")
313 self.assertEqual(repr(r1.implies(s1)), "(| (~ (cat (sig r1__a))) (sig s1))")
314
315 # bit_select, word_select, matches,
316 self.assertEqual(repr(r1.bit_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)")
317 self.assertEqual(repr(r1.word_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)")
318 self.assertEqual(repr(r1.matches("1")),
319 "(== (& (cat (sig r1__a)) (const 1'd1)) (const 1'd1))")
320
321 # shift_left, shift_right, rotate_left, rotate_right, eq
322 self.assertEqual(repr(r1.shift_left(1)), "(cat (const 1'd0) (cat (sig r1__a)))")
323 self.assertEqual(repr(r1.shift_right(1)), "(slice (cat (sig r1__a)) 1:1)")
324 self.assertEqual(repr(r1.rotate_left(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
325 self.assertEqual(repr(r1.rotate_right(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
326 self.assertEqual(repr(r1.eq(1)), "(eq (cat (sig r1__a)) (const 1'd1))")
327 self.assertEqual(repr(r1.eq(s1)), "(eq (cat (sig r1__a)) (sig s1))")
328
329
330 class ConnectTestCase(FHDLTestCase):
331 def setUp_flat(self):
332 self.core_layout = [
333 ("addr", 32, DIR_FANOUT),
334 ("data_r", 32, DIR_FANIN),
335 ("data_w", 32, DIR_FANIN),
336 ]
337 self.periph_layout = [
338 ("addr", 32, DIR_FANOUT),
339 ("data_r", 32, DIR_FANIN),
340 ("data_w", 32, DIR_FANIN),
341 ]
342
343 def setUp_nested(self):
344 self.core_layout = [
345 ("addr", 32, DIR_FANOUT),
346 ("data", [
347 ("r", 32, DIR_FANIN),
348 ("w", 32, DIR_FANIN),
349 ]),
350 ]
351 self.periph_layout = [
352 ("addr", 32, DIR_FANOUT),
353 ("data", [
354 ("r", 32, DIR_FANIN),
355 ("w", 32, DIR_FANIN),
356 ]),
357 ]
358
359 def test_flat(self):
360 self.setUp_flat()
361
362 core = Record(self.core_layout)
363 periph1 = Record(self.periph_layout)
364 periph2 = Record(self.periph_layout)
365
366 stmts = core.connect(periph1, periph2)
367 self.assertRepr(stmts, """(
368 (eq (sig periph1__addr) (sig core__addr))
369 (eq (sig periph2__addr) (sig core__addr))
370 (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
371 (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
372 )""")
373
374 def test_flat_include(self):
375 self.setUp_flat()
376
377 core = Record(self.core_layout)
378 periph1 = Record(self.periph_layout)
379 periph2 = Record(self.periph_layout)
380
381 stmts = core.connect(periph1, periph2, include={"addr": True})
382 self.assertRepr(stmts, """(
383 (eq (sig periph1__addr) (sig core__addr))
384 (eq (sig periph2__addr) (sig core__addr))
385 )""")
386
387 def test_flat_exclude(self):
388 self.setUp_flat()
389
390 core = Record(self.core_layout)
391 periph1 = Record(self.periph_layout)
392 periph2 = Record(self.periph_layout)
393
394 stmts = core.connect(periph1, periph2, exclude={"addr": True})
395 self.assertRepr(stmts, """(
396 (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
397 (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
398 )""")
399
400 def test_nested(self):
401 self.setUp_nested()
402
403 core = Record(self.core_layout)
404 periph1 = Record(self.periph_layout)
405 periph2 = Record(self.periph_layout)
406
407 stmts = core.connect(periph1, periph2)
408 self.maxDiff = None
409 self.assertRepr(stmts, """(
410 (eq (sig periph1__addr) (sig core__addr))
411 (eq (sig periph2__addr) (sig core__addr))
412 (eq (sig core__data__r) (| (sig periph1__data__r) (sig periph2__data__r)))
413 (eq (sig core__data__w) (| (sig periph1__data__w) (sig periph2__data__w)))
414 )""")
415
416 def test_wrong_include_exclude(self):
417 self.setUp_flat()
418
419 core = Record(self.core_layout)
420 periph = Record(self.periph_layout)
421
422 with self.assertRaisesRegex(AttributeError,
423 r"^Cannot include field 'foo' because it is not present in record 'core'$"):
424 core.connect(periph, include={"foo": True})
425
426 with self.assertRaisesRegex(AttributeError,
427 r"^Cannot exclude field 'foo' because it is not present in record 'core'$"):
428 core.connect(periph, exclude={"foo": True})
429
430 def test_wrong_direction(self):
431 recs = [Record([("x", 1)]) for _ in range(2)]
432
433 with self.assertRaisesRegex(TypeError,
434 (r"^Cannot connect field 'x' of unnamed record because it does not have "
435 r"a direction$")):
436 recs[0].connect(recs[1])
437
438 def test_wrong_missing_field(self):
439 core = Record([("addr", 32, DIR_FANOUT)])
440 periph = Record([])
441
442 with self.assertRaisesRegex(AttributeError,
443 (r"^Cannot connect field 'addr' of record 'core' to subordinate record 'periph' "
444 r"because the subordinate record does not have this field$")):
445 core.connect(periph)