axi: add to_pads method
authorKarol Gugala <kgugala@antmicro.com>
Mon, 3 Feb 2020 13:38:24 +0000 (14:38 +0100)
committerPiotr Binkowski <pbinkowski@antmicro.com>
Fri, 21 Feb 2020 11:22:18 +0000 (12:22 +0100)
Signed-off-by: Karol Gugala <kgugala@antmicro.com>
litex/soc/interconnect/axi.py

index 2681ff5ae3c27094844d8774d7ccc590c69846ec..a118970eaab5e8652b01775bfd48af0840c99dae 100644 (file)
@@ -1,4 +1,5 @@
 # This file is Copyright (c) 2018-2019 Florent Kermarrec <florent@enjoy-digital.fr>
+# This file is Copyright (c) 2020 Karol Gugala <kgugala@antmicro.com>
 # License: BSD
 
 """AXI4 Full/Lite support for LiteX"""
@@ -6,6 +7,7 @@
 from migen import *
 
 from litex.soc.interconnect import stream
+from litex.build.generic_platform import *
 
 # AXI Definition -----------------------------------------------------------------------------------
 
@@ -54,7 +56,7 @@ def r_description(data_width, id_width):
     ]
 
 class AXIInterface(Record):
-    def __init__(self, data_width, address_width, id_width=1, clock_domain="sys"):
+    def __init__(self, data_width, address_width, mode="master", id_width=1, clock_domain="sys"):
         self.data_width    = data_width
         self.address_width = address_width
         self.id_width      = id_width
@@ -66,6 +68,80 @@ class AXIInterface(Record):
         self.ar = stream.Endpoint(ax_description(address_width, id_width))
         self.r  = stream.Endpoint(r_description(data_width, id_width))
 
+    def _signals_in_channels(self, channels):
+        for channel_name in channels:
+            channel = getattr(self, channel_name)
+            for signal in channel.layout:
+                if signal[0] == 'param':
+                    continue
+                if signal[0] == 'payload':
+                    for s in signal[1]:
+                        yield s[0], channel_name, s[1], s[2]
+                else:
+                    if signal[0] == 'first':
+                        continue
+                    if signal[0] == 'last' and channel_name != 'w' and channel_name != 'r':
+                        continue
+                    yield signal[0], channel_name, signal[1], signal[2]
+
+
+    def to_pads(self, bus_name='axi'):
+        axi_bus = {}
+        for signal, channel, width, direction in self._signals_in_channels(['aw', 'w', 'b', 'ar', 'r']):
+            signal_name = channel + signal
+            axi_bus[signal_name] = width
+
+        signals = []
+        for pad in axi_bus:
+            signals.append(Subsignal(pad, Pins(axi_bus[pad])))
+
+        pads = [
+                (bus_name , 0) + tuple(signals)
+                ]
+        return pads
+
+    def connect_to_pads(self, module, platform, bus_name, mode='master'):
+
+        def _get_signals(pads, channel, signal):
+            signal_name = channel + signal
+            channel = getattr(self, channel)
+            axi_signal = getattr(channel, signal)
+            pads_signal = getattr(pads, signal_name)
+            return pads_signal, axi_signal
+
+        axi_pads = self.to_pads(bus_name)
+        platform.add_extension(axi_pads)
+        pads = platform.request(bus_name)
+
+        for signal, channel, width, direction in self._signals_in_channels(['aw', 'w', 'ar']):
+            pads_signal, axi_signal = _get_signals(pads, channel, signal)
+
+            if mode == 'master':
+                if direction == DIR_M_TO_S:
+                    module.comb += pads_signal.eq(axi_signal)
+                else:
+                    module.comb += axi_signal.eq(pads_signal)
+            else:
+                if direction == DIR_S_TO_M:
+                    module.comb += pads_signal.eq(axi_signal)
+                else:
+                    module.comb += axi_signal.eq(pads_signal)
+
+        for signal, channel, width, direction in self._signals_in_channels(['r', 'b']):
+            pads_signal, axi_signal = _get_signals(pads, channel, signal)
+
+            if mode == 'master':
+                if direction == DIR_S_TO_M:
+                    module.comb += pads_signal.eq(axi_signal)
+                else:
+                    module.comb += axi_signal.eq(pads_signal)
+            else:
+                if direction == DIR_M_TO_S:
+                    module.comb += pads_signal.eq(axi_signal)
+                else:
+                    module.comb += axi_signal.eq(pads_signal)
+
+
 # AXI Lite Definition ------------------------------------------------------------------------------
 
 def ax_lite_description(address_width):