Improve p_* functions in smtio.py
authorClifford Wolf <clifford@clifford.at>
Wed, 25 Oct 2017 13:45:32 +0000 (15:45 +0200)
committerClifford Wolf <clifford@clifford.at>
Wed, 25 Oct 2017 13:45:32 +0000 (15:45 +0200)
backends/smt2/smtio.py

index 7cd7d0d5997a43421e1b30ce6fcb1554b5451d0d..18dee4e95900ce38bbf54661edfde48cbc4446bd 100644 (file)
@@ -127,7 +127,7 @@ class SmtIo:
             if self.dummy_file is not None:
                 self.dummy_fd = open(self.dummy_file, "w")
             if not self.noincr:
-                self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+                self.p_open()
 
         if self.unroll:
             self.logic_uf = False
@@ -210,19 +210,24 @@ class SmtIo:
 
         return stmt
 
-    def p_write(self, data):
-        # select_result = select([self.p.stdout], [self.p.stdin], [], 0.1):
-        wlen = self.p.stdin.write(data)
-        assert wlen == len(data)
+    def p_open(self):
+        self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
 
-    def p_flush(self):
-        self.p.stdin.flush()
+    def p_write(self, data, flush):
+        self.p.stdin.write(bytes(data, "ascii"))
+        if flush:
+            self.p.stdin.flush()
 
     def p_read(self):
         if len(self.p_buffer) == 0:
             return self.p.stdout.readline().decode("ascii")
         assert 0
 
+    def p_close(self):
+        self.p.stdin.close()
+        self.p = None
+        self.p_buffer = []
+
     def write(self, stmt, unroll=True):
         if stmt.startswith(";"):
             self.info(stmt)
@@ -295,21 +300,17 @@ class SmtIo:
         if self.solver != "dummy":
             if self.noincr:
                 if self.p is not None and not stmt.startswith("(get-"):
-                    self.p.stdin.close()
-                    self.p = None
-                    self.p_buffer = []
+                    self.p_close()
                 if stmt == "(push 1)":
                     self.smt2cache.append(list())
                 elif stmt == "(pop 1)":
                     self.smt2cache.pop()
                 else:
                     if self.p is not None:
-                        self.p_write(bytes(stmt + "\n", "ascii"))
-                        self.p_flush()
+                        self.p_write(stmt + "\n", True)
                     self.smt2cache[-1].append(stmt)
             else:
-                self.p_write(bytes(stmt + "\n", "ascii"))
-                self.p_flush()
+                self.p_write(stmt + "\n", True)
 
     def info(self, stmt):
         if not stmt.startswith("; yosys-smt2-"):
@@ -456,16 +457,13 @@ class SmtIo:
         if self.solver != "dummy":
             if self.noincr:
                 if self.p is not None:
-                    self.p.stdin.close()
-                    self.p = None
-                    self.p_buffer = []
-                self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+                    self.p_close()
+                self.p_open()
                 for cache_ctx in self.smt2cache:
                     for cache_stmt in cache_ctx:
-                        self.p_write(bytes(cache_stmt + "\n", "ascii"))
+                        self.p_write(cache_stmt + "\n", False)
 
-            self.p_write(bytes("(check-sat)\n", "ascii"))
-            self.p_flush()
+            self.p_write("(check-sat)\n", True)
 
             if self.timeinfo:
                 i = 0