soc/cores/spi/SPIMaster: rewrite/simplify.
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Mon, 20 Jul 2020 08:36:35 +0000 (10:36 +0200)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Mon, 20 Jul 2020 08:44:18 +0000 (10:44 +0200)
- Make sure MOSI is latched on start, MISO is stable during Xfer (last value).
- Allow clk_divider down to 2.
- improve test errors reporting with hex() on AssertEqual.

litex/soc/cores/spi.py
test/test_spi.py

index 1ebc3ff34ab36cd2aa42ea51286149b02b0707b2..d31812265362f48ab427f972113b4ccc2f720813 100644 (file)
@@ -41,94 +41,96 @@ class SPIMaster(Module, AutoCSR):
 
         # # #
 
-        done  = Signal()
-        bits  = Signal(8)
-        xfer  = Signal()
-        shift = Signal()
+        clk_enable = Signal()
+        cs_enable  = Signal()
+        count      = Signal(max=data_width)
+        mosi_latch = Signal()
+        miso_latch = Signal()
 
         # Clock generation -------------------------------------------------------------------------
         clk_divider = Signal(16)
         clk_rise    = Signal()
         clk_fall    = Signal()
+        self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1))
+        self.comb += clk_fall.eq(clk_divider == (self.clk_divider     - 1))
         self.sync += [
-            If(clk_rise, pads.clk.eq(xfer)),
-            If(clk_fall, pads.clk.eq(0)),
-            If(clk_fall,
-                clk_divider.eq(0)
-            ).Else(
-                clk_divider.eq(clk_divider + 1)
+            clk_divider.eq(clk_divider + 1),
+            If(clk_rise,
+                pads.clk.eq(clk_enable),
+            ).Elif(clk_fall,
+                clk_divider.eq(0),
+                pads.clk.eq(0),
             )
         ]
-        self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1))
-        self.comb += clk_fall.eq(clk_divider == (self.clk_divider - 1))
 
         # Control FSM ------------------------------------------------------------------------------
         self.submodules.fsm = fsm = FSM(reset_state="IDLE")
         fsm.act("IDLE",
-            done.eq(1),
+            self.done.eq(1),
             If(self.start,
-                NextValue(bits, 0),
-                NextState("WAIT-CLK-FALL")
+                self.done.eq(0),
+                mosi_latch.eq(1),
+                NextState("START")
             )
         )
