soc/tools/litex_term: continue cleanup
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Feb 2016 13:35:18 +0000 (14:35 +0100)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Feb 2016 13:35:18 +0000 (14:35 +0100)
litex/soc/tools/litex_term.py
litex/soc/tools/remote/comm_uart.py

index bdc0c056818ff2fc6a00410d378fec3c029b5642..06d8843c4c465c53adcb08dea816132218f95ee7 100644 (file)
@@ -7,12 +7,11 @@ import serial
 import threading
 import argparse
 
-
+# TODO: cleanup getkey function
 if sys.platform == "win32":
     def getkey():
         import msvcrt
         return msvcrt.getch()
-
 else:
     def getkey():
         import termios
@@ -31,7 +30,9 @@ else:
         return c
 
 
-sfl_magic_len = 14
+sfl_prompt_req = b"F7:    boot from serial\n"
+sfl_prompt_ack = b"\x06"
+
 sfl_magic_req = b"sL5DdSMmkekro\n"
 sfl_magic_ack = b"z6IHG7cYDID6o\n"
 
@@ -105,39 +106,36 @@ class SFLFrame:
         packet += self.payload
         return packet
 
+
 class LiteXTerm:
-    def __init__(self, kernel_image, kernel_address):
+    def __init__(self, serial_boot, kernel_image, kernel_address):
+        self.serial_boot = serial_boot
         self.kernel_image = kernel_image
         self.kernel_address = kernel_address
 
         self.reader_alive = False
         self.writer_alive = False
 
-        self.detect_magic_bytes = bytes(len(sfl_magic_req))
-
-    def open(self, port, speed):
-        self.serial = serial.serial_for_url(
-            port,
-            baudrate=speed,
-            bytesize=8,
-            parity="N",
-            stopbits=1,
-            xonxoff=0,
-            timeout=0.25)
-        self.serial.flushOutput()
-        self.serial.flushInput()
-        self.serial.close()  # in case port was not correctly closed
-        self.serial.open()
+        self.promp_detect_buffer = bytes(len(sfl_prompt_req))
+        self.magic_detect_buffer = bytes(len(sfl_magic_req))
+
+    def open(self, port, baudrate):
+        if hasattr(self, "port"):
+            return
+        self.port = serial.serial_for_url(port, baudrate)
 
     def close(self):
-        self.serial.close()
+        if not hasattr(self, "port"):
+            return
+        self.port.close()
+        del self.port
 
     def send_frame(self, frame):
         retry = 1
         while retry:
-            self.serial.write(frame.encode())
+            self.port.write(frame.encode())
             # Get the reply from the device
-            reply = self.serial.read()
+            reply = self.port.read()
             if reply == sfl_ack_success:
                 retry = 0
             elif reply == sfl_ack_crcerror:
@@ -182,17 +180,28 @@ class LiteXTerm:
         frame.payload = self.kernel_address.to_bytes(4, "big") 
         self.send_frame(frame)
 
+    def detect_prompt(self, data):
+        if len(data):
+            self.promp_detect_buffer = self.promp_detect_buffer[1:] + data
+            return self.promp_detect_buffer == sfl_prompt_req
+        else:
+            return False
+
+    def answer_prompt(self):
+        print("[TERM] Received serial boot prompt from the device.")
+        self.port.write(sfl_prompt_ack)
+
     def detect_magic(self, data):
         if len(data):
-            self.detect_magic_bytes = self.detect_magic_bytes[1:] + data
-            return self.detect_magic_bytes == sfl_magic_req
+            self.magic_detect_buffer = self.magic_detect_buffer[1:] + data
+            return self.magic_detect_buffer == sfl_magic_req
         else:
             return False
 
     def answer_magic(self):
         print("[TERM] Received firmware download request from the device.")
         if os.path.exists(self.kernel_image):
-            self.serial.write(sfl_magic_ack)
+            self.port.write(sfl_magic_ack)
             self.upload(self.kernel_image, self.kernel_address)
             self.boot()
         print("[TERM] Done.");
@@ -200,14 +209,20 @@ class LiteXTerm:
     def reader(self):
         try:
             while self.reader_alive:
-                c = self.serial.read()
+                c = self.port.read()
                 if c == b"\r":
                     sys.stdout.write(b"\n")
                 else:
-                    sys.stdout.write(c.decode())
+                    try:
+                        # TODO: cleanup
+                        sys.stdout.write(c.decode())
+                    except:
+                        pass
                 sys.stdout.flush()
 
                 if self.kernel_image is not None:
+                    if self.serial_boot and self.detect_prompt(c):
+                        self.answer_prompt()
                     if self.detect_magic(c):
                         self.answer_magic()
 
@@ -231,14 +246,13 @@ class LiteXTerm:
                 try:
                     b = getkey()
                 except KeyboardInterrupt:
-                    b = serial.to_bytes([3])
-                c = b.decode()
-                if c == chr(0x03):
+                    b = b"\x03"
+                if b == b"\x03":
                     self.stop()
-                elif c == '\n':
-                    self.serial.write(serial.to_bytes([10]))
+                elif b == b"\n":
+                    self.port.write(b"\x0a")
                 else:
-                    self.serial.write(b)
+                    self.port.write(b)
         except:
             self.writer_alive = False
             raise
@@ -272,6 +286,8 @@ def _get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument("port", help="serial port")
     parser.add_argument("--speed", default=115200, help="serial baudrate")
+    parser.add_argument("--serial-boot", default=False, action='store_true',
+                        help="automatically initiate serial boot")
     parser.add_argument("--kernel", default=None, help="kernel image")
     parser.add_argument("--kernel-adr", type=lambda a: int(a, 0), default=0x40000000, help="kernel address")
     return parser.parse_args()
@@ -279,13 +295,14 @@ def _get_args():
 
 def main():
     args = _get_args()
-    term = LiteXTerm(args.kernel, args.kernel_adr)
+    term = LiteXTerm(args.serial_boot, args.kernel, args.kernel_adr)
     term.open(args.port, args.speed)
     term.start()
     try:
         term.join(True)
     except KeyboardInterrupt:
         pass
+    term.close()
 
 
 if __name__ == "__main__":
index 0f7987d12377ce9f773916c9cb3cd754eca9d4ec..10fa37e8db4d792632bfe87f2c830b1a33ff0208 100644 (file)
@@ -21,6 +21,7 @@ class CommUART:
     def close(self):
         if not hasattr(self, "port"):
             return
+        self.port.close()
         del self.port
 
     def _read(self, length):