ip: add checksum functions to model
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Wed, 4 Feb 2015 18:03:49 +0000 (19:03 +0100)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Wed, 4 Feb 2015 18:03:49 +0000 (19:03 +0100)
liteeth/common.py
liteeth/ip/__init__.py
liteeth/test/model/arp.py
liteeth/test/model/ip.py

index 09972a994d5ef60b001357f84ed23bed56edc245..1d32885e8aa0fbc4b8bb1cb14799212838fcaac6 100644 (file)
@@ -52,7 +52,7 @@ arp_header = {
        "destination_ip_address":       HField(24,  0, 32)
 }
 
-ipv4_header_len = 24
+ipv4_header_len = 20
 ipv4_header = {
        "version":                                      HField(0,  0, 4),
        "ihl":                                          HField(0,  4, 4),
@@ -66,8 +66,7 @@ ipv4_header = {
        "protocol":                                     HField(9,  0, 8),
        "header_checksum":                      HField(10,  0, 16),
        "source_ip_address":            HField(12,  0, 32),
-       "destination_ip_address":       HField(16,  0, 32),
-       "options":                                      HField(20,  0, 32)
+       "destination_ip_address":       HField(16,  0, 32)
 }
 
 udp_header_len = 8
index 1ffb52938cf247aa5e824e87ab15a899c65a8beb..973423563b0748baadea2f1085cd6a9c83647ae1 100644 (file)
@@ -41,8 +41,7 @@ class LiteEthIPTX(Module):
                        packetizer.sink.flags.eq(0),
                        packetizer.sink.fragment_offset.eq(0),
                        packetizer.sink.time_to_live.eq(0x80),
-                       packetizer.sink.source_ip_address.eq(ip_address),
-                       packetizer.sink.options.eq(0)
+                       packetizer.sink.source_ip_address.eq(ip_address)
                ]
                sink = packetizer.source
 
index 1dc84626564cd34dc3553e55e01d1798b6959fac..65f006f32387de00298797833857b9218a3c845f 100644 (file)
@@ -87,7 +87,7 @@ class ARP(Module):
                        self.process_request(packet)
                elif packet.operation == arp_opcode_reply:
                        self.process_reply(packet)
-       
+
        def process_request(self, request):
                if request.destination_ip_address == self.ip_address:
                        reply = ARPPacket([0]*(arp_packet_length-arp_header_len))
index 3b5aa6d5ce1014d5654244b268eef33e607b87ab..61a861be1851e4b3b3fc9325c707500165acc080 100644 (file)
@@ -10,11 +10,28 @@ def print_ip(s):
 
 preamble = split_bytes(eth_preamble, 8)
 
+def carry_around_add(a, b):
+    c = a + b
+    return (c & 0xffff) + (c >> 16)
+
+def checksum(msg):
+    s = 0
+    for i in range(0, len(msg), 2):
+        w = msg[i] + (msg[i+1] << 8)
+        s = carry_around_add(s, w)
+    return ~s & 0xffff
+
 # IP model
 class IPPacket(Packet):
        def __init__(self, init=[]):
                Packet.__init__(self, init)
 
+       def get_checksum(self):
+               return self[10] | (self[11] << 8)
+
+       def check_checksum(self):
+               return checksum(self[:ipv4_header_len]) == 0
+
        def decode(self):
                header = []
                for byte in self[:ipv4_header_len]:
@@ -30,6 +47,13 @@ class IPPacket(Packet):
                for d in split_bytes(header, ipv4_header_len):
                        self.insert(0, d)
 
+       def insert_checksum(self):
+               self[10] = 0
+               self[11] = 0
+               c = checksum(self[:ipv4_header_len])
+               self[10] = c & 0xff
+               self[11] = (c >> 8) & 0xff
+
        def __repr__(self):
                r = "--------\n"
                for k in sorted(ipv4_header.keys()):
@@ -56,6 +80,7 @@ class IP(Module):
 
        def send(self, packet):
                packet.encode()
+               packet.insert_checksum()
                if self.debug:
                        print_ip(">>>>>>>>")
                        print_ip(packet)
@@ -67,6 +92,11 @@ class IP(Module):
 
        def callback(self, packet):
                packet = IPPacket(packet)
+               if not packet.check_checksum():
+                       received = packet.get_checksum()
+                       packet.insert_checksum()
+                       expected = packet.get_checksum()
+                       raise ValueError("Checksum error received %04x / expected %04x" %(received, expected)) # XXX maybe too restrictive
                packet.decode()
                if self.debug:
                        print_ip("<<<<<<<<")
@@ -78,7 +108,7 @@ class IP(Module):
 
        def process(self, packet):
                pass
-       
+
 if __name__ == "__main__":
        from liteeth.test.model.dumps import *
        from liteeth.test.model.mac import *
@@ -89,11 +119,14 @@ if __name__ == "__main__":
        #print(packet)
        packet = IPPacket(packet)
        # check decoding
+       errors += not packet.check_checksum()
        packet.decode()
        #print(packet)
        errors += verify_packet(packet, {})
        # check encoding
        packet.encode()
+       packet.insert_checksum()
+       errors += not packet.check_checksum()
        packet.decode()
        #print(packet)
        errors += verify_packet(packet, {})