-        fsm.act("WAIT-CLK-FALL",
+        fsm.act("START",
+            NextValue(count, 0),
             If(clk_fall,
-                NextState("XFER")
+                cs_enable.eq(1),
+                NextState("RUN")
             )
         )
-        fsm.act("XFER",
-            If(bits == self.length,
-                NextState("END")
-            ).Elif(clk_fall,
-                NextValue(bits, bits + 1)
-            ),
-            xfer.eq(1),
-            shift.eq(1)
+        fsm.act("RUN",
+            clk_enable.eq(1),
+            cs_enable.eq(1),
+            If(clk_fall,
+                NextValue(count, count + 1),
+                If(count == (self.length - 1),
+                    NextState("STOP")
+                )
+            )
         )
-        fsm.act("END",
+        fsm.act("STOP",
+            cs_enable.eq(1),
             If(clk_rise,
+                miso_latch.eq(1),
+                self.irq.eq(1),
                 NextState("IDLE")
-            ),
-            shift.eq(1),
-            self.irq.eq(1)
+            )
         )
-        self.sync += self.done.eq(done & ~self.start)
 
         # Chip Select generation -------------------------------------------------------------------
         if hasattr(pads, "cs_n"):
             for i in range(len(pads.cs_n)):
-                self.comb += pads.cs_n[i].eq(~self.cs[i] | ~xfer)
+                self.sync += pads.cs_n[i].eq(~self.cs[i] | ~cs_enable)
 
         # Master Out Slave In (MOSI) generation (generated on spi_clk falling edge) ----------------
-        mosi_data = Array(self.mosi[i] for i in range(data_width))
-        mosi_bit  = Signal(max=data_width)
+        mosi_data  = Signal(data_width)
+        mosi_array = Array(mosi_data[i] for i in range(data_width))
+        mosi_sel   = Signal(max=data_width)
         self.sync += [
-            If(self.start,
-                mosi_bit.eq(self.length - 1 if mode == "aligned" else data_width - 1),
-            ).Elif(clk_rise & shift,
-                mosi_bit.eq(mosi_bit - 1)
+            If(mosi_latch,
+                mosi_data.eq(self.mosi),
+                mosi_sel.eq((self.length-1) if mode == "aligned" else (data_width-1)),
+            ).Elif(clk_fall,
+                If(cs_enable, pads.mosi.eq(mosi_array[mosi_sel])),
+                mosi_sel.eq(mosi_sel - 1)
             ),
-            If(clk_fall,
-                pads.mosi.eq(mosi_data[mosi_bit])
-            )
         ]
 
         # Master In Slave Out (MISO) capture (captured on spi_clk rising edge) --------------------
         miso      = Signal()
         miso_data = Signal(data_width)
         self.sync += [
-            If(clk_rise & shift,
+            If(clk_rise,
                 If(self.loopback,
-                    miso.eq(pads.mosi)
+                    miso_data.eq(Cat(pads.mosi, miso_data))
                 ).Else(
-                    miso.eq(pads.miso)
+                    miso_data.eq(Cat(pads.miso, miso_data))
                 )
-            ),
-            If(clk_fall & shift,
-                miso_data.eq(Cat(miso, miso_data))
-            ),
-            If(done, self.miso.eq(miso_data)),
+            )
         ]
+        self.sync += If(miso_latch, self.miso.eq(miso_data))
 
     def add_csr(self, with_cs=True, with_loopback=True):
         self._control  = CSRStorage(fields=[
index 92d1ad827b56060d94883ae8e039ad1221d0201c..e9684dae3211d24fe2136e61a4dd6368722b1b22 100644 (file)
@@ -16,6 +16,7 @@ class TestSPI(unittest.TestCase):
     def test_spi_master_xfer_loopback_32b_32b(self):
         def generator(dut):
             yield dut.loopback.eq(1)
+            yield dut.clk_divider.eq(2)
             yield dut.mosi.eq(0xdeadbeef)
             yield dut.length.eq(32)
             yield dut.start.eq(1)
@@ -24,7 +25,8 @@ class TestSPI(unittest.TestCase):
             yield
             while (yield dut.done) == 0:
                 yield
-            self.assertEqual((yield dut.miso), 0xdeadbeef)
+            yield
+            self.assertEqual(hex((yield dut.miso)), hex(0xdeadbeef))
 
         dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False)
         run_simulation(dut, generator(dut))
@@ -40,7 +42,8 @@ class TestSPI(unittest.TestCase):
             yield
             while (yield dut.done) == 0:
                 yield
-            self.assertEqual((yield dut.miso), 0xbeef)
+            yield
+            self.assertEqual(hex((yield dut.miso)), hex(0xbeef))
 
         dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False, mode="aligned")
         run_simulation(dut, generator(dut))
@@ -59,6 +62,8 @@ class TestSPI(unittest.TestCase):
                 self.submodules.slave  = SPISlave(pads, data_width=32)
 
         def master_generator(dut):
+            for i in range(8):
+                yield
             yield dut.master.mosi.eq(0xdeadbeef)
             yield dut.master.length.eq(32)
             yield dut.master.start.eq(1)
@@ -67,15 +72,19 @@ class TestSPI(unittest.TestCase):
             yield
             while (yield dut.master.done) == 0:
                 yield
-            self.assertEqual((yield dut.master.miso), 0x12345678)
+            yield
+            self.assertEqual(hex((yield dut.master.miso)), hex(0x12345678))
 
         def slave_generator(dut):
+            for i in range(8):
+                yield
             yield dut.slave.miso.eq(0x12345678)
             while (yield dut.slave.start) == 0:
                 yield
             while (yield dut.slave.done) == 0:
                 yield
-            self.assertEqual((yield dut.slave.mosi), 0xdeadbeef)
+            yield
+            self.assertEqual(hex((yield dut.slave.mosi)), hex(0xdeadbeef))
             self.assertEqual((yield dut.slave.length), 32)
 
         dut = DUT()