Add testing for Timeline elaboratable
[gram.git] / gram / compat.py
1 # This file is Copyright (c) 2020 LambdaConcept <contact@lambdaconcept.com>
2
3 import unittest
4
5 from nmigen import *
6 from nmigen import tracer
7 from nmigen.compat import Case
8 from nmigen.back.pysim import *
9
10 __ALL__ = ["delayed_enter", "RoundRobin", "Timeline", "CSRPrefixProxy"]
11
12
13 def delayed_enter(m, src, dst, delay):
14 assert delay > 0
15
16 for i in range(delay):
17 if i == 0:
18 statename = src
19 else:
20 statename = "{}-{}".format(src, i)
21
22 if i == delay-1:
23 deststate = dst
24 else:
25 deststate = "{}-{}".format(src, i+1)
26
27 with m.State(statename):
28 m.next = deststate
29
30 # Original nMigen implementation by HarryHo90sHK
31
32
33 class RoundRobin(Elaboratable):
34 """A round-robin scheduler.
35 Parameters
36 ----------
37 n : int
38 Maximum number of requests to handle.
39 Attributes
40 ----------
41 request : Signal(n)
42 Signal where a '1' on the i-th bit represents an incoming request from the i-th device.
43 grant : Signal(range(n))
44 Signal that equals to the index of the device which is currently granted access.
45 stb : Signal()
46 Strobe signal to enable granting access to the next device requesting. Externally driven.
47 """
48
49 def __init__(self, n):
50 self.n = n
51 self.request = Signal(n)
52 self.grant = Signal(range(n))
53 self.stb = Signal()
54
55 def elaborate(self, platform):
56 m = Module()
57
58 with m.If(self.stb):
59 with m.Switch(self.grant):
60 for i in range(self.n):
61 with m.Case(i):
62 for j in reversed(range(i+1, i+self.n)):
63 # If i+1 <= j < n, then t == j; (after i)
64 # If n <= j < i+n, then t == j - n (before i)
65 t = j % self.n
66 with m.If(self.request[t]):
67 m.d.sync += self.grant.eq(t)
68
69 return m
70
71
72 class Timeline(Elaboratable):
73 def __init__(self, events):
74 self.trigger = Signal()
75 self._events = events
76
77 def elaborate(self, platform):
78 m = Module()
79
80 lastevent = max([e[0] for e in self._events])
81 counter = Signal(range(lastevent+1))
82
83 # Counter incrementation
84 # (with overflow handling)
85 if (lastevent & (lastevent + 1)) != 0:
86 with m.If(counter == lastevent):
87 m.d.sync += counter.eq(0)
88 with m.Else():
89 with m.If(counter != 0):
90 m.d.sync += counter.eq(counter+1)
91 with m.Elif(self.trigger):
92 m.d.sync += counter.eq(1)
93 else:
94 with m.If(counter != 0):
95 m.d.sync += counter.eq(counter+1)
96 with m.Elif(self.trigger):
97 m.d.sync += counter.eq(1)
98
99 for e in self._events:
100 if e[0] == 0:
101 with m.If(self.trigger & (counter == 0)):
102 m.d.sync += e[1]
103 else:
104 with m.If(counter == e[0]):
105 m.d.sync += e[1]
106
107 return m
108
109 class TimelineTestCase(unittest.TestCase):
110 def test_sequence(self):
111 sigA = Signal()
112 sigB = Signal()
113 sigC = Signal()
114 timeline = Timeline([
115 (1, sigA.eq(1)),
116 (5, sigA.eq(1)),
117 (7, sigA.eq(0)),
118 (10, sigB.eq(1)),
119 (11, sigB.eq(0)),
120 ])
121 m = Module()
122 m.submodules.timeline = timeline
123
124 def process():
125 # Test default value for unset signals
126 self.assertFalse((yield sigA))
127 self.assertFalse((yield sigB))
128
129 # Ensure that the sequence isn't triggered without the trigger signal
130 for i in range(100):
131 yield
132 self.assertFalse((yield sigA))
133 self.assertFalse((yield sigB))
134
135 yield timeline.trigger.eq(1)
136 yield
137
138 for i in range(11+1):
139 yield
140
141 if i == 1:
142 self.assertTrue((yield sigA))
143 self.assertFalse((yield sigB))
144 elif i == 5:
145 self.assertTrue((yield sigA))
146 self.assertFalse((yield sigB))
147 elif i == 7:
148 self.assertFalse((yield sigA))
149 self.assertFalse((yield sigB))
150 elif i == 10:
151 self.assertFalse((yield sigA))
152 self.assertTrue((yield sigB))
153 elif i == 11:
154 self.assertFalse((yield sigA))
155 self.assertFalse((yield sigB))
156
157 sim = Simulator(m)
158 with sim.write_vcd("test_compat.vcd"):
159 sim.add_clock(1e-6)
160 sim.add_sync_process(process)
161 sim.run()
162
163
164 class CSRPrefixProxy:
165 def __init__(self, bank, prefix):
166 self._bank = bank
167 self._prefix = prefix
168
169 def csr(self, width, access, *, addr=None, alignment=None, name=None,
170 src_loc_at=0):
171 if name is not None and not isinstance(name, str):
172 raise TypeError("Name must be a string, not {!r}".format(name))
173 name = name or tracer.get_var_name(depth=2 + src_loc_at).lstrip("_")
174
175 prefixed_name = "{}_{}".format(self._prefix, name)
176 return self._bank.csr(width=width, access=access, addr=addr,
177 alignment=alignment, name=prefixed_name)