24a47bd761883e738fb292290bd3cddaa06b743a
[cvc5.git] / test / python / unit / api / test_datatype_api.py
1 ###############################################################################
2 # Top contributors (to current version):
3 # Yoni Zohar
4 #
5 # This file is part of the cvc5 project.
6 #
7 # Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 # in the top-level source directory and their institutional affiliations.
9 # All rights reserved. See the file COPYING in the top-level source
10 # directory for licensing information.
11 # #############################################################################
12 ##
13
14 import pytest
15 import pycvc5
16 from pycvc5 import kinds
17 from pycvc5 import Sort, Term, DatatypeDecl
18
19
20 @pytest.fixture
21 def solver():
22 return pycvc5.Solver()
23
24
25 def test_mk_datatype_sort(solver):
26 dtypeSpec = solver.mkDatatypeDecl("list")
27 cons = solver.mkDatatypeConstructorDecl("cons")
28 cons.addSelector("head", solver.getIntegerSort())
29 dtypeSpec.addConstructor(cons)
30 nil = solver.mkDatatypeConstructorDecl("nil")
31 dtypeSpec.addConstructor(nil)
32 listSort = solver.mkDatatypeSort(dtypeSpec)
33 d = listSort.getDatatype()
34 consConstr = d[0]
35 nilConstr = d[1]
36 with pytest.raises(RuntimeError):
37 d[2]
38 consConstr.getConstructorTerm()
39 nilConstr.getConstructorTerm()
40
41
42 def test_mk_datatype_sorts(solver):
43 # Create two mutual datatypes corresponding to this definition
44 # block:
45 #
46 # DATATYPE
47 # tree = node(left: tree, right: tree) | leaf(data: list),
48 # list = cons(car: tree, cdr: list) | nil
49 # END
50 #
51
52 #Make unresolved types as placeholders
53 unresTypes = set([])
54 unresTree = solver.mkUninterpretedSort("tree")
55 unresList = solver.mkUninterpretedSort("list")
56 unresTypes.add(unresTree)
57 unresTypes.add(unresList)
58
59 tree = solver.mkDatatypeDecl("tree")
60 node = solver.mkDatatypeConstructorDecl("node")
61 node.addSelector("left", unresTree)
62 node.addSelector("right", unresTree)
63 tree.addConstructor(node)
64
65 leaf = solver.mkDatatypeConstructorDecl("leaf")
66 leaf.addSelector("data", unresList)
67 tree.addConstructor(leaf)
68
69 llist = solver.mkDatatypeDecl("list")
70 cons = solver.mkDatatypeConstructorDecl("cons")
71 cons.addSelector("car", unresTree)
72 cons.addSelector("cdr", unresTree)
73 llist.addConstructor(cons)
74
75 nil = solver.mkDatatypeConstructorDecl("nil")
76 llist.addConstructor(nil)
77
78 dtdecls = []
79 dtdecls.append(tree)
80 dtdecls.append(llist)
81 dtsorts = []
82 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
83 assert len(dtsorts) == len(dtdecls)
84 for i in range(0, len(dtdecls)):
85 assert dtsorts[i].isDatatype()
86 assert not dtsorts[i].getDatatype().isFinite()
87 assert dtsorts[i].getDatatype().getName() == dtdecls[i].getName()
88 # verify the resolution was correct
89 dtTree = dtsorts[0].getDatatype()
90 dtcTreeNode = dtTree[0]
91 assert dtcTreeNode.getName() == "node"
92 dtsTreeNodeLeft = dtcTreeNode[0]
93 assert dtsTreeNodeLeft.getName() == "left"
94 # argument type should have resolved to be recursive
95 assert dtsTreeNodeLeft.getRangeSort().isDatatype()
96 assert dtsTreeNodeLeft.getRangeSort() == dtsorts[0]
97
98 # fails due to empty datatype
99 dtdeclsBad = []
100 emptyD = solver.mkDatatypeDecl("emptyD")
101 dtdeclsBad.append(emptyD)
102 with pytest.raises(RuntimeError):
103 solver.mkDatatypeSorts(dtdeclsBad)
104
105
106 def test_datatype_structs(solver):
107 intSort = solver.getIntegerSort()
108 boolSort = solver.getBooleanSort()
109
110 # create datatype sort to test
111 dtypeSpec = solver.mkDatatypeDecl("list")
112 cons = solver.mkDatatypeConstructorDecl("cons")
113 cons.addSelector("head", intSort)
114 cons.addSelectorSelf("tail")
115 nullSort = Sort(solver)
116 with pytest.raises(RuntimeError):
117 cons.addSelector("null", nullSort)
118 dtypeSpec.addConstructor(cons)
119 nil = solver.mkDatatypeConstructorDecl("nil")
120 dtypeSpec.addConstructor(nil)
121 dtypeSort = solver.mkDatatypeSort(dtypeSpec)
122 dt = dtypeSort.getDatatype()
123 assert not dt.isCodatatype()
124 assert not dt.isTuple()
125 assert not dt.isRecord()
126 assert not dt.isFinite()
127 assert dt.isWellFounded()
128 # get constructor
129 dcons = dt[0]
130 consTerm = dcons.getConstructorTerm()
131 assert dcons.getNumSelectors() == 2
132
133 # create datatype sort to test
134 dtypeSpecEnum = solver.mkDatatypeDecl("enum")
135 ca = solver.mkDatatypeConstructorDecl("A")
136 dtypeSpecEnum.addConstructor(ca)
137 cb = solver.mkDatatypeConstructorDecl("B")
138 dtypeSpecEnum.addConstructor(cb)
139 cc = solver.mkDatatypeConstructorDecl("C")
140 dtypeSpecEnum.addConstructor(cc)
141 dtypeSortEnum = solver.mkDatatypeSort(dtypeSpecEnum)
142 dtEnum = dtypeSortEnum.getDatatype()
143 assert not dtEnum.isTuple()
144 assert dtEnum.isFinite()
145
146 # create codatatype
147 dtypeSpecStream = solver.mkDatatypeDecl("stream", True)
148 consStream = solver.mkDatatypeConstructorDecl("cons")
149 consStream.addSelector("head", intSort)
150 consStream.addSelectorSelf("tail")
151 dtypeSpecStream.addConstructor(consStream)
152 dtypeSortStream = solver.mkDatatypeSort(dtypeSpecStream)
153 dtStream = dtypeSortStream.getDatatype()
154 assert dtStream.isCodatatype()
155 assert not dtStream.isFinite()
156 # codatatypes may be well-founded
157 assert dtStream.isWellFounded()
158
159 # create tuple
160 tupSort = solver.mkTupleSort([boolSort])
161 dtTuple = tupSort.getDatatype()
162 assert dtTuple.isTuple()
163 assert not dtTuple.isRecord()
164 assert dtTuple.isFinite()
165 assert dtTuple.isWellFounded()
166
167 # create record
168 fields = [("b", boolSort), ("i", intSort)]
169 recSort = solver.mkRecordSort(fields)
170 assert recSort.isDatatype()
171 dtRecord = recSort.getDatatype()
172 assert not dtRecord.isTuple()
173 assert dtRecord.isRecord()
174 assert not dtRecord.isFinite()
175 assert dtRecord.isWellFounded()
176
177
178 def test_datatype_names(solver):
179 intSort = solver.getIntegerSort()
180
181 # create datatype sort to test
182 dtypeSpec = solver.mkDatatypeDecl("list")
183 dtypeSpec.getName()
184 assert dtypeSpec.getName() == "list"
185 cons = solver.mkDatatypeConstructorDecl("cons")
186 cons.addSelector("head", intSort)
187 cons.addSelectorSelf("tail")
188 dtypeSpec.addConstructor(cons)
189 nil = solver.mkDatatypeConstructorDecl("nil")
190 dtypeSpec.addConstructor(nil)
191 dtypeSort = solver.mkDatatypeSort(dtypeSpec)
192 dt = dtypeSort.getDatatype()
193 assert dt.getName() == "list"
194 dt.getConstructor("nil")
195 dt["cons"]
196 with pytest.raises(RuntimeError):
197 dt.getConstructor("head")
198 with pytest.raises(RuntimeError):
199 dt.getConstructor("")
200
201 dcons = dt[0]
202 assert dcons.getName() == "cons"
203 dcons.getSelector("head")
204 dcons["tail"]
205 with pytest.raises(RuntimeError):
206 dcons.getSelector("cons")
207
208 # get selector
209 dselTail = dcons[1]
210 assert dselTail.getName() == "tail"
211 assert dselTail.getRangeSort() == dtypeSort
212
213 # get selector from datatype
214 dt.getSelector("head")
215 with pytest.raises(RuntimeError):
216 dt.getSelector("cons")
217
218 # possible to construct null datatype declarations if not using mkDatatypeDecl
219 with pytest.raises(RuntimeError):
220 DatatypeDecl(solver).getName()
221
222
223 def test_parametric_datatype(solver):
224 v = []
225 t1 = solver.mkParamSort("T1")
226 t2 = solver.mkParamSort("T2")
227 v.append(t1)
228 v.append(t2)
229 pairSpec = solver.mkDatatypeDecl("pair", v)
230
231 mkpair = solver.mkDatatypeConstructorDecl("mk-pair")
232 mkpair.addSelector("first", t1)
233 mkpair.addSelector("second", t2)
234 pairSpec.addConstructor(mkpair)
235
236 pairType = solver.mkDatatypeSort(pairSpec)
237
238 assert pairType.getDatatype().isParametric()
239
240 v.clear()
241 v.append(solver.getIntegerSort())
242 v.append(solver.getIntegerSort())
243 pairIntInt = pairType.instantiate(v)
244 v.clear()
245 v.append(solver.getRealSort())
246 v.append(solver.getRealSort())
247 pairRealReal = pairType.instantiate(v)
248 v.clear()
249 v.append(solver.getRealSort())
250 v.append(solver.getIntegerSort())
251 pairRealInt = pairType.instantiate(v)
252 v.clear()
253 v.append(solver.getIntegerSort())
254 v.append(solver.getRealSort())
255 pairIntReal = pairType.instantiate(v)
256
257 assert pairIntInt != pairRealReal
258 assert pairIntReal != pairRealReal
259 assert pairRealInt != pairRealReal
260 assert pairIntInt != pairIntReal
261 assert pairIntInt != pairRealInt
262 assert pairIntReal != pairRealInt
263
264 assert pairRealReal.isComparableTo(pairRealReal)
265 assert not pairIntReal.isComparableTo(pairRealReal)
266 assert not pairRealInt.isComparableTo(pairRealReal)
267 assert not pairIntInt.isComparableTo(pairRealReal)
268 assert not pairRealReal.isComparableTo(pairRealInt)
269 assert not pairIntReal.isComparableTo(pairRealInt)
270 assert pairRealInt.isComparableTo(pairRealInt)
271 assert not pairIntInt.isComparableTo(pairRealInt)
272 assert not pairRealReal.isComparableTo(pairIntReal)
273 assert pairIntReal.isComparableTo(pairIntReal)
274 assert not pairRealInt.isComparableTo(pairIntReal)
275 assert not pairIntInt.isComparableTo(pairIntReal)
276 assert not pairRealReal.isComparableTo(pairIntInt)
277 assert not pairIntReal.isComparableTo(pairIntInt)
278 assert not pairRealInt.isComparableTo(pairIntInt)
279 assert pairIntInt.isComparableTo(pairIntInt)
280
281 assert pairRealReal.isSubsortOf(pairRealReal)
282 assert not pairIntReal.isSubsortOf(pairRealReal)
283 assert not pairRealInt.isSubsortOf(pairRealReal)
284 assert not pairIntInt.isSubsortOf(pairRealReal)
285 assert not pairRealReal.isSubsortOf(pairRealInt)
286 assert not pairIntReal.isSubsortOf(pairRealInt)
287 assert pairRealInt.isSubsortOf(pairRealInt)
288 assert not pairIntInt.isSubsortOf(pairRealInt)
289 assert not pairRealReal.isSubsortOf(pairIntReal)
290 assert pairIntReal.isSubsortOf(pairIntReal)
291 assert not pairRealInt.isSubsortOf(pairIntReal)
292 assert not pairIntInt.isSubsortOf(pairIntReal)
293 assert not pairRealReal.isSubsortOf(pairIntInt)
294 assert not pairIntReal.isSubsortOf(pairIntInt)
295 assert not pairRealInt.isSubsortOf(pairIntInt)
296 assert pairIntInt.isSubsortOf(pairIntInt)
297
298
299 def test_datatype_simply_rec(solver):
300 # Create mutual datatypes corresponding to this definition block:
301 #
302 # DATATYPE
303 # wlist = leaf(data: list),
304 # list = cons(car: wlist, cdr: list) | nil,
305 # ns = elem(ndata: set(wlist)) | elemArray(ndata2: array(list, list))
306 # END
307
308 # Make unresolved types as placeholders
309 unresTypes = set([])
310 unresWList = solver.mkUninterpretedSort("wlist")
311 unresList = solver.mkUninterpretedSort("list")
312 unresNs = solver.mkUninterpretedSort("ns")
313 unresTypes.add(unresWList)
314 unresTypes.add(unresList)
315 unresTypes.add(unresNs)
316
317 wlist = solver.mkDatatypeDecl("wlist")
318 leaf = solver.mkDatatypeConstructorDecl("leaf")
319 leaf.addSelector("data", unresList)
320 wlist.addConstructor(leaf)
321
322 llist = solver.mkDatatypeDecl("list")
323 cons = solver.mkDatatypeConstructorDecl("cons")
324 cons.addSelector("car", unresWList)
325 cons.addSelector("cdr", unresList)
326 llist.addConstructor(cons)
327 nil = solver.mkDatatypeConstructorDecl("nil")
328 llist.addConstructor(nil)
329
330 ns = solver.mkDatatypeDecl("ns")
331 elem = solver.mkDatatypeConstructorDecl("elem")
332 elem.addSelector("ndata", solver.mkSetSort(unresWList))
333 ns.addConstructor(elem)
334 elemArray = solver.mkDatatypeConstructorDecl("elemArray")
335 elemArray.addSelector("ndata", solver.mkArraySort(unresList, unresList))
336 ns.addConstructor(elemArray)
337
338 dtdecls = [wlist, llist, ns]
339 # this is well-founded and has no nested recursion
340 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
341 assert len(dtsorts) == 3
342 assert dtsorts[0].getDatatype().isWellFounded()
343 assert dtsorts[1].getDatatype().isWellFounded()
344 assert dtsorts[2].getDatatype().isWellFounded()
345 assert not dtsorts[0].getDatatype().hasNestedRecursion()
346 assert not dtsorts[1].getDatatype().hasNestedRecursion()
347 assert not dtsorts[2].getDatatype().hasNestedRecursion()
348
349 # Create mutual datatypes corresponding to this definition block:
350 # DATATYPE
351 # ns2 = elem2(ndata: array(int,ns2)) | nil2
352 # END
353
354 unresTypes.clear()
355 unresNs2 = solver.mkUninterpretedSort("ns2")
356 unresTypes.add(unresNs2)
357
358 ns2 = solver.mkDatatypeDecl("ns2")
359 elem2 = solver.mkDatatypeConstructorDecl("elem2")
360 elem2.addSelector("ndata",
361 solver.mkArraySort(solver.getIntegerSort(), unresNs2))
362 ns2.addConstructor(elem2)
363 nil2 = solver.mkDatatypeConstructorDecl("nil2")
364 ns2.addConstructor(nil2)
365
366 dtdecls.clear()
367 dtdecls.append(ns2)
368
369 # this is not well-founded due to non-simple recursion
370 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
371 assert len(dtsorts) == 1
372 assert dtsorts[0].getDatatype()[0][0].getRangeSort().isArray()
373 assert dtsorts[0].getDatatype()[0][0].getRangeSort().getArrayElementSort() \
374 == dtsorts[0]
375 assert dtsorts[0].getDatatype().isWellFounded()
376 assert dtsorts[0].getDatatype().hasNestedRecursion()
377
378 # Create mutual datatypes corresponding to this definition block:
379 # DATATYPE
380 # list3 = cons3(car: ns3, cdr: list3) | nil3,
381 # ns3 = elem3(ndata: set(list3))
382 # END
383
384 unresTypes.clear()
385 unresNs3 = solver.mkUninterpretedSort("ns3")
386 unresTypes.add(unresNs3)
387 unresList3 = solver.mkUninterpretedSort("list3")
388 unresTypes.add(unresList3)
389
390 list3 = solver.mkDatatypeDecl("list3")
391 cons3 = solver.mkDatatypeConstructorDecl("cons3")
392 cons3.addSelector("car", unresNs3)
393 cons3.addSelector("cdr", unresList3)
394 list3.addConstructor(cons3)
395 nil3 = solver.mkDatatypeConstructorDecl("nil3")
396 list3.addConstructor(nil3)
397
398 ns3 = solver.mkDatatypeDecl("ns3")
399 elem3 = solver.mkDatatypeConstructorDecl("elem3")
400 elem3.addSelector("ndata", solver.mkSetSort(unresList3))
401 ns3.addConstructor(elem3)
402
403 dtdecls.clear()
404 dtdecls.append(list3)
405 dtdecls.append(ns3)
406
407 # both are well-founded and have nested recursion
408 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
409 assert len(dtsorts) == 2
410 assert dtsorts[0].getDatatype().isWellFounded()
411 assert dtsorts[1].getDatatype().isWellFounded()
412 assert dtsorts[0].getDatatype().hasNestedRecursion()
413 assert dtsorts[1].getDatatype().hasNestedRecursion()
414
415 # Create mutual datatypes corresponding to this definition block:
416 # DATATYPE
417 # list4 = cons(car: set(ns4), cdr: list4) | nil,
418 # ns4 = elem(ndata: list4)
419 # END
420 unresTypes.clear()
421 unresNs4 = solver.mkUninterpretedSort("ns4")
422 unresTypes.add(unresNs4)
423 unresList4 = solver.mkUninterpretedSort("list4")
424 unresTypes.add(unresList4)
425
426 list4 = solver.mkDatatypeDecl("list4")
427 cons4 = solver.mkDatatypeConstructorDecl("cons4")
428 cons4.addSelector("car", solver.mkSetSort(unresNs4))
429 cons4.addSelector("cdr", unresList4)
430 list4.addConstructor(cons4)
431 nil4 = solver.mkDatatypeConstructorDecl("nil4")
432 list4.addConstructor(nil4)
433
434 ns4 = solver.mkDatatypeDecl("ns4")
435 elem4 = solver.mkDatatypeConstructorDecl("elem3")
436 elem4.addSelector("ndata", unresList4)
437 ns4.addConstructor(elem4)
438
439 dtdecls.clear()
440 dtdecls.append(list4)
441 dtdecls.append(ns4)
442
443 # both are well-founded and have nested recursion
444 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
445 assert len(dtsorts) == 2
446 assert dtsorts[0].getDatatype().isWellFounded()
447 assert dtsorts[1].getDatatype().isWellFounded()
448 assert dtsorts[0].getDatatype().hasNestedRecursion()
449 assert dtsorts[1].getDatatype().hasNestedRecursion()
450
451 # Create mutual datatypes corresponding to this definition block:
452 # DATATYPE
453 # list5[X] = cons(car: X, cdr: list5[list5[X]]) | nil
454 # END
455 unresTypes.clear()
456 unresList5 = solver.mkSortConstructorSort("list5", 1)
457 unresTypes.add(unresList5)
458
459 v = []
460 x = solver.mkParamSort("X")
461 v.append(x)
462 list5 = solver.mkDatatypeDecl("list5", v)
463
464 args = [x]
465 urListX = unresList5.instantiate(args)
466 args[0] = urListX
467 urListListX = unresList5.instantiate(args)
468
469 cons5 = solver.mkDatatypeConstructorDecl("cons5")
470 cons5.addSelector("car", x)
471 cons5.addSelector("cdr", urListListX)
472 list5.addConstructor(cons5)
473 nil5 = solver.mkDatatypeConstructorDecl("nil5")
474 list5.addConstructor(nil5)
475
476 dtdecls.clear()
477 dtdecls.append(list5)
478
479 # well-founded and has nested recursion
480 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
481 assert len(dtsorts) == 1
482 assert dtsorts[0].getDatatype().isWellFounded()
483 assert dtsorts[0].getDatatype().hasNestedRecursion()
484
485
486 def test_datatype_specialized_cons(solver):
487 # Create mutual datatypes corresponding to this definition block:
488 # DATATYPE
489 # plist[X] = pcons(car: X, cdr: plist[X]) | pnil
490 # END
491
492 # Make unresolved types as placeholders
493 unresTypes = set([])
494 unresList = solver.mkSortConstructorSort("plist", 1)
495 unresTypes.add(unresList)
496
497 v = []
498 x = solver.mkParamSort("X")
499 v.append(x)
500 plist = solver.mkDatatypeDecl("plist", v)
501
502 args = [x]
503 urListX = unresList.instantiate(args)
504
505 pcons = solver.mkDatatypeConstructorDecl("pcons")
506 pcons.addSelector("car", x)
507 pcons.addSelector("cdr", urListX)
508 plist.addConstructor(pcons)
509 nil5 = solver.mkDatatypeConstructorDecl("pnil")
510 plist.addConstructor(nil5)
511
512 dtdecls = [plist]
513
514 # make the datatype sorts
515 dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
516 assert len(dtsorts) == 1
517 d = dtsorts[0].getDatatype()
518 nilc = d[0]
519
520 isort = solver.getIntegerSort()
521 iargs = [isort]
522 listInt = dtsorts[0].instantiate(iargs)
523
524 testConsTerm = Term(solver)
525 # get the specialized constructor term for list[Int]
526 testConsTerm = nilc.getSpecializedConstructorTerm(listInt)
527 assert testConsTerm != nilc.getConstructorTerm()
528 # error to get the specialized constructor term for Int
529 with pytest.raises(RuntimeError):
530 nilc.getSpecializedConstructorTerm(isort)