soc/interconnect/axi: implement AXI Lite decoder
authorJędrzej Boczar <jboczar@antmicro.com>
Tue, 21 Jul 2020 12:25:24 +0000 (14:25 +0200)
committerJędrzej Boczar <jboczar@antmicro.com>
Wed, 22 Jul 2020 15:16:33 +0000 (17:16 +0200)
litex/soc/interconnect/axi.py
test/test_axi.py

index 40fae8be55f174675e0e4ce22f37d4611d1671da..e76abe301facb173a4484c7de0eebea7863ba553 100644 (file)
@@ -192,6 +192,7 @@ class AXILiteInterface:
     def write(self, addr, data, strb=None):
         if strb is None:
             strb = 2**len(self.w.strb) - 1
+        # aw + w
         yield self.aw.valid.eq(1)
         yield self.aw.addr.eq(addr)
         yield self.w.data.eq(data)
@@ -201,9 +202,12 @@ class AXILiteInterface:
         while not (yield self.aw.ready):
             yield
         yield self.aw.valid.eq(0)
+        yield self.aw.addr.eq(0)
         while not (yield self.w.ready):
             yield
         yield self.w.valid.eq(0)
+        yield self.w.strb.eq(0)
+        # b
         yield self.b.ready.eq(1)
         while not (yield self.b.valid):
             yield
@@ -212,12 +216,14 @@ class AXILiteInterface:
         return resp
 
     def read(self, addr):
+        # ar
         yield self.ar.valid.eq(1)
         yield self.ar.addr.eq(addr)
         yield
         while not (yield self.ar.ready):
             yield
         yield self.ar.valid.eq(0)
+        # r
         yield self.r.ready.eq(1)
         while not (yield self.r.valid):
             yield
@@ -943,6 +949,7 @@ class AXILiteRequestCounter(Module):
         self.full = full = Signal()
         self.empty = empty = Signal()
         self.stall = stall = Signal()
+        self.ready = self.empty
 
         self.comb += [
             full.eq(counter == max_requests - 1),
@@ -994,15 +1001,15 @@ class AXILiteArbiter(Module):
                         self.comb += dest.eq(source)
 
         # allow to change rr.grant only after all requests from a master have been responded to
-        self.submodules.wr_counter = wr_counter = AXILiteRequestCounter(
+        self.submodules.wr_lock = wr_lock = AXILiteRequestCounter(
             request=target.aw.valid & target.aw.ready, response=target.b.valid & target.b.ready)
-        self.submodules.rd_counter = rd_counter = AXILiteRequestCounter(
+        self.submodules.rd_lock = rd_lock = AXILiteRequestCounter(
             request=target.ar.valid & target.ar.ready, response=target.r.valid & target.r.ready)
 
         # switch to next request only if there are no responses pending
         self.comb += [
-            self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_counter.empty),
-            self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_counter.empty),
+            self.rr_write.ce.eq(~(target.aw.valid | target.w.valid | target.b.valid) & wr_lock.ready),
+            self.rr_read.ce.eq(~(target.ar.valid | target.r.valid) & rd_lock.ready),
         ]
 
         # connect bus requests to round-robin selectors
@@ -1012,54 +1019,103 @@ class AXILiteArbiter(Module):
         ]
 
 class AXILiteDecoder(Module):
-    # slaves is a list of pairs:
-    # 0) function that takes the address signal and returns a FHDL expression
-    #    that evaluates to 1 when the slave is selected and 0 otherwise.
-    # 1) wishbone.Slave reference.
-    # register adds flip-flops after the address comparators. Improves timing,
-    # but breaks Wishbone combinatorial feedback.
-    def __init__(self, master, slaves, register=False):
+    _doc_slaves = """
+    slaves: [(decoder, slave), ...]
+        List of slaves with address decoders, where `decoder` is a function:
+            decoder(Signal(address_width - log2(data_width//8))) -> Signal(1)
+        that returns 1 when the slave is selected and 0 otherwise.
+    """.strip()
+
+    __doc__ = """AXI Lite decoder
+
+    Decode master access to particular slave based on its decoder function.
+
+    {slaves}
+    """.format(slaves=_doc_slaves)
+
+    def __init__(self, master, slaves):
         addr_shift = log2_int(master.data_width//8)
-        ns = len(slaves)
-        slave_sel = Signal(ns)
-        slave_sel_r = Signal(ns)
+
+        channels = {
+            "write": {"aw", "w", "b"},
+            "read":  {"ar", "r"},
+        }
+        # reverse mapping: directions[channel] -> "write"/"read"
+        directions = {ch: d for d, chs in channels.items() for ch in chs}
+
+        def new_slave_sel():
+            return {"write": Signal(len(slaves)), "read":  Signal(len(slaves))}
+
+        slave_sel_dec = new_slave_sel()
+        slave_sel_reg = new_slave_sel()
+        slave_sel     = new_slave_sel()
+
+        # we need to hold the slave selected until all responses come back
+        # TODO: check if this will break Timeout if a slave does not respond?
+        # should probably work correctly as it uses master signals
+        # TODO: we could reuse arbiter counters
+        locks = {
+            "write": AXILiteRequestCounter(
+                request=master.aw.valid & master.aw.ready,
+                response=master.b.valid & master.b.ready),
+            "read": AXILiteRequestCounter(
+                request=master.ar.valid & master.ar.ready,
+                response=master.r.valid & master.r.ready),
+        }
+        self.submodules += locks.values()
 
         def get_sig(interface, channel, name):
             return getattr(getattr(interface, channel), name)
 
-        # decode slave addresses
-        self.comb += [slave_sel[i].eq(fun(master.aw.addr[addr_shift:]) | fun(master.aw.addr[addr_shift:]))
-            for i, (fun, bus) in enumerate(slaves)]
-        if register:
-            self.sync += slave_sel_r.eq(slave_sel)
-        else:
-            self.comb += slave_sel_r.eq(slave_sel)
+        # # #
 
-        # connect master->slaves signals except valid
-        for fun, slave in slaves:
+        # decode slave addresses
+        for i, (decoder, bus) in enumerate(slaves):
+            self.comb += [
+                slave_sel_dec["write"][i].eq(decoder(master.aw.addr[addr_shift:])),
+                slave_sel_dec["read"][i].eq(decoder(master.ar.addr[addr_shift:])),
+            ]
+
+        # change the current selection only when we've got all responses
+        for channel in locks.keys():
+            self.sync += If(locks[channel].ready, slave_sel_reg[channel].eq(slave_sel_dec[channel]))
+        # we have to cut the delaying select
+        for ch, final in slave_sel.items():
+            self.comb += If(locks[ch].ready,
+                             final.eq(slave_sel_dec[ch])
+                         ).Else(
+                             final.eq(slave_sel_reg[ch])
+                         )
+
+        # connect master->slaves signals except valid/ready
+        for i, (_, slave) in enumerate(slaves):
             for channel, name, direction in master.layout_flat():
-                if direction == DIR_M_TO_S and name != "valid":
-                    self.comb += get_sig(slave, channel, name).eq(get_sig(master, channel, name))
-
-        # combine cyc with slave selection signals
-        for i, (fun, slave) in enumerate(slaves):
-            for ch in ["aw", "w", "ar"]:
-                slave_valid = get_sig(slave, ch, "valid")
-                master_valid = get_sig(master, ch, "valid")
-                self.comb += slave_valid.eq(master_valid & slave_sel[i])
+                if direction == DIR_M_TO_S:
+                    src = get_sig(master, channel, name)
+                    dst = get_sig(slave, channel, name)
+                    # mask master control signals depending on slave selection
+                    if name in ["valid", "ready"]:
+                        src = src & slave_sel[directions[channel]][i]
+                    self.comb += dst.eq(src)
+
+        # connect slave->master signals masking not selected slaves
+        for channel, name, direction in master.layout_flat():
+            if direction == DIR_S_TO_M:
+                dst = get_sig(master, channel, name)
+                masked = []
+                for i, (_, slave) in enumerate(slaves):
+                    src = get_sig(slave, channel, name)
+                    # mask depending on channel
+                    mask = Replicate(slave_sel[directions[channel]][i], len(dst))
+                    masked.append(src & mask)
+                self.comb += dst.eq(reduce(or_, masked))
 
-        # generate master ready by ORing all slave readys
-        self.comb += [
-            master.aw.ready.eq(reduce(or_, [slave.aw.ready for fun, slave in slaves])),
-            master.w.ready.eq(reduce(or_, [slave.w.ready for fun, slave in slaves])),
-            master.ar.ready.eq(reduce(or_, [slave.ar.ready for fun, slave in slaves])),
-        ]
+class AXILiteInterconnectShared(Module):
+    __doc__ = """AXI Lite shared interconnect
 
-        # mux (1-hot) slave data return
-        masked = [Replicate(slave_sel_r[i], len(master.r.data)) & slaves[i][1].r.data for i in range(ns)]
-        self.comb += master.r.data.eq(reduce(or_, masked))
+    {slaves}
+    """.format(slaves=AXILiteDecoder._doc_slaves)
 
-class AXILiteInterconnectShared(Module):
     def __init__(self, masters, slaves, register=False, timeout_cycles=1e6):
         # TODO: data width
         shared = AXILiteInterface()
@@ -1069,13 +1125,21 @@ class AXILiteInterconnectShared(Module):
             self.submodules.timeout = AXILiteTimeout(shared, timeout_cycles)
 
 class AXILiteCrossbar(Module):
+    __doc__ = """AXI Lite crossbar
+
+    MxN crossbar for M masters and N slaves.
+
+    {slaves}
+    """.format(slaves=AXILiteDecoder._doc_slaves)
+
     def __init__(self, masters, slaves, register=False):
         matches, busses = zip(*slaves)
-        access = [[AXILiteInterface() for j in slaves] for i in masters]
+        access_m_s = [[AXILiteInterface() for j in slaves] for i in masters]  # a[master][slave]
+        access_s_m = list(zip(*access_m_s))  # a[slave][master]
         # decode each master into its access row
-        for row, master in zip(access, masters):
-            row = list(zip(matches, row))
-            self.submodules += AXILiteDecoder(master, row, register)
+        for slaves, master in zip(access_m_s, masters):
+            slaves = list(zip(matches, slaves))
+            self.submodules += AXILiteDecoder(master, slaves, register)
         # arbitrate each access column onto its slave
-        for column, bus in zip(zip(*access), busses):
-            self.submodules += AXILiteArbiter(column, bus)
+        for masters, bus in zip(access_s_m, busses):
+            self.submodules += AXILiteArbiter(masters, bus)
index ee5d4107dfbba27f783c76c18a8ceaa933232c64..1c8345569460ce1960f7ee59ccc9b493129abfb7 100644 (file)
@@ -342,6 +342,7 @@ class AXILiteChecker:
             yield
 
     def handle_write(self, axi_lite):
+        # aw
         while not (yield axi_lite.aw.valid):
             yield
         yield from self.delay(self.ready_latency)
@@ -352,12 +353,14 @@ class AXILiteChecker:
         while not (yield axi_lite.w.valid):
             yield
         yield from self.delay(self.ready_latency)
+        # w
         data = (yield axi_lite.w.data)
         strb = (yield axi_lite.w.strb)
         yield axi_lite.w.ready.eq(1)
         yield
         yield axi_lite.w.ready.eq(0)
         yield from self.delay(self.response_latency)
+        # b
         yield axi_lite.b.valid.eq(1)
         yield axi_lite.b.resp.eq(RESP_OKAY)
         yield
@@ -367,6 +370,7 @@ class AXILiteChecker:
         self.writes.append((addr, data, strb))
 
     def handle_read(self, axi_lite):
+        # ar
         while not (yield axi_lite.ar.valid):
             yield
         yield from self.delay(self.ready_latency)
@@ -375,6 +379,7 @@ class AXILiteChecker:
         yield
         yield axi_lite.ar.ready.eq(0)
         yield from self.delay(self.response_latency)
+        # r
         data = self.rdata_generator(addr)
         yield axi_lite.r.valid.eq(1)
         yield axi_lite.r.resp.eq(RESP_OKAY)
@@ -383,6 +388,7 @@ class AXILiteChecker:
         while not (yield axi_lite.r.ready):
             yield
         yield axi_lite.r.valid.eq(0)
+        yield axi_lite.r.data.eq(0)
         self.reads.append((addr, data))
 
     @passive
@@ -650,7 +656,7 @@ class TestAXILite(unittest.TestCase):
 # TestAXILiteInterconnet ---------------------------------------------------------------------------
 
 class TestAXILiteInterconnect(unittest.TestCase):
-    def axilite_pattern_generator(self, axi_lite, pattern):
+    def axilite_pattern_generator(self, axi_lite, pattern, delay=0):
         for rw, addr, data in pattern:
             assert rw in ["w", "r"]
             if rw == "w":
@@ -660,6 +666,8 @@ class TestAXILiteInterconnect(unittest.TestCase):
                 rdata, resp = (yield from axi_lite.read(addr))
                 self.assertEqual(resp, RESP_OKAY)
                 self.assertEqual(rdata, data)
+            for _ in range(delay):
+                yield
         for _ in range(16):
             yield
 
@@ -776,7 +784,7 @@ class TestAXILiteInterconnect(unittest.TestCase):
             checker = AXILiteChecker()
             generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)]
             generators += [timeout(300), checker.handler(dut.slave)]
-            run_simulation(dut, generators, vcd_name='sim.vcd')
+            run_simulation(dut, generators)
             order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203]
             self.assertEqual([addr for addr, data, strb in checker.writes], order)
             self.assertEqual([addr for addr, data in checker.reads], order)
@@ -805,7 +813,6 @@ class TestAXILiteInterconnect(unittest.TestCase):
             for _ in range(8):
                 yield
 
-
         n_masters = 3
 
         # with no delay each master will do all transfers at once
@@ -825,7 +832,117 @@ class TestAXILiteInterconnect(unittest.TestCase):
             checker = AXILiteChecker(response_latency=lambda: 3)
             generators = [generator(i, master, delay=1) for i, master in enumerate(dut.masters)]
             generators += [timeout(300), checker.handler(dut.slave)]
-            run_simulation(dut, generators, vcd_name='sim.vcd')
+            run_simulation(dut, generators)
             order = [0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203]
             self.assertEqual([addr for addr, data, strb in checker.writes], order)
             self.assertEqual([addr for addr, data in checker.reads], order)
+
+    def decoder_test(self, n_slaves, pattern, generator_delay=0):
+        class DUT(Module):
+            def __init__(self, decoders):
+                self.master = AXILiteInterface()
+                self.slaves = [AXILiteInterface() for _ in range(len(decoders))]
+                slaves = list(zip(decoders, self.slaves))
+                self.submodules.decoder = AXILiteDecoder(self.master, slaves)
+
+        def decoder(i):
+            # bytes to 32-bit words aligned
+            size   = (0x100) >> 2
+            origin = (0x100 * i) >> 2
+            return lambda a: (a[log2_int(size):] == (origin >> log2_int(size)))
+
+        def rdata_generator(adr):
+            for rw, a, v in pattern:
+                if rw == "r" and a == adr:
+                    return v
+            return 0xbaadc0de
+
+        dut = DUT([decoder(i) for i in range(n_slaves)])
+        checkers = [AXILiteChecker(rdata_generator=rdata_generator) for _ in dut.slaves]
+
+        generators = [self.axilite_pattern_generator(dut.master, pattern, delay=generator_delay)]
+        generators += [checker.handler(slave) for (slave, checker) in zip(dut.slaves, checkers)]
+        generators += [timeout(300)]
+        run_simulation(dut, generators, vcd_name='sim.vcd')
+
+        return checkers
+
+    def test_decoder_write(self):
+        for delay in [0, 1, 0]:
+            with self.subTest(delay=delay):
+                slaves = self.decoder_test(n_slaves=3, pattern=[
+                    ("w", 0x010, 1),
+                    ("w", 0x110, 2),
+                    ("w", 0x210, 3),
+                    ("w", 0x011, 1),
+                    ("w", 0x012, 1),
+                    ("w", 0x111, 2),
+                    ("w", 0x112, 2),
+                    ("w", 0x211, 3),
+                    ("w", 0x212, 3),
+                ], generator_delay=delay)
+
+                def addr(checker_list):
+                    return [entry[0] for entry in checker_list]
+
+                self.assertEqual(addr(slaves[0].writes), [0x010, 0x011, 0x012])
+                self.assertEqual(addr(slaves[1].writes), [0x110, 0x111, 0x112])
+                self.assertEqual(addr(slaves[2].writes), [0x210, 0x211, 0x212])
+                for slave in slaves:
+                    self.assertEqual(slave.reads, [])
+
+    def test_decoder_read(self):
+        for delay in [0, 1]:
+            with self.subTest(delay=delay):
+                slaves = self.decoder_test(n_slaves=3, pattern=[
+                    ("r", 0x010, 1),
+                    ("r", 0x110, 2),
+                    ("r", 0x210, 3),
+                    ("r", 0x011, 1),
+                    ("r", 0x012, 1),
+                    ("r", 0x111, 2),
+                    ("r", 0x112, 2),
+                    ("r", 0x211, 3),
+                    ("r", 0x212, 3),
+                ], generator_delay=delay)
+
+                def addr(checker_list):
+                    return [entry[0] for entry in checker_list]
+
+                self.assertEqual(addr(slaves[0].reads), [0x010, 0x011, 0x012])
+                self.assertEqual(addr(slaves[1].reads), [0x110, 0x111, 0x112])
+                self.assertEqual(addr(slaves[2].reads), [0x210, 0x211, 0x212])
+                for slave in slaves:
+                    self.assertEqual(slave.writes, [])
+
+    def test_decoder_read_write(self):
+        for delay in [0, 1]:
+            with self.subTest(delay=delay):
+                slaves = self.decoder_test(n_slaves=3, pattern=[
+                    ("w", 0x010, 1),
+                    ("w", 0x110, 2),
+                    ("r", 0x111, 2),
+                    ("r", 0x011, 1),
+                    ("r", 0x211, 3),
+                    ("w", 0x210, 3),
+                ], generator_delay=delay)
+
+                def addr(checker_list):
+                    return [entry[0] for entry in checker_list]
+
+                self.assertEqual(addr(slaves[0].writes), [0x010])
+                self.assertEqual(addr(slaves[0].reads),  [0x011])
+                self.assertEqual(addr(slaves[1].writes), [0x110])
+                self.assertEqual(addr(slaves[1].reads),  [0x111])
+                self.assertEqual(addr(slaves[2].writes), [0x210])
+                self.assertEqual(addr(slaves[2].reads),  [0x211])
+
+    def test_decoder_stall(self):
+        with self.assertRaises(TimeoutError):
+            self.decoder_test(n_slaves=3, pattern=[
+                ("w", 0x300, 1),
+            ])
+        with self.assertRaises(TimeoutError):
+            self.decoder_test(n_slaves=3, pattern=[
+                ("r", 0x300, 1),
+            ])