Delete solver pointer in Cython __dealloc__ (#4799)
authormakaimann <makaim@stanford.edu>
Mon, 3 Aug 2020 20:39:44 +0000 (13:39 -0700)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 20:39:44 +0000 (13:39 -0700)
src/api/python/cvc4.pxi
test/unit/api/python/test_grammar.py

index b6e4616da4741b0692bfb1ec1c5a2918ea547edf..a51307d213875a6dbc14d7b1ac622eb142615c63 100644 (file)
@@ -54,6 +54,22 @@ def expand_list_arg(num_req_args=0):
 ### always use c++ default arguments
 #### only use default args of None at python level
 
+# References and pointers
+# The Solver object holds a pointer to a c_Solver.
+# This is because the assignment operator is deleted in the C++ API for solvers.
+# Cython has a limitation where you can't stack allocate objects
+# that have constructors with arguments:
+# https://groups.google.com/forum/#!topic/cython-users/fuKd-nQLpBs.
+# To get around that you can either have a nullary constructor and assignment
+# or, use a pointer (which is what we chose).
+# An additional complication of this is that to free up resources, you must
+# know when to delete the object.
+# Python will not follow the same scoping rules as in C++, so it must be
+# able to reference count. To do this correctly, the solver must be a
+# reference in the Python class for any class that keeps a pointer to
+# the solver in C++ (to ensure the solver is not deleted before something
+# that depends on it).
+
 
 ## Objects for hashing
 cdef c_OpHashFunction cophash = c_OpHashFunction()
@@ -63,11 +79,12 @@ cdef c_TermHashFunction ctermhash = c_TermHashFunction()
 
 cdef class Datatype:
     cdef c_Datatype cd
-    def __cinit__(self):
-        pass
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
+        self.solver = solver
 
     def __getitem__(self, index):
-        cdef DatatypeConstructor dc = DatatypeConstructor()
+        cdef DatatypeConstructor dc = DatatypeConstructor(self.solver)
         if isinstance(index, int) and index >= 0:
             dc.cdc = self.cd[(<int?> index)]
         elif isinstance(index, str):
@@ -77,12 +94,12 @@ cdef class Datatype:
         return dc
 
     def getConstructor(self, str name):
-        cdef DatatypeConstructor dc = DatatypeConstructor()
+        cdef DatatypeConstructor dc = DatatypeConstructor(self.solver)
         dc.cdc = self.cd.getConstructor(name.encode())
         return dc
 
     def getConstructorTerm(self, str name):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cd.getConstructorTerm(name.encode())
         return term
 
@@ -100,18 +117,20 @@ cdef class Datatype:
 
     def __iter__(self):
         for ci in self.cd:
-            dc = DatatypeConstructor()
+            dc = DatatypeConstructor(self.solver)
             dc.cdc = ci
             yield dc
 
 
 cdef class DatatypeConstructor:
     cdef c_DatatypeConstructor cdc
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
         self.cdc = c_DatatypeConstructor()
+        self.solver = solver
 
     def __getitem__(self, index):
-        cdef DatatypeSelector ds = DatatypeSelector()
+        cdef DatatypeSelector ds = DatatypeSelector(self.solver)
         if isinstance(index, int) and index >= 0:
             ds.cds = self.cdc[(<int?> index)]
         elif isinstance(index, str):
@@ -124,12 +143,12 @@ cdef class DatatypeConstructor:
         return self.cdc.getName().decode()
 
     def getConstructorTerm(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cdc.getConstructorTerm()
         return term
 
     def getTesterTerm(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cdc.getTesterTerm()
         return term
 
@@ -137,12 +156,12 @@ cdef class DatatypeConstructor:
         return self.cdc.getNumSelectors()
 
     def getSelector(self, str name):
-        cdef DatatypeSelector ds = DatatypeSelector()
+        cdef DatatypeSelector ds = DatatypeSelector(self.solver)
         ds.cds = self.cdc.getSelector(name.encode())
         return ds
 
     def getSelectorTerm(self, str name):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cdc.getSelectorTerm(name.encode())
         return term
 
@@ -154,16 +173,17 @@ cdef class DatatypeConstructor:
 
     def __iter__(self):
         for ci in self.cdc:
-            ds = DatatypeSelector()
+            ds = DatatypeSelector(self.solver)
             ds.cds = ci
             yield ds
 
 
 cdef class DatatypeConstructorDecl:
     cdef c_DatatypeConstructorDecl cddc
+    cdef Solver solver
 
-    def __cinit__(self):
-      pass
+    def __cinit__(self, Solver solver):
+        self.solver = solver
 
     def addSelector(self, str name, Sort sort):
         self.cddc.addSelector(name.encode(), sort.csort)
@@ -180,8 +200,9 @@ cdef class DatatypeConstructorDecl:
 
 cdef class DatatypeDecl:
     cdef c_DatatypeDecl cdd
-    def __cinit__(self):
-        pass
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
+        self.solver = solver
 
     def addConstructor(self, DatatypeConstructorDecl ctor):
         self.cdd.addConstructor(ctor.cddc)
@@ -201,19 +222,21 @@ cdef class DatatypeDecl:
 
 cdef class DatatypeSelector:
     cdef c_DatatypeSelector cds
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
         self.cds = c_DatatypeSelector()
+        self.solver = solver
 
     def getName(self):
         return self.cds.getName().decode()
 
     def getSelectorTerm(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cds.getSelectorTerm()
         return term
 
     def getRangeSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.cds.getRangeSort()
         return sort
 
@@ -226,8 +249,10 @@ cdef class DatatypeSelector:
 
 cdef class Op:
     cdef c_Op cop
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
         self.cop = c_Op()
+        self.solver = solver
 
     def __eq__(self, Op other):
         return self.cop == other.cop
@@ -274,7 +299,9 @@ cdef class Op:
 
 cdef class Grammar:
     cdef c_Grammar  cgrammar
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
+        self.solver = solver
         self.cgrammar = c_Grammar()
 
     def addRule(self, Term ntSymbol, Term rule):
@@ -365,59 +392,62 @@ cdef class Solver:
     def __cinit__(self):
         self.csolver = new c_Solver(NULL)
 
+    def __dealloc__(self):
+        del self.csolver
+
     def getBooleanSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getBooleanSort()
         return sort
 
     def getIntegerSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getIntegerSort()
         return sort
 
     def getRealSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getRealSort()
         return sort
 
     def getRegExpSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getRegExpSort()
         return sort
 
     def getRoundingmodeSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getRoundingmodeSort()
         return sort
 
     def getStringSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.getStringSort()
         return sort
 
     def mkArraySort(self, Sort indexSort, Sort elemSort):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkArraySort(indexSort.csort, elemSort.csort)
         return sort
 
     def mkBitVectorSort(self, uint32_t size):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkBitVectorSort(size)
         return sort
 
     def mkFloatingPointSort(self, uint32_t exp, uint32_t sig):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkFloatingPointSort(exp, sig)
         return sort
 
     def mkDatatypeSort(self, DatatypeDecl dtypedecl):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkDatatypeSort(dtypedecl.cdd)
         return sort
 
     def mkFunctionSort(self, sorts, Sort codomain):
 
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         # populate a vector with dereferenced c_Sorts
         cdef vector[c_Sort] v
 
@@ -433,7 +463,7 @@ cdef class Solver:
         return sort
 
     def mkParamSort(self, symbolname):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkParamSort(symbolname.encode())
         return sort
 
@@ -446,7 +476,7 @@ cdef class Solver:
                  where sorts can also be comma-separated arguments of
                   type Sort
         '''
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         cdef vector[c_Sort] v
         for s in sorts:
             v.push_back((<Sort?> s).csort)
@@ -462,7 +492,7 @@ cdef class Solver:
                   where fields can also be comma-separated arguments of
           type Tuple[str, Sort]
         '''
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         cdef vector[pair[string, c_Sort]] v
         cdef pair[string, c_Sort] p
         for f in fields:
@@ -475,22 +505,22 @@ cdef class Solver:
         return sort
 
     def mkSetSort(self, Sort elemSort):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkSetSort(elemSort.csort)
         return sort
 
     def mkSequenceSort(self, Sort elemSort):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkSequenceSort(elemSort.csort)
         return sort
 
     def mkUninterpretedSort(self, str name):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.mkUninterpretedSort(name.encode())
         return sort
 
     def mkSortConstructorSort(self, str symbol, size_t arity):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort =self.csolver.mkSortConstructorSort(symbol.encode(), arity)
         return sort
 
@@ -503,7 +533,7 @@ cdef class Solver:
                  where sorts can also be comma-separated arguments of
                  type Sort
         '''
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         cdef vector[c_Sort] v
         for s in sorts:
             v.push_back((<Sort?> s).csort)
@@ -520,7 +550,7 @@ cdef class Solver:
 
                 where List[Term] can also be comma-separated arguments
         '''
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Term] v
 
         op = kind_or_op
@@ -544,7 +574,7 @@ cdef class Solver:
                 Op mkOp(Kind kind, uint32_t arg)
                 Op mkOp(Kind kind, uint32_t arg0, uint32_t arg1)
         '''
-        cdef Op op = Op()
+        cdef Op op = Op(self)
 
         if arg0 is None:
             op.cop = self.csolver.mkOp(k.k)
@@ -570,27 +600,27 @@ cdef class Solver:
         return op
 
     def mkTrue(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkTrue()
         return term
 
     def mkFalse(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkFalse()
         return term
 
     def mkBoolean(self, bint val):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkBoolean(val)
         return term
 
     def mkPi(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkPi()
         return term
 
     def mkReal(self, val, den=None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         if den is None:
             term.cterm = self.csolver.mkReal(str(val).encode())
         else:
@@ -602,27 +632,27 @@ cdef class Solver:
         return term
 
     def mkRegexpEmpty(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkRegexpEmpty()
         return term
 
     def mkRegexpSigma(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkRegexpSigma()
         return term
 
     def mkEmptySet(self, Sort s):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkEmptySet(s.csort)
         return term
 
     def mkSepNil(self, Sort sort):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkSepNil(sort.csort)
         return term
 
     def mkString(self, str_or_vec):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[unsigned] v
         if isinstance(str_or_vec, str):
             term.cterm = self.csolver.mkString(<string &> str_or_vec.encode())
@@ -639,17 +669,17 @@ cdef class Solver:
         return term
 
     def mkEmptySequence(self, Sort sort):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkEmptySequence(sort.csort)
         return term
 
     def mkUniverseSet(self, Sort sort):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkUniverseSet(sort.csort)
         return term
 
     def mkBitVector(self, size_or_str, val = None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         if isinstance(size_or_str, int):
             if val is None:
                 term.cterm = self.csolver.mkBitVector(<int> size_or_str)
@@ -670,47 +700,47 @@ cdef class Solver:
         return term
 
     def mkConstArray(self, Sort sort, Term val):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkConstArray(sort.csort, val.cterm)
         return term
 
     def mkPosInf(self, int exp, int sig):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkPosInf(exp, sig)
         return term
 
     def mkNegInf(self, int exp, int sig):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkNegInf(exp, sig)
         return term
 
     def mkNaN(self, int exp, int sig):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkNaN(exp, sig)
         return term
 
     def mkPosZero(self, int exp, int sig):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkPosZero(exp, sig)
         return term
 
     def mkNegZero(self, int exp, int sig):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkNegZero(exp, sig)
         return term
 
     def mkRoundingMode(self, RoundingMode rm):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkRoundingMode(<c_RoundingMode> rm.crm)
         return term
 
     def mkUninterpretedConst(self, Sort sort, int index):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkUninterpretedConst(sort.csort, index)
         return term
 
     def mkAbstractValue(self, index):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         try:
             term.cterm = self.csolver.mkAbstractValue(str(index).encode())
         except:
@@ -719,12 +749,12 @@ cdef class Solver:
         return term
 
     def mkFloatingPoint(self, int exp, int sig, Term val):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkFloatingPoint(exp, sig, val.cterm)
         return term
 
     def mkConst(self, Sort sort, symbol=None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         if symbol is None:
             term.cterm = self.csolver.mkConst(sort.csort)
         else:
@@ -733,7 +763,7 @@ cdef class Solver:
         return term
 
     def mkVar(self, Sort sort, symbol=None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         if symbol is None:
             term.cterm = self.csolver.mkVar(sort.csort)
         else:
@@ -742,12 +772,12 @@ cdef class Solver:
         return term
 
     def mkDatatypeConstructorDecl(self, str name):
-        cdef DatatypeConstructorDecl ddc = DatatypeConstructorDecl()
+        cdef DatatypeConstructorDecl ddc = DatatypeConstructorDecl(self)
         ddc.cddc = self.csolver.mkDatatypeConstructorDecl(name.encode())
         return ddc
 
     def mkDatatypeDecl(self, str name, sorts_or_bool=None, isCoDatatype=None):
-        cdef DatatypeDecl dd = DatatypeDecl()
+        cdef DatatypeDecl dd = DatatypeDecl(self)
         cdef vector[c_Sort] v
 
         # argument cases
@@ -791,7 +821,7 @@ cdef class Solver:
         return dd
 
     def simplify(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.simplify(t.cterm)
         return term
 
@@ -804,7 +834,7 @@ cdef class Solver:
         return r
 
     def mkSygusGrammar(self, boundVars, ntSymbols):
-        cdef Grammar grammar = Grammar()
+        cdef Grammar grammar = Grammar(self)
         cdef vector[c_Term] bvc
         cdef vector[c_Term] ntc
         for bv in boundVars:
@@ -815,7 +845,7 @@ cdef class Solver:
         return grammar
 
     def mkSygusVar(self, Sort sort, str symbol=""):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.mkSygusVar(sort.csort, symbol.encode())
         return term
 
@@ -826,7 +856,7 @@ cdef class Solver:
         self.csolver.addSygusInvConstraint(inv_f.cterm, pre_f.cterm, trans_f.cterm, post_f.cterm)
 
     def synthFun(self, str symbol, bound_vars, Sort sort, Grammar grammar=None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Term] v
         for bv in bound_vars:
             v.push_back((<Term?> bv).cterm)
@@ -842,12 +872,12 @@ cdef class Solver:
         return r
 
     def getSynthSolution(self, Term term):
-        cdef Term t = Term()
+        cdef Term t = Term(self)
         t.cterm = self.csolver.getSynthSolution(term.cterm)
         return t
 
     def synthInv(self, symbol, bound_vars, Grammar grammar=None):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Term] v
         for bv in bound_vars:
             v.push_back((<Term?> bv).cterm)
@@ -903,7 +933,7 @@ cdef class Solver:
                  where ctors can also be comma-separated arguments of
                   type DatatypeConstructorDecl
         '''
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         cdef vector[c_DatatypeConstructorDecl] v
 
         for c in ctors:
@@ -912,7 +942,7 @@ cdef class Solver:
         return sort
 
     def declareFun(self, str symbol, list sorts, Sort sort):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Sort] v
         for s in sorts:
             v.push_back((<Sort?> s).csort)
@@ -922,7 +952,7 @@ cdef class Solver:
         return term
 
     def declareSort(self, str symbol, int arity):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self)
         sort.csort = self.csolver.declareSort(symbol.encode(), arity)
         return sort
 
@@ -934,7 +964,7 @@ cdef class Solver:
                 Term defineFun(Term fun, List[Term] bound_vars,
                                Term term, bool glbl)
         '''
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Term] v
         for bv in bound_vars:
             v.push_back((<Term?> bv).cterm)
@@ -961,7 +991,7 @@ cdef class Solver:
                 Term defineFunRec(Term fun, List[Term] bound_vars,
                                Term term, bool glbl)
         '''
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         cdef vector[c_Term] v
         for bv in bound_vars:
             v.push_back((<Term?> bv).cterm)
@@ -1001,7 +1031,7 @@ cdef class Solver:
     def getAssertions(self):
         assertions = []
         for a in self.csolver.getAssertions():
-            term = Term()
+            term = Term(self)
             term.cterm = a
             assertions.append(term)
         return assertions
@@ -1012,8 +1042,8 @@ cdef class Solver:
         '''
         assignments = {}
         for a in self.csolver.getAssignment():
-            varterm = Term()
-            valterm = Term()
+            varterm = Term(self)
+            valterm = Term(self)
             varterm.cterm = a.first
             valterm.cterm = a.second
             assignments[varterm] = valterm
@@ -1028,7 +1058,7 @@ cdef class Solver:
     def getUnsatAssumptions(self):
         assumptions = []
         for a in self.csolver.getUnsatAssumptions():
-            term = Term()
+            term = Term(self)
             term.cterm = a
             assumptions.append(term)
         return assumptions
@@ -1036,23 +1066,23 @@ cdef class Solver:
     def getUnsatCore(self):
         core = []
         for a in self.csolver.getUnsatCore():
-            term = Term()
+            term = Term(self)
             term.cterm = a
             core.append(term)
         return core
 
     def getValue(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.getValue(t.cterm)
         return term
 
     def getSeparationHeap(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.getSeparationHeap()
         return term
 
     def getSeparationNilTerm(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self)
         term.cterm = self.csolver.getSeparationNilTerm()
         return term
 
@@ -1080,9 +1110,10 @@ cdef class Solver:
 
 cdef class Sort:
     cdef c_Sort csort
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
         # csort always set by Solver
-        pass
+        self.solver = solver
 
     def __eq__(self, Sort other):
         return self.csort == other.csort
@@ -1190,12 +1221,12 @@ cdef class Sort:
         return self.csort.isComparableTo(sort.csort)
 
     def getDatatype(self):
-        cdef Datatype d = Datatype()
+        cdef Datatype d = Datatype(self.solver)
         d.cd = self.csort.getDatatype()
         return d
 
     def instantiate(self, params):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         cdef vector[c_Sort] v
         for s in params:
             v.push_back((<Sort?> s).csort)
@@ -1208,13 +1239,13 @@ cdef class Sort:
     def getConstructorDomainSorts(self):
         domain_sorts = []
         for s in self.csort.getConstructorDomainSorts():
-            sort = Sort()
+            sort = Sort(self.solver)
             sort.csort = s
             domain_sorts.append(sort)
         return domain_sorts
 
     def getConstructorCodomainSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getConstructorCodomainSort()
         return sort
 
@@ -1224,33 +1255,33 @@ cdef class Sort:
     def getFunctionDomainSorts(self):
         domain_sorts = []
         for s in self.csort.getFunctionDomainSorts():
-            sort = Sort()
+            sort = Sort(self.solver)
             sort.csort = s
             domain_sorts.append(sort)
         return domain_sorts
 
     def getFunctionCodomainSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getFunctionCodomainSort()
         return sort
 
     def getArrayIndexSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getArrayIndexSort()
         return sort
 
     def getArrayElementSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getArrayElementSort()
         return sort
 
     def getSetElementSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getSetElementSort()
         return sort
 
     def getSequenceElementSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.csort.getSequenceElementSort()
         return sort
 
@@ -1263,7 +1294,7 @@ cdef class Sort:
     def getUninterpretedSortParamSorts(self):
         param_sorts = []
         for s in self.csort.getUninterpretedSortParamSorts():
-            sort = Sort()
+            sort = Sort(self.solver)
             sort.csort = s
             param_sorts.append(sort)
         return param_sorts
@@ -1286,7 +1317,7 @@ cdef class Sort:
     def getDatatypeParamSorts(self):
         param_sorts = []
         for s in self.csort.getDatatypeParamSorts():
-            sort = Sort()
+            sort = Sort(self.solver)
             sort.csort = s
             param_sorts.append(sort)
         return param_sorts
@@ -1300,7 +1331,7 @@ cdef class Sort:
     def getTupleSorts(self):
         tuple_sorts = []
         for s in self.csort.getTupleSorts():
-            sort = Sort()
+            sort = Sort(self.solver)
             sort.csort = s
             tuple_sorts.append(sort)
         return tuple_sorts
@@ -1308,9 +1339,10 @@ cdef class Sort:
 
 cdef class Term:
     cdef c_Term cterm
-    def __cinit__(self):
+    cdef Solver solver
+    def __cinit__(self, Solver solver):
         # cterm always set in the Solver object
-        pass
+        self.solver = solver
 
     def __eq__(self, Term other):
         return self.cterm == other.cterm
@@ -1326,7 +1358,7 @@ cdef class Term:
 
     def __iter__(self):
         for ci in self.cterm:
-            term = Term()
+            term = Term(self.solver)
             term.cterm = ci
             yield term
 
@@ -1337,14 +1369,14 @@ cdef class Term:
         return kind(<int> self.cterm.getKind())
 
     def getSort(self):
-        cdef Sort sort = Sort()
+        cdef Sort sort = Sort(self.solver)
         sort.csort = self.cterm.getSort()
         return sort
 
     def substitute(self, list es, list replacements):
         cdef vector[c_Term] ces
         cdef vector[c_Term] creplacements
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
 
         if len(es) != len(replacements):
             raise RuntimeError("Expecting list inputs to substitute to "
@@ -1362,7 +1394,7 @@ cdef class Term:
         return self.cterm.hasOp()
 
     def getOp(self):
-        cdef Op op = Op()
+        cdef Op op = Op(self.solver)
         op.cop = self.cterm.getOp()
         return op
 
@@ -1373,50 +1405,50 @@ cdef class Term:
         return self.cterm.isConst()
 
     def getConstArrayBase(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.getConstArrayBase()
         return term
 
     def getConstSequenceElements(self):
         elems = []
         for e in self.cterm.getConstSequenceElements():
-            term = Term()
+            term = Term(self.solver)
             term.cterm = e
             elems.append(term)
         return elems
 
     def notTerm(self):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.notTerm()
         return term
 
     def andTerm(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.andTerm((<Term> t).cterm)
         return term
 
     def orTerm(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.orTerm(t.cterm)
         return term
 
     def xorTerm(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.xorTerm(t.cterm)
         return term
 
     def eqTerm(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.eqTerm(t.cterm)
         return term
 
     def impTerm(self, Term t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.impTerm(t.cterm)
         return term
 
     def iteTerm(self, Term then_t, Term else_t):
-        cdef Term term = Term()
+        cdef Term term = Term(self.solver)
         term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
         return term
 
index 3a703017ffac9a134cc8597e538ba5d2fc83fa1e..53e873cea52c39faaebd31fefb46a5d5548b3d80 100644 (file)
@@ -10,7 +10,7 @@ def test_add_rule():
   boolean = solver.getBooleanSort()
   integer = solver.getIntegerSort()
 
-  nullTerm = pycvc4.Term()
+  nullTerm = pycvc4.Term(solver)
   start = solver.mkVar(boolean)
   nts = solver.mkVar(boolean)
 
@@ -41,7 +41,7 @@ def test_add_rules():
   boolean = solver.getBooleanSort()
   integer = solver.getIntegerSort()
 
-  nullTerm = pycvc4.Term()
+  nullTerm = pycvc4.Term(solver)
   start = solver.mkVar(boolean)
   nts = solver.mkVar(boolean)
 
@@ -69,7 +69,7 @@ def testAddAnyConstant():
   solver = pycvc4.Solver()
   boolean = solver.getBooleanSort()
 
-  nullTerm = pycvc4.Term()
+  nullTerm = pycvc4.Term(solver)
   start = solver.mkVar(boolean)
   nts = solver.mkVar(boolean)
 
@@ -93,7 +93,7 @@ def testAddAnyVariable():
   solver = pycvc4.Solver()
   boolean = solver.getBooleanSort()
 
-  nullTerm = pycvc4.Term()
+  nullTerm = pycvc4.Term(solver)
   x = solver.mkVar(boolean)
   start = solver.mkVar(boolean)
   nts = solver.mkVar(boolean)