add missing nmutil dependency
[utils.git] / uart_demo.py
1 #!/usr/bin/env python3
2 import itertools
3 import sys
4 import unittest
5 from nmigen_boards.arty_a7 import ArtyA7_100Platform
6 from nmigen.hdl.dsl import Elaboratable, Module
7 from nmigen.hdl.ast import Signal, Array, Const
8 from nmigen.hdl.mem import Memory
9 from nmigen.sim import Tick
10 from nmigen.build import ResourceError
11 from nmutil.sim_util import do_sim
12 import enum
13
14
15 def get_all_resources(platform, name):
16 if platform is None:
17 # simulating
18 return []
19 resources = []
20 for number in itertools.count():
21 try:
22 resources.append(platform.request(name, number))
23 except ResourceError:
24 break
25 return resources
26
27
28 SIM_CLOCK_FREQ = 1e6
29
30
31 class TickGenerator(Elaboratable):
32 """Generates a tick exactly `rate` times per second"""
33
34 def __init__(self, rate):
35 assert isinstance(rate, int) and rate > 0, "unsupported rate"
36 self.rate = rate
37 self.tick = Signal()
38
39 def elaborate(self, platform):
40 m = Module()
41 if platform is None:
42 orig_clk_freq = SIM_CLOCK_FREQ
43 else:
44 orig_clk_freq = platform.default_clk_frequency
45 clk_freq = int(orig_clk_freq)
46 assert clk_freq == orig_clk_freq, \
47 "non-integer clock frequencies are unsupported"
48 assert self.rate <= clk_freq, \
49 "rate can't be higher than the clock frequency"
50 counter = Signal(range(clk_freq))
51 next_count = Signal(range(clk_freq))
52 m.d.sync += counter.eq(next_count)
53 underflow = Signal()
54 m.d.comb += underflow.eq(counter < self.rate)
55 m.d.sync += self.tick.eq(underflow)
56 with m.If(underflow):
57 m.d.comb += next_count.eq(counter + (clk_freq - self.rate))
58 with m.Else():
59 m.d.comb += next_count.eq(counter - self.rate)
60 return m
61
62
63 class SimpleUART(Elaboratable):
64 """Simple transmit-only UART"""
65
66 def __init__(self, baud_rate=9600):
67 self.__tick_gen = TickGenerator(baud_rate)
68 self.data_in = Signal(8)
69 self.data_in_valid = Signal()
70 self.data_in_ready = Signal()
71 self.tx = Signal(reset=1)
72
73 @property
74 def baud_rate(self):
75 return self.__tick_gen.rate
76
77 def elaborate(self, platform):
78 m = Module()
79 m.submodules.tick_gen = self.__tick_gen
80 data = Signal.like(self.data_in)
81 data_full = Signal(reset=0)
82 m.d.comb += self.data_in_ready.eq(~data_full)
83 with m.If(self.data_in_ready & self.data_in_valid):
84 m.d.sync += data.eq(self.data_in)
85 m.d.sync += data_full.eq(True)
86
87 tx_sequence = [Const(0, 1), *data, Const(1, 1)]
88 current_bit_num = Signal(range(len(tx_sequence)),
89 reset=0)
90 with m.If(self.__tick_gen.tick & data_full):
91 m.d.sync += self.tx.eq(Array(tx_sequence)[current_bit_num])
92 with m.If(current_bit_num == len(tx_sequence) - 1):
93 m.d.sync += [
94 current_bit_num.eq(0),
95 data_full.eq(False),
96 ]
97 with m.Else():
98 m.d.sync += current_bit_num.eq(current_bit_num + 1)
99 return m
100
101
102 class UartDemo(Elaboratable):
103 def __init__(self, text):
104 self.simple_uart = SimpleUART()
105 self.text = str(text)
106 self.text_bytes = list(self.text.encode())
107
108 def elaborate(self, platform):
109 m = Module()
110 m.submodules.simple_uart = self.simple_uart
111
112 for uart in get_all_resources(platform, "uart"):
113 m.d.comb += uart.tx.o.eq(self.simple_uart.tx)
114
115 text_rom = Memory(width=8, depth=len(self.text_bytes),
116 init=self.text_bytes)
117 text_read = text_rom.read_port()
118 m.submodules.text_read = text_read
119 addr = Signal(range(len(self.text_bytes)), reset=0)
120 valid = Signal(reset=0)
121 m.d.comb += [
122 text_read.addr.eq(addr),
123 self.simple_uart.data_in.eq(text_read.data),
124 self.simple_uart.data_in_valid.eq(valid),
125 ]
126
127 with m.If(self.simple_uart.data_in_ready
128 & self.simple_uart.data_in_valid):
129 with m.If(addr == len(self.text_bytes) - 1):
130 m.d.sync += addr.eq(0)
131 with m.Else():
132 m.d.sync += addr.eq(addr + 1)
133 m.d.sync += valid.eq(0) # wait for it to propagate through memory
134 with m.Else():
135 m.d.sync += valid.eq(1)
136
137 return m
138
139
140 class TestUartDemo(unittest.TestCase):
141 def test_uart_demo(self):
142 class ExpectedState(enum.Enum):
143 DATA0 = 0
144 DATA1 = 1
145 DATA2 = 2
146 DATA3 = 3
147 DATA4 = 4
148 DATA5 = 5
149 DATA6 = 6
150 DATA7 = 7
151 START = enum.auto()
152 STOP = enum.auto()
153
154 m = Module()
155 dut = UartDemo("test text")
156 sample_event = Signal()
157 expected_state = Signal(ExpectedState)
158 m.submodules.dut = dut
159 with do_sim(self, m, [
160 dut.simple_uart.tx,
161 sample_event,
162 expected_state,
163 ]) as sim:
164 expected_bit_tick_count = round(
165 SIM_CLOCK_FREQ / dut.simple_uart.baud_rate)
166
167 def read_bit(is_initial=False):
168 yield sample_event.eq(1)
169 start_value = yield dut.simple_uart.tx
170 transition = None
171 for i in range(expected_bit_tick_count):
172 yield Tick()
173 yield sample_event.eq(0)
174 value = yield dut.simple_uart.tx
175 if value != start_value:
176 transition = i
177 break
178 if transition is not None:
179 delta = expected_bit_tick_count if is_initial else 1
180 self.assertAlmostEqual(
181 transition, expected_bit_tick_count / 2, delta=delta)
182 for i in range(expected_bit_tick_count // 2):
183 yield Tick()
184 yield sample_event.eq(0)
185 value = yield dut.simple_uart.tx
186 self.assertNotEqual(value, start_value,
187 "two transitions in one bit time")
188 return start_value
189
190 def process():
191 yield expected_state.eq(ExpectedState.START)
192 start_bit = yield from read_bit(True)
193 for i in range(3):
194 if start_bit == 0:
195 break
196 start_bit = yield from read_bit(True)
197 for i in range(3):
198 for expected_byte in dut.text_bytes:
199 with self.subTest(i=i,
200 expected_byte=hex(expected_byte)):
201 self.assertEqual(start_bit, 0, "missing start bit")
202 for bit_index in range(8):
203 with self.subTest(bit_index=bit_index):
204 yield expected_state.eq(
205 ExpectedState(bit_index))
206 data_bit = yield from read_bit()
207 expected = (expected_byte >> bit_index) & 1
208 self.assertEqual(data_bit, expected,
209 "wrong data bit")
210 yield expected_state.eq(ExpectedState.STOP)
211 stop_bit = yield from read_bit()
212 self.assertEqual(stop_bit, 1, "missing stop bit")
213 yield expected_state.eq(ExpectedState.START)
214 start_bit = yield from read_bit()
215
216 sim.add_process(process)
217 sim.add_clock(1 / SIM_CLOCK_FREQ)
218 sim.run()
219
220
221 def build(platform, do_program):
222 platform.build(UartDemo("Hello World!\n"), do_program=do_program)
223
224
225 PLATFORMS = {
226 "ArtyA7_100": ArtyA7_100Platform,
227 # TODO: add more
228 }
229
230 DEFAULT_PLATFORM = next(iter(PLATFORMS.keys()))
231 DEFAULT_TOOLCHAIN = "yosys_nextpnr"
232 DEFAULT_TEXT = "Hello World!\n"
233
234 PLATFORMS_TEXT = '\n'.join(PLATFORMS.keys())
235 HELP_TEXT = f"""
236 usage: {sys.argv[0]} program|build [<platform> [<toolchain> [<text>]]]
237
238 generate a FPGA bitstream. If `program` is specified, also program that
239 bitstream to the FPGA plugged into this computer.
240
241 <platform> the FPGA platform. Defaults to {DEFAULT_PLATFORM}.
242 <toolchain> the toolchain used. Defaults to {DEFAULT_TOOLCHAIN}.
243 <text> the text that will be repeatedly sent out all FPGA uarts.
244 Defaults to {DEFAULT_TEXT!r}.
245
246 Supported FPGA platforms:
247 {PLATFORMS_TEXT}
248
249 unittest usage:
250 """.lstrip()
251
252 if __name__ == "__main__":
253 if "-h" in sys.argv or "--help" in sys.argv:
254 print(HELP_TEXT)
255 unittest.main() # get unittest's help too
256 elif 1 < len(sys.argv) and (sys.argv[1] == "build"
257 or sys.argv[1] == "program"):
258 platform_str = sys.argv[2] if 2 < len(sys.argv) else DEFAULT_PLATFORM
259 assert platform_str in PLATFORMS, (
260 f"unsupported platform {platform_str}:\n"
261 f"valid platforms: {list(PLATFORMS.keys())}")
262 platform_cls = PLATFORMS[platform_str]
263 toolchain = sys.argv[3] if 3 < len(sys.argv) else DEFAULT_TOOLCHAIN
264 text = sys.argv[4] if 4 < len(sys.argv) else DEFAULT_TEXT
265 assert text != "", "empty text not supported"
266 platform = platform_cls(toolchain=toolchain)
267 top = UartDemo(text)
268 platform.build(top,
269 do_program=sys.argv[1] == "program")
270 else:
271 unittest.main()