soc/interconnect/axi: add AXIBurst2Beat
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Apr 2019 10:13:16 +0000 (12:13 +0200)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Apr 2019 10:13:16 +0000 (12:13 +0200)
Converts AXI bursts commands to AXI beats.

litex/soc/interconnect/axi.py
test/test_axi.py [new file with mode: 0644]

index 4692cbfdca6c2fc57ca36b03ba0005b7176c0871..5cb1e9adda9c3aaefec63199e6bd30a317c691d9 100644 (file)
@@ -1,3 +1,5 @@
+"""AXI4 support for LiteX"""
+
 from migen import *
 
 from litex.soc.interconnect import stream
@@ -61,6 +63,70 @@ class AXIInterface(Record):
         self.ar = stream.Endpoint(ax_description(address_width, id_width))
         self.r = stream.Endpoint(r_description(data_width, id_width))
 
+# AXI Bursts to Beats ------------------------------------------------------------------------------
+
+class AXIBurst2Beat(Module):
+    def __init__(self, ax_burst, ax_beat):
+
+        # # #
+
+        self.count = count = Signal(8)
+        size = Signal(8 + 4)
+        offset = Signal(8 + 4)
+
+        # convert burst size to bytes
+        cases = {}
+        cases["default"] = size.eq(1024)
+        for i in range(10):
+            cases[i] = size.eq(2**i)
+        self.comb += Case(ax_burst.size, cases)
+
+        # fsm
+        self.submodules.fsm = fsm = FSM(reset_state="IDLE")
+        fsm.act("IDLE",
+            ax_beat.valid.eq(ax_burst.valid),
+            ax_beat.first.eq(1),
+            ax_beat.last.eq(ax_burst.len == 0),
+            ax_beat.addr.eq(ax_burst.addr),
+            ax_beat.id.eq(ax_burst.id),
+            If(ax_beat.valid & ax_beat.ready,
+                If(ax_burst.len != 0,
+                    NextState("BURST2BEAT")
+                ).Else(
+                    ax_burst.ready.eq(1)
+                )
+            ),
+            NextValue(count, 1),
+            NextValue(offset, size),
+        )
+        wrap_offset = Signal(8 + 4)
+        self.sync += wrap_offset.eq((ax_burst.len - 1)*size)
+        fsm.act("BURST2BEAT",
+            ax_beat.valid.eq(1),
+            ax_beat.first.eq(0),
+            ax_beat.last.eq(count == ax_burst.len),
+            If((ax_burst.burst == BURST_INCR) |
+               (ax_burst.burst == BURST_WRAP),
+                ax_beat.addr.eq(ax_burst.addr + offset)
+            ).Else(
+                ax_beat.addr.eq(ax_burst.addr)
+            ),
+            ax_beat.id.eq(ax_burst.id),
+            If(ax_beat.valid & ax_beat.ready,
+                If(ax_beat.last,
+                    ax_burst.ready.eq(1),
+                    NextState("IDLE")
+                ),
+                NextValue(count, count + 1),
+                NextValue(offset, offset + size),
+                If(ax_burst.burst == BURST_WRAP,
+                    If(offset == wrap_offset,
+                        NextValue(offset, 0)
+                    )
+                )
+            )
+        )
+
 # AXI to Wishbone ----------------------------------------------------------------------------------
 
 class AXI2Wishbone(Module):
diff --git a/test/test_axi.py b/test/test_axi.py
new file mode 100644 (file)
index 0000000..6c51caa
--- /dev/null
@@ -0,0 +1,96 @@
+import unittest
+import random
+
+from migen import *
+
+from litedram.common import *
+from litedram.frontend.axi import *
+
+from litex.gen.sim import *
+
+
+class Burst:
+    def __init__(self, addr, type=BURST_FIXED, len=0, size=0):
+        self.addr = addr
+        self.type = type
+        self.len = len
+        self.size = size
+
+    def to_beats(self):
+        r = []
+        for i in range(self.len + 1):
+            if self.type == BURST_INCR:
+                offset = i*2**(self.size)
+                r += [Beat(self.addr + offset)]
+            elif self.type == BURST_WRAP:
+                offset = (i*2**(self.size))%((2**self.size)*(self.len))
+                r += [Beat(self.addr + offset)]
+            else:
+                r += [Beat(self.addr)]
+        return r
+
+
+class Beat:
+    def __init__(self, addr):
+        self.addr = addr
+
+
+class TestAXI(unittest.TestCase):
+    def test_burst2beat(self):
+        def bursts_generator(ax, bursts, valid_rand=50):
+            prng = random.Random(42)
+            for burst in bursts:
+                yield ax.valid.eq(1)
+                yield ax.addr.eq(burst.addr)
+                yield ax.burst.eq(burst.type)
+                yield ax.len.eq(burst.len)
+                yield ax.size.eq(burst.size)
+                while (yield ax.ready) == 0:
+                    yield
+                yield ax.valid.eq(0)
+                while prng.randrange(100) < valid_rand:
+                    yield
+                yield
+
+        @passive
+        def beats_checker(ax, beats, ready_rand=50):
+            self.errors = 0
+            yield ax.ready.eq(0)
+            prng = random.Random(42)
+            for beat in beats:
+                while ((yield ax.valid) and (yield ax.ready)) == 0:
+                    if prng.randrange(100) > ready_rand:
+                        yield ax.ready.eq(1)
+                    else:
+                        yield ax.ready.eq(0)
+                    yield
+                ax_addr = (yield ax.addr)
+                if ax_addr != beat.addr:
+                    self.errors += 1
+                yield
+
+        # dut
+        ax_burst = stream.Endpoint(ax_description(32, 32))
+        ax_beat = stream.Endpoint(ax_description(32, 32))
+        dut =  AXIBurst2Beat(ax_burst, ax_beat)
+
+        # generate dut input (bursts)
+        prng = random.Random(42)
+        bursts = []
+        for i in range(32):
+            bursts.append(Burst(prng.randrange(2**32), BURST_FIXED, prng.randrange(255), log2_int(32//8)))
+            bursts.append(Burst(prng.randrange(2**32), BURST_INCR, prng.randrange(255), log2_int(32//8)))
+        bursts.append(Burst(4, BURST_WRAP, 4-1, log2_int(2)))
+
+        # generate expected dut output (beats for reference)
+        beats = []
+        for burst in bursts:
+            beats += burst.to_beats()
+
+        # simulation
+        generators = [
+            bursts_generator(ax_burst, bursts),
+            beats_checker(ax_beat, beats)
+        ]
+        run_simulation(dut, generators)
+        self.assertEqual(self.errors, 0)