From a9cfdf6710e6a1bc4dd49bf09263fd8bce1af6b5 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 11 Dec 2019 21:24:43 -0600 Subject: [PATCH] Fix CEGIS refinement for recursive functions evaluation (#3555) --- .../quantifiers/sygus/term_database_sygus.cpp | 40 ++-- test/regress/CMakeLists.txt | 1 + test/regress/regress1/sygus/node-discrete.sy | 216 ++++++++++++++++++ 3 files changed, 230 insertions(+), 27 deletions(-) create mode 100644 test/regress/regress1/sygus/node-discrete.sy diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index c5ea0f9f3..08fb58e40 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -735,6 +735,11 @@ SygusTypeInfo& TermDbSygus::getTypeInfo(TypeNode tn) Node TermDbSygus::rewriteNode(Node n) const { Node res = Rewriter::rewrite(n); + if (res.isConst()) + { + // constant, we are done + return res; + } if (options::sygusRecFun()) { if (d_funDefEval->hasDefinitions()) @@ -1006,34 +1011,13 @@ Node TermDbSygus::evaluateWithUnfolding( { if (ret == n && ret[0].isConst()) { - Trace("dt-eval-unfold-debug") - << "Optimize: evaluate constant head " << ret << std::endl; - // can just do direct evaluation here - // notice we prefer this code to the rewriter since it may use - // the evaluator - std::vector args; - bool success = true; - for (unsigned i = 1, nchild = ret.getNumChildren(); i < nchild; i++) - { - if (!ret[i].isConst()) - { - success = false; - break; - } - args.push_back(ret[i]); - } - if (success) - { - TypeNode rt = ret[0].getType(); - Node bret = sygusToBuiltin(ret[0], rt); - Node rete = evaluateBuiltin(rt, bret, args); - visited[n] = rete; - Trace("dt-eval-unfold-debug") - << "Return " << rete << " for " << n << std::endl; - return rete; - } + // use rewriting, possibly involving recursive functions + ret = rewriteNode(ret); + } + else + { + ret = d_eval_unfold->unfold(ret); } - ret = d_eval_unfold->unfold(ret); } if( ret.getNumChildren()>0 ){ std::vector< Node > children; @@ -1050,6 +1034,8 @@ Node TermDbSygus::evaluateWithUnfolding( ret = NodeManager::currentNM()->mkNode( ret.getKind(), children ); } ret = getExtRewriter()->extendedRewrite(ret); + // use rewriting, possibly involving recursive functions + ret = rewriteNode(ret); } visited[n] = ret; return ret; diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 814eaab49..7f07224e0 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1750,6 +1750,7 @@ set(regress_1_tests regress1/sygus/nflat-fwd-3.sy regress1/sygus/nflat-fwd.sy regress1/sygus/nia-max-square-ns.sy + regress1/sygus/node-discrete.sy regress1/sygus/no-flat-simp.sy regress1/sygus/no-mention.sy regress1/sygus/once_2.sy diff --git a/test/regress/regress1/sygus/node-discrete.sy b/test/regress/regress1/sygus/node-discrete.sy new file mode 100644 index 000000000..16f7243d1 --- /dev/null +++ b/test/regress/regress1/sygus/node-discrete.sy @@ -0,0 +1,216 @@ +; EXPECT: unsat +; COMMAND-LINE: --sygus-out=status +(set-logic ALL) + +(declare-datatype Packet ((P1) (P2))) + +(declare-datatype Node ((A) (B) (C))) + +(declare-datatype SPair ((mkPair (pnode Node) (ppacket Packet)))) + +(declare-datatype State ((mkState (rcv (Array SPair Bool))))) +(declare-datatype StateList ((consSL (headSL State) (tailSL StateList)) (nilSL))) + +; C is destination of P1 and P2 +(define-fun h_State ((s State)) Real + (+ + (ite (select (rcv s) (mkPair C P1)) 1.0 0.0) + (ite (select (rcv s) (mkPair C P2)) 1.0 0.0) + ) +) + +; reliability +(define-fun rel () Real 0.7) + +; new chance of success +(define-fun updateReal ((addP Real) (currP Real)) Real + (+ currP (* (- 1.0 currP) addP)) +) + +; Actions and how they are interpreted + +(declare-datatype Action ( + (sleep) + (pushPck (push_dst Node) (push_pck Packet)) + (pullPck (pull_src Node) (pull_pck Packet)) +)) +(declare-datatype ActionList ((consAL (headA Action) (tailA ActionList)) (nilAL))) + +;; returns true if action is valid for actor in state s +(define-fun preconditionAction ((actor Node) (a Action) (s State)) Bool + (let ((rcv (rcv s))) + (ite ((_ is pullPck) a) + (let ((pck (pull_pck a))) + ; don't pull if already recieved the packet + (not (select rcv (mkPair actor pck))) + ) + true + ) + ) +) + +; which action fires in state s? +(define-fun-rec actionListToAction ((actor Node) (al ActionList) (s State)) Action + (ite ((_ is consAL) al) + (let ((a (headA al))) + (ite (preconditionAction actor a s) + a + (actionListToAction actor (tailA al) s) + ) + ) + sleep + ) +) + +(declare-datatype PState ((mkPState (states StateList) (prob (Array State Real))))) + +(define-fun-rec h_PState_rec ((pssl StateList) (pspb (Array State Real))) Real + (ite ((_ is consSL) pssl) + (let ((s (headSL pssl))) + (+ (* (select pspb s) (h_State s)) (h_PState_rec (tailSL pssl) pspb)) + ) + 0.0) +) +(define-fun h_PState ((ps PState)) Real + (h_PState_rec (states ps) (prob ps)) +) + +(define-fun nilPState () PState + (mkPState nilSL ((as const (Array State Real)) 0)) +) +(define-fun-rec appendStateToPState ((s State) (r Real) (p PState)) PState + (let ((pstates (states p))) + (let ((pprob (prob p))) + (let ((pr (select pprob s))) + (mkPState + ; add to list if not there already + (ite (= pr 0.0) + (consSL s pstates) + pstates + ) + (store + pprob + s (+ r pr) + ) + ) + ))) +) + + +(define-fun transNode ((actor Node) (a Action) (r Real) (s State) (psp PState)) PState + (let ((prevRcv (rcv s))) + (ite ((_ is pushPck) a) + (let ((dst (push_dst a))) + (let ((pck (push_pck a))) + (let ((dst_pair (mkPair dst pck))) + (let ((src_pair (mkPair actor pck))) + (let ((chSuccess (ite (select prevRcv src_pair) rel 0.0))) + ; success and failure + (appendStateToPState + (mkState (store prevRcv dst_pair true)) + (* r chSuccess) + (appendStateToPState + s + (* r (- 1.0 chSuccess)) + psp + )) + ))))) + (ite ((_ is pullPck) a) + (let ((src (pull_src a))) + (let ((pck (pull_pck a))) + (let ((dst_pair (mkPair actor pck))) + (let ((src_pair (mkPair src pck))) + (let ((chSuccess (ite (select prevRcv src_pair) rel 0.0))) + ; success and failure + (appendStateToPState + (mkState (store prevRcv dst_pair true)) + (* r chSuccess) + (appendStateToPState + s + (* r (- 1.0 chSuccess)) + psp + )) + ))))) + (appendStateToPState + s + r + psp) + )) + ) +) + +(define-fun-rec transNodeListRec ((actor Node) (al ActionList) (ps PState) (pssl StateList) (psp PState)) PState + ; if more states to consider in s + (ite ((_ is consSL) pssl) + (let ((s (headSL pssl))) + (let ((r (select (prob ps) s))) + (transNode actor (actionListToAction actor al s) r s + (transNodeListRec actor al ps (tailSL pssl) psp)) + )) + psp) +) + +(define-fun-rec transNodeList ((actor Node) (al ActionList) (ps PState) (psp PState)) PState + (transNodeListRec actor al ps (states ps) psp) +) + +(define-fun trans ((aa ActionList) (ab ActionList) (ac ActionList) (ps PState)) PState + ;(transNodeList A aa ps + ;(transNodeList B ab ps + (transNodeList C ac ps + nilPState);)) +) + +(synth-fun actionA () ActionList + ((GAL ActionList) (GA Action) (GN Node) (GP Packet)) + ( + (GAL ActionList (nilAL)) + (GA Action ((pushPck GN GP) (pullPck GN GP))) + (GN Node (B C)) + (GP Packet (P1 P2)) + ) +) +(synth-fun actionB () ActionList + ((GAL ActionList) (GA Action) (GN Node) (GP Packet)) + ( + (GAL ActionList (nilAL)) + (GA Action ((pushPck GN GP) (pullPck GN GP))) + (GN Node (A C)) + (GP Packet (P1 P2)) + ) +) +(synth-fun actionC () ActionList + ((GAL ActionList) (GA Action) (GN Node) (GP Packet)) + ( + (GAL ActionList ((consAL GA GAL) nilAL)) + (GA Action ((pushPck GN GP) (pullPck GN GP))) + (GN Node (A B)) + (GP Packet (P1 P2)) + ) +) + + +; A and B initially have packets P1 and P2 +(define-fun init-state () State + (mkState + (store + (store + ((as const (Array SPair Bool)) false) + (mkPair B P2) true + ) + (mkPair A P1) true + ) + ) +) + +(define-fun init-pstate () PState + (appendStateToPState init-state 1.0 nilPState) +) + +; expected value of packets is greater than 1.0 after 2 time steps. +(constraint + (< 1.0 (h_PState + (trans actionA actionB actionC (trans actionA actionB actionC init-pstate)) + )) + ) +(check-synth) -- 2.30.2