--- /dev/null
+# This file is Copyright (c) 2019 Florent Kermarrec <florent@enjoy-digital.fr>
+# License: BSD
+
+import unittest
+import random
+
+from migen import *
+
+from litex.soc.interconnect.stream import *
+from litex.soc.interconnect.packet import *
+
+packet_header_length = 31
+packet_header_fields = {
+ "field_8b" : HeaderField(0, 0, 8),
+ "field_16b" : HeaderField(1, 0, 16),
+ "field_32b" : HeaderField(3, 0, 32),
+ "field_64b" : HeaderField(7, 0, 64),
+ "field_128b": HeaderField(15, 0, 128),
+}
+packet_header = Header(
+ fields = packet_header_fields,
+ length = packet_header_length,
+ swap_field_bytes = True)
+
+def packet_description(dw):
+ param_layout = packet_header.get_layout()
+ payload_layout = [("data", dw)]
+ return EndpointDescription(payload_layout, param_layout)
+
+def raw_description(dw):
+ payload_layout = [("data", dw)]
+ return EndpointDescription(payload_layout)
+
+class Packet:
+ def __init__(self, header, datas):
+ self.header = header
+ self.datas = datas
+
+
+class TestPacket(unittest.TestCase):
+ def test_loopback(self):
+ prng = random.Random(42)
+ # Prepare packets
+ npackets = 8
+ packets = []
+ for n in range(npackets):
+ header = {}
+ header["field_8b"] = prng.randrange(2**8)
+ header["field_16b"] = prng.randrange(2**16)
+ header["field_32b"] = prng.randrange(2**32)
+ header["field_64b"] = prng.randrange(2**64)
+ header["field_128b"] = prng.randrange(2**128)
+ datas = [prng.randrange(2**8) for _ in range(prng.randrange(2**7))]
+ packets.append(Packet(header, datas))
+
+ def generator(dut):
+ # Send packets
+ for packet in packets:
+ yield dut.sink.field_8b.eq(packet.header["field_8b"])
+ yield dut.sink.field_16b.eq(packet.header["field_16b"])
+ yield dut.sink.field_32b.eq(packet.header["field_32b"])
+ yield dut.sink.field_64b.eq(packet.header["field_64b"])
+ yield dut.sink.field_128b.eq(packet.header["field_128b"])
+ yield
+ for n, data in enumerate(packet.datas):
+ yield dut.sink.valid.eq(1)
+ yield dut.sink.last.eq(n == (len(packet.datas) - 1))
+ yield dut.sink.data.eq(data)
+ yield
+ while (yield dut.sink.ready) == 0:
+ yield
+ dut.sink.valid.eq(0)
+
+ def checker(dut):
+ dut.header_errors = 0
+ dut.data_errors = 0
+ dut.last_errors = 0
+ # Receive and check packets
+ yield dut.source.ready.eq(1)
+ for packet in packets:
+ for n, data in enumerate(packet.datas):
+ while (yield dut.source.valid) == 0:
+ yield
+ for field in ["field_8b", "field_16b", "field_32b", "field_64b", "field_128b"]:
+ if (yield getattr(dut.source, field)) != packet.header[field]:
+ dut.header_errors += 1
+ #print("{:02x} vs {:02x}".format((yield dut.source.data), data))
+ if ((yield dut.source.data) != data):
+ dut.data_errors += 1
+ if ((yield dut.source.last) != (n == (len(packet.datas) - 1))):
+ dut.last_errors += 1
+ yield
+ yield
+
+ class DUT(Module):
+ def __init__(self):
+ packetizer = Packetizer(packet_description(8), raw_description(8), packet_header)
+ depacketizer = Depacketizer(raw_description(8), packet_description(8), packet_header)
+ self.submodules += packetizer, depacketizer
+ self.comb += packetizer.source.connect(depacketizer.sink)
+ self.sink, self.source = packetizer.sink, depacketizer.source
+
+ dut = DUT()
+ run_simulation(dut, [generator(dut), checker(dut)])
+ self.assertEqual(dut.header_errors, 0)
+ self.assertEqual(dut.data_errors, 0)
+ self.assertEqual(dut.last_errors, 0)