soc/tools: initialize wishbone remote control (for now only uart)
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Mon, 16 Nov 2015 16:46:36 +0000 (17:46 +0100)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Tue, 17 Nov 2015 00:05:52 +0000 (01:05 +0100)
litex/soc/interconnect/wishbonebridge.py
litex/soc/tools/remote/__init__.py [new file with mode: 0644]
litex/soc/tools/remote/client.py [new file with mode: 0644]
litex/soc/tools/remote/comm_uart.py [new file with mode: 0644]
litex/soc/tools/remote/csr_builder.py [new file with mode: 0644]
litex/soc/tools/remote/etherbone.py [new file with mode: 0644]
litex/soc/tools/remote/server.py [new file with mode: 0644]
setup.py

index 51d84c8c71e5ee204e1bf89f5ff933310e3b4adc..7448733dff340d9f00bc67bcaf1321171d3ebb1d 100644 (file)
@@ -144,7 +144,7 @@ class WishboneStreamingBridge(Module):
             phy.sink.stb.eq(1),
             If(phy.sink.ack,
                 byte_counter_ce.eq(1),
-                If(byte_counter.value == 3,
+                If(byte_counter == 3,
                     word_counter_ce.eq(1),
                     If(word_counter == (length-1),
                         NextState("IDLE")
diff --git a/litex/soc/tools/remote/__init__.py b/litex/soc/tools/remote/__init__.py
new file mode 100644 (file)
index 0000000..f5a29bc
--- /dev/null
@@ -0,0 +1,3 @@
+from litex.soc.tools.remote.comm_uart import CommUART
+from litex.soc.tools.remote.server import RemoteServer
+from litex.soc.tools.remote.client import RemoteClient
diff --git a/litex/soc/tools/remote/client.py b/litex/soc/tools/remote/client.py
new file mode 100644 (file)
index 0000000..2c1e64d
--- /dev/null
@@ -0,0 +1,66 @@
+import socket
+
+from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord
+from litex.soc.tools.remote.etherbone import EtherboneReads, EtherboneWrites
+from litex.soc.tools.remote.etherbone import EtherboneIPC
+from litex.soc.tools.remote.csr_builder import CSRBuilder
+
+
+class RemoteClient(EtherboneIPC, CSRBuilder):
+    def __init__(self, host="localhost", port=1234, csr_csv="csr.csv", csr_data_width=32, debug=False):
+        CSRBuilder.__init__(self, self, csr_csv, csr_data_width)
+        self.host = host
+        self.port = port
+        self.debug = debug
+
+    def open(self):
+        if hasattr(self, "socket"):
+            return
+        self.socket = socket.create_connection((self.host, self.port), 5.0)
+        self.socket.settimeout(1.0)
+
+    def close(self):
+        if not hasattr(self, "socket"):
+            return
+        self.socket.close()
+        del self.socket
+
+    def read(self, addr, length=1):
+        # prepare packet
+        record = EtherboneRecord()
+        record.reads = EtherboneReads(addrs=[addr + 4*j for j in range(length)])
+        record.rcount = len(record.reads)
+
+        # send packet
+        packet = EtherbonePacket()
+        packet.records = [record]
+        packet.encode()
+        self.send_packet(self.socket, packet[:])
+
+        # receive response
+        packet = EtherbonePacket(self.receive_packet(self.socket))
+        packet.decode()
+        datas = packet.records.pop().writes.get_datas()
+        if self.debug:
+            for i, data in enumerate(datas):
+                print("read {:08x} @ {:08x}".format(data, addr + 4*i))
+        if length == 1:
+            return datas[0]
+        else:
+            return datas
+
+    def write(self, addr, datas):
+        if not isinstance(datas, list):
+            datas = [datas]
+        record = EtherboneRecord()
+        record.writes = EtherboneWrites(base_addr=addr, datas=[d for d in datas])
+        record.wcount = len(record.writes)
+
+        packet = EtherbonePacket()
+        packet.records = [record]
+        packet.encode()
+        self.send_packet(self.socket, packet)
+
+        if self.debug:
+            for i, data in enumerate(datas):
+                print("write {:08x} @ {:08x}".format(data, addr + 4*i))
diff --git a/litex/soc/tools/remote/comm_uart.py b/litex/soc/tools/remote/comm_uart.py
new file mode 100644 (file)
index 0000000..5845d10
--- /dev/null
@@ -0,0 +1,64 @@
+import serial
+import struct
+
+
+class CommUART:
+    msg_type = {
+        "write": 0x01,
+        "read":  0x02
+    }
+    def __init__(self, port, baudrate=115200, debug=False):
+        self.port = port
+        self.baudrate = str(baudrate)
+        self.csr_data_width = None
+        self.debug = debug
+        self.port = serial.serial_for_url(port, baudrate)
+
+    def open(self, csr_data_width):
+        self.csr_data_width = csr_data_width
+        if hasattr(self, "port"):
+            return
+        self.port.open()
+
+    def close(self):
+        if not hasattr(self, "port"):
+            return
+        del self.port
+
+    def _read(self, length):
+        r = bytes()
+        while len(r) < length:
+            r += self.port.read(length - len(r))
+        return r
+
+    def _write(self, data):
+        remaining = len(data)
+        pos = 0
+        while remaining:
+            written = self.port.write(data[pos:])
+            remaining -= written
+            pos += written
+
+    def read(self, addr, length=None):
+        r = []
+        length_int = 1 if length is None else length
+        self._write([self.msg_type["read"], length_int])
+        self._write(list((addr//4).to_bytes(4, byteorder="big")))
+        for i in range(length_int):
+            data = int.from_bytes(self._read(4), "big")
+            if self.debug:
+                print("read {:08x} @ {:08x}".format(data, addr + 4*i))
+            if length is None:
+                return data
+            r.append(data)
+        return r
+
+    def write(self, addr, data):
+        data = data if isinstance(data, list) else [data]
+        length = len(data)
+        self._write([self.msg_type["write"], length])
+        self._write(list((addr//4).to_bytes(4, byteorder="big")))
+        for i in range(len(data)):
+            self._write(list(data[i].to_bytes(4, byteorder="big")))
+            if self.debug:
+                print("write {:08x} @ {:08x}".format(data[i], addr + 4*i))
diff --git a/litex/soc/tools/remote/csr_builder.py b/litex/soc/tools/remote/csr_builder.py
new file mode 100644 (file)
index 0000000..9429291
--- /dev/null
@@ -0,0 +1,84 @@
+import csv
+
+
+class CSRElements:
+    def __init__(self, d):
+        self.d = d
+
+    def __getattr__(self, attr):
+        try:
+            return self.__dict__['d'][attr]
+        except KeyError:
+            pass
+        raise KeyError("No such element " + attr)
+
+
+class CSRRegister:
+    def __init__(self, readfn, writefn, name, addr, length, data_width, mode):
+        self.readfn = readfn
+        self.writefn = writefn
+        self.addr = addr
+        self.length = length
+        self.data_width = data_width
+        self.mode = mode
+
+    def read(self):
+        if self.mode not in ["rw", "ro"]:
+            raise KeyError(name + "register not readable")
+        datas = self.readfn(self.addr, length=self.length)
+        if isinstance(datas, int):
+            return datas
+        else:
+            data = 0
+            for i in range(self.length):
+                data = data << self.data_width
+                data |= datas[i]
+            return data
+
+    def write(self, value):
+        if self.mode not in ["rw", "wo"]:
+            raise KeyError(name + "register not writable")
+        datas = []
+        for i in range(self.length):
+            datas.append((value >> ((self.length-1-i)*self.data_width)) & (2**self.data_width-1))
+        self.writefn(self.addr, datas)
+
+
+class CSRBuilder:
+    def __init__(self, comm, csr_csv, csr_data_width):
+        self.csr_data_width = csr_data_width
+        self.constants = self.build_constants(csr_csv)
+        self.bases = self.build_bases(csr_csv)
+        self.regs = self.build_registers(csr_csv, comm.read, comm.write)
+
+    def build_bases(self, csr_csv):
+        csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
+        d = {}
+        for item in csv_reader:
+            group, name, addr, dummy0, dummy1 = item
+            if group == "csr_base":
+                d[name] = int(addr.replace("0x", ""), 16)
+        return CSRElements(d)
+
+    def build_registers(self, csr_csv, readfn, writefn):
+        csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
+        d = {}
+        for item in csv_reader:
+            group, name, addr, length, mode = item
+            if group == "csr_register":
+                addr = int(addr.replace("0x", ""), 16)
+                length = int(length)
+                d[name] = CSRRegister(readfn, writefn, name, addr, length, self.csr_data_width, mode)
+        return CSRElements(d)
+
+    def build_constants(self, csr_csv):
+        csv_reader = csv.reader(open(csr_csv), delimiter=',', quotechar='#')
+        d = {}
+        for item in csv_reader:
+            group, name, value, dummy0, dummy1 = item
+            if group == "constant":
+                try:
+                    d[name] = int(value)
+                except:
+                    d[name] = value
+        return CSRElements(d)
diff --git a/litex/soc/tools/remote/etherbone.py b/litex/soc/tools/remote/etherbone.py
new file mode 100644 (file)
index 0000000..e119d16
--- /dev/null
@@ -0,0 +1,376 @@
+import math
+from copy import deepcopy
+import struct
+
+from litex.soc.interconnect.stream_packet import HeaderField, Header
+
+etherbone_magic = 0x4e6f
+etherbone_version = 1
+etherbone_packet_header_length = 8
+etherbone_packet_header_fields = {
+    "magic":            HeaderField(0,  0, 16),
+
+    "version":          HeaderField(2,  4, 4),
+    "nr":               HeaderField(2,  2, 1),
+    "pr":               HeaderField(2,  1, 1),
+    "pf":               HeaderField(2,  0, 1),
+
+    "addr_size":        HeaderField(3,  4, 4),
+    "port_size":        HeaderField(3,  0, 4)
+}
+etherbone_packet_header = Header(etherbone_packet_header_fields,
+                                 etherbone_packet_header_length,
+                                 swap_field_bytes=True)
+
+etherbone_record_header_length = 4
+etherbone_record_header_fields = {
+    "bca":              HeaderField(0,  0, 1),
+    "rca":              HeaderField(0,  1, 1),
+    "rff":              HeaderField(0,  2, 1),
+    "cyc":              HeaderField(0,  4, 1),
+    "wca":              HeaderField(0,  5, 1),
+    "wff":              HeaderField(0,  6, 1),
+
+    "byte_enable":      HeaderField(1,  0, 8),
+
+    "wcount":           HeaderField(2,  0, 8),
+
+    "rcount":           HeaderField(3,  0, 8)
+}
+etherbone_record_header = Header(etherbone_record_header_fields,
+                                 etherbone_record_header_length,
+                                 swap_field_bytes=True)
+
+
+def split_bytes(v, n, endianness="big"):
+    r = []
+    r_bytes = v.to_bytes(n, byteorder=endianness)
+    for byte in r_bytes:
+        r.append(int(byte))
+    return r
+
+
+def merge_bytes(b, endianness="big"):
+    return int.from_bytes(bytes(b), endianness)
+
+
+def get_field_data(field, datas):
+    v = merge_bytes(datas[field.byte:field.byte+math.ceil(field.width/8)])
+    return (v >> field.offset) & (2**field.width-1)
+
+
+class Packet(list):
+    def __init__(self, init=[]):
+        self.ongoing = False
+        self.done = False
+        for data in init:
+            self.append(data)
+
+
+class EtherboneWrite:
+    def __init__(self, data):
+        self.data = data
+
+    def __repr__(self):
+        return "WR32 0x{:08x}".format(self.data)
+
+
+class EtherboneRead:
+    def __init__(self, addr):
+        self.addr = addr
+
+    def __repr__(self):
+        return "RD32 @ 0x{:08x}".format(self.addr)
+
+
+class EtherboneWrites(Packet):
+    def __init__(self, init=[], base_addr=0, datas=[]):
+        Packet.__init__(self, init)
+        self.base_addr = base_addr
+        self.writes = []
+        self.encoded = init != []
+        for data in datas:
+            self.add(EtherboneWrite(data))
+
+    def add(self, write):
+        self.writes.append(write)
+
+    def get_datas(self):
+        datas = []
+        for write in self.writes:
+            datas.append(write.data)
+        return datas
+
+    def encode(self):
+        if self.encoded:
+            raise ValueError
+        for byte in split_bytes(self.base_addr, 4):
+            self.append(byte)
+        for write in self.writes:
+            for byte in split_bytes(write.data, 4):
+                self.append(byte)
+        self.encoded = True
+
+    def decode(self):
+        if not self.encoded:
+            raise ValueError
+        base_addr = []
+        for i in range(4):
+            base_addr.append(self.pop(0))
+        self.base_addr = merge_bytes(base_addr)
+        self.writes = []
+        while len(self) != 0:
+            write = []
+            for i in range(4):
+                write.append(self.pop(0))
+            self.writes.append(EtherboneWrite(merge_bytes(write)))
+        self.encoded = False
+
+    def __repr__(self):
+        r = "Writes\n"
+        r += "--------\n"
+        r += "BaseAddr @ 0x{:08x}\n".format(self.base_addr)
+        for write in self.writes:
+            r += write.__repr__() + "\n"
+        return r
+
+
+class EtherboneReads(Packet):
+    def __init__(self, init=[], base_ret_addr=0, addrs=[]):
+        Packet.__init__(self, init)
+        self.base_ret_addr = base_ret_addr
+        self.reads = []
+        self.encoded = init != []
+        for addr in addrs:
+            self.add(EtherboneRead(addr))
+
+    def add(self, read):
+        self.reads.append(read)
+
+    def get_addrs(self):
+        addrs = []
+        for read in self.reads:
+            addrs.append(read.addr)
+        return addrs
+
+    def encode(self):
+        if self.encoded:
+            raise ValueError
+        for byte in split_bytes(self.base_ret_addr, 4):
+            self.append(byte)
+        for read in self.reads:
+            for byte in split_bytes(read.addr, 4):
+                self.append(byte)
+        self.encoded = True
+
+    def decode(self):
+        if not self.encoded:
+            raise ValueError
+        base_ret_addr = []
+        for i in range(4):
+            base_ret_addr.append(self.pop(0))
+        self.base_ret_addr = merge_bytes(base_ret_addr)
+        self.reads = []
+        while len(self) != 0:
+            read = []
+            for i in range(4):
+                read.append(self.pop(0))
+            self.reads.append(EtherboneRead(merge_bytes(read)))
+        self.encoded = False
+
+    def __repr__(self):
+        r = "Reads\n"
+        r += "--------\n"
+        r += "BaseRetAddr @ 0x{:08x}\n".format(self.base_ret_addr)
+        for read in self.reads:
+            r += read.__repr__() + "\n"
+        return r
+
+
+class EtherboneRecord(Packet):
+    def __init__(self, init=[]):
+        Packet.__init__(self, init)
+        self.writes = None
+        self.reads = None
+        self.bca = 0
+        self.rca = 0
+        self.rff = 0
+        self.cyc = 0
+        self.wca = 0
+        self.wff = 0
+        self.byte_enable = 0xf
+        self.wcount = 0
+        self.rcount = 0
+        self.encoded = init != []
+
+
+    def get_writes(self):
+        if self.wcount == 0:
+            return None
+        else:
+            writes = []
+            for i in range((self.wcount+1)*4):
+                writes.append(self.pop(0))
+            return EtherboneWrites(writes)
+
+    def get_reads(self):
+        if self.rcount == 0:
+            return None
+        else:
+            reads = []
+            for i in range((self.rcount+1)*4):
+                reads.append(self.pop(0))
+            return EtherboneReads(reads)
+
+    def decode(self):
+        if not self.encoded:
+            raise ValueError
+        header = []
+        for byte in self[:etherbone_record_header.length]:
+            header.append(self.pop(0))
+        for k, v in sorted(etherbone_record_header.fields.items()):
+            setattr(self, k, get_field_data(v, header))
+        self.writes = self.get_writes()
+        if self.writes is not None:
+            self.writes.decode()
+        self.reads = self.get_reads()
+        if self.reads is not None:
+            self.reads.decode()
+        self.encoded = False
+
+    def set_writes(self, writes):
+        self.wcount = len(writes.writes)
+        writes.encode()
+        for byte in writes:
+            self.append(byte)
+
+    def set_reads(self, reads):
+        self.rcount = len(reads.reads)
+        reads.encode()
+        for byte in reads:
+            self.append(byte)
+
+    def encode(self):
+        if self.encoded:
+            raise ValueError
+        if self.writes is not None:
+            self.set_writes(self.writes)
+        if self.reads is not None:
+            self.set_reads(self.reads)
+        header = 0
+        for k, v in sorted(etherbone_record_header.fields.items()):
+            value = merge_bytes(split_bytes(getattr(self, k),
+                                            math.ceil(v.width/8)),
+                                            "little")
+            header += (value << v.offset+(v.byte*8))
+        for d in split_bytes(header, etherbone_record_header.length):
+            self.insert(0, d)
+        self.encoded = True
+
+    def __repr__(self, n=0):
+        r = "Record {}\n".format(n)
+        r += "--------\n"
+        if self.encoded:
+            for d in self:
+                r += "{:02x}".format(d)
+        else:
+            for k in sorted(etherbone_record_header.fields.keys()):
+                r += k + " : 0x{:0x}\n".format(getattr(self, k))
+            if self.wcount != 0:
+                r += self.writes.__repr__()
+            if self.rcount != 0:
+                r += self.reads.__repr__()
+        return r
+
+
+class EtherbonePacket(Packet):
+    def __init__(self, init=[]):
+        Packet.__init__(self, init)
+        self.encoded = init != []
+        self.records = []
+
+        self.magic = etherbone_magic
+        self.version = etherbone_version
+        self.addr_size = 32//8
+        self.port_size = 32//8
+        self.nr = 0
+        self.pr = 0
+        self.pf = 0
+
+    def get_records(self):
+        records = []
+        done = False
+        payload = self
+        while len(payload) != 0:
+            record = EtherboneRecord(payload)
+            record.decode()
+            records.append(deepcopy(record))
+            payload = record
+        return records
+
+    def decode(self):
+        if not self.encoded:
+            raise ValueError
+        header = []
+        for byte in self[:etherbone_packet_header.length]:
+            header.append(self.pop(0))
+        for k, v in sorted(etherbone_packet_header.fields.items()):
+            setattr(self, k, get_field_data(v, header))
+        self.records = self.get_records()
+        self.encoded = False
+
+    def set_records(self, records):
+        for record in records:
+            record.encode()
+            for byte in record:
+                self.append(byte)
+
+    def encode(self):
+        if self.encoded:
+            raise ValueError
+        self.set_records(self.records)
+        header = 0
+        for k, v in sorted(etherbone_packet_header.fields.items()):
+            value = merge_bytes(split_bytes(getattr(self, k), math.ceil(v.width/8)), "little")
+            header += (value << v.offset+(v.byte*8))
+        for d in split_bytes(header, etherbone_packet_header.length):
+            self.insert(0, d)
+        self.encoded = True
+
+    def __repr__(self):
+        r = "Packet\n"
+        r += "--------\n"
+        if self.encoded:
+            for d in self:
+                r += "{:02x}".format(d)
+        else:
+            for k in sorted(etherbone_packet_header.fields.keys()):
+                r += k + " : 0x{:0x}\n".format(getattr(self, k))
+            for i, record in enumerate(self.records):
+                r += record.__repr__(i)
+        return r
+
+
+class EtherboneIPC:
+    def send_packet(self, socket, packet):
+        socket.sendall(bytes(packet))
+
+    def receive_packet(self, socket):
+        header_length = etherbone_packet_header_length + etherbone_record_header_length
+        packet = bytes()
+        while len(packet) < header_length:
+            chunk = socket.recv(header_length - len(packet))
+            if len(chunk) == 0:
+                return 0
+            else:
+                packet += chunk
+        wcount, rcount = struct.unpack(">BB", packet[header_length-2:])
+        counts = wcount + rcount
+        packet_size = header_length + 4*(counts + 1)
+        while len(packet) < packet_size:
+            chunk = socket.recv(packet_size - len(packet))
+            if len(chunk) == 0:
+                return 0
+            else:
+                packet += chunk
+        return packet
diff --git a/litex/soc/tools/remote/server.py b/litex/soc/tools/remote/server.py
new file mode 100644 (file)
index 0000000..5f6bb99
--- /dev/null
@@ -0,0 +1,104 @@
+import socket
+import threading
+import argparse
+
+from litex.soc.tools.remote.etherbone import EtherbonePacket, EtherboneRecord, EtherboneWrites
+from litex.soc.tools.remote.etherbone import EtherboneIPC
+
+
+class RemoteServer(EtherboneIPC):
+    def __init__(self, comm, port=1234, csr_data_width=32):
+        self.comm = comm
+        self.port = port
+        self.csr_data_width = 32
+
+    def open(self):
+        if hasattr(self, "socket"):
+            return
+        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.socket.bind(("localhost", self.port))
+        self.socket.listen(1)
+        self.comm.open(self.csr_data_width)
+
+    def close(self):
+        self.comm.close()
+        if not hasattr(self, "socket"):
+            return
+        self.socket.close()
+        del self.socket
+
+    def _serve_thread(self):
+        while True:
+            client_socket, addr = self.socket.accept()
+            print("Connected with " + addr[0] + ":" + str(addr[1]))
+            try:
+                while True:
+                    packet = self.receive_packet(client_socket)
+                    if packet == 0:
+                        break
+                    packet = EtherbonePacket(packet)
+                    packet.decode()
+
+                    record = packet.records.pop()
+
+                    # writes:
+                    if record.writes != None:
+                        self.comm.write(record.writes.base_addr, record.writes.get_datas())
+
+                    # reads
+                    if record.reads != None:
+                        reads = []
+                        for addr in record.reads.get_addrs():
+                            reads.append(self.comm.read(addr))
+
+                        record = EtherboneRecord()
+                        record.writes = EtherboneWrites(datas=reads)
+                        record.wcount = len(record.writes)
+
+                        packet = EtherbonePacket()
+                        packet.records = [record]
+                        packet.encode()
+                        self.send_packet(client_socket, packet)
+            finally:
+                print("Disconnect")
+                client_socket.close()
+
+    def start(self):
+        self.serve_thread = threading.Thread(target=self._serve_thread)
+        self.serve_thread.setDaemon(True)
+        self.serve_thread.start()
+
+    def join(self, writer_only=False):
+        if not hasattr(self, "serve_thread"):
+            return
+        self.serve_thread.join()
+
+def _get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--comm", default="uart", help="comm interface")
+    parser.add_argument("--port", default="2", help="UART port")
+    parser.add_argument("--baudrate", default=115200, help="UART baudrate")
+    parser.add_argument("--csr_data_width", default=32, help="CSR data_width")
+    return parser.parse_args()
+
+def main():
+    args = _get_args()
+    if args.comm == "uart":
+        from litex.soc.tools.remote import CommUART
+        port = args.port if not args.port.isdigit() else int(args.port)
+        comm = CommUART(args.port if not args.port.isdigit() else int(args.port),
+                        args.baudrate,
+                        debug=False)
+    else:
+        raise NotImplementedError
+
+    server = RemoteServer(comm, csr_data_width=args.csr_data_width)
+    server.open()
+    server.start()
+    try:
+        server.join(True)
+    except KeyboardInterrupt: # FIXME
+        pass
+
+if __name__ == "__main__":
+    main()
index f1bb2abe9cc7080fdd32e793ae816a8e24ecb4d0..e10cf41ecd1a9a6b1d21d818d8830bd7f76415f9 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -36,6 +36,8 @@ setup(
         "console_scripts": [
             "flterm=litex.soc.tools.flterm:main",
             "mkmscimg=litex.soc.tools.mkmscimg:main",
+            "remote_server=litex.soc.tools.remote.server:main",
+            "remote_client=litex.soc.tools.remote.client:main"
         ],
     },
 )