Fix CEGIS refinement for recursive functions evaluation (#3555)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 12 Dec 2019 03:24:43 +0000 (21:24 -0600)
committerGitHub <noreply@github.com>
Thu, 12 Dec 2019 03:24:43 +0000 (21:24 -0600)
src/theory/quantifiers/sygus/term_database_sygus.cpp
test/regress/CMakeLists.txt
test/regress/regress1/sygus/node-discrete.sy [new file with mode: 0644]

index c5ea0f9f3af36612560e5ac5170f8f7f5864cba1..08fb58e40258510f0f7743eb78cb305d3aab13f0 100644 (file)
@@ -735,6 +735,11 @@ SygusTypeInfo& TermDbSygus::getTypeInfo(TypeNode tn)
 Node TermDbSygus::rewriteNode(Node n) const
 {
   Node res = Rewriter::rewrite(n);
+  if (res.isConst())
+  {
+    // constant, we are done
+    return res;
+  }
   if (options::sygusRecFun())
   {
     if (d_funDefEval->hasDefinitions())
@@ -1006,34 +1011,13 @@ Node TermDbSygus::evaluateWithUnfolding(
     {
       if (ret == n && ret[0].isConst())
       {
-        Trace("dt-eval-unfold-debug")
-            << "Optimize: evaluate constant head " << ret << std::endl;
-        // can just do direct evaluation here
-        // notice we prefer this code to the rewriter since it may use
-        // the evaluator
-        std::vector<Node> args;
-        bool success = true;
-        for (unsigned i = 1, nchild = ret.getNumChildren(); i < nchild; i++)
-        {
-          if (!ret[i].isConst())
-          {
-            success = false;
-            break;
-          }
-          args.push_back(ret[i]);
-        }
-        if (success)
-        {
-          TypeNode rt = ret[0].getType();
-          Node bret = sygusToBuiltin(ret[0], rt);
-          Node rete = evaluateBuiltin(rt, bret, args);
-          visited[n] = rete;
-          Trace("dt-eval-unfold-debug")
-              << "Return " << rete << " for " << n << std::endl;
-          return rete;
-        }
+        // use rewriting, possibly involving recursive functions
+        ret = rewriteNode(ret);
+      }
+      else
+      {
+        ret = d_eval_unfold->unfold(ret);
       }
-      ret = d_eval_unfold->unfold(ret);
     }    
     if( ret.getNumChildren()>0 ){
       std::vector< Node > children;
@@ -1050,6 +1034,8 @@ Node TermDbSygus::evaluateWithUnfolding(
         ret = NodeManager::currentNM()->mkNode( ret.getKind(), children );
       }
       ret = getExtRewriter()->extendedRewrite(ret);
+      // use rewriting, possibly involving recursive functions
+      ret = rewriteNode(ret);
     }
     visited[n] = ret;
     return ret;
index 814eaab497ead116b83c3cda156c8380bcb7c17a..7f07224e0729246aabec01993557eec1806de7cb 100644 (file)
@@ -1750,6 +1750,7 @@ set(regress_1_tests
   regress1/sygus/nflat-fwd-3.sy
   regress1/sygus/nflat-fwd.sy
   regress1/sygus/nia-max-square-ns.sy
+  regress1/sygus/node-discrete.sy
   regress1/sygus/no-flat-simp.sy
   regress1/sygus/no-mention.sy
   regress1/sygus/once_2.sy
diff --git a/test/regress/regress1/sygus/node-discrete.sy b/test/regress/regress1/sygus/node-discrete.sy
new file mode 100644 (file)
index 0000000..16f7243
--- /dev/null
@@ -0,0 +1,216 @@
+; EXPECT: unsat
+; COMMAND-LINE: --sygus-out=status
+(set-logic ALL)
+
+(declare-datatype Packet ((P1) (P2)))
+
+(declare-datatype Node ((A) (B) (C)))
+
+(declare-datatype SPair ((mkPair (pnode Node) (ppacket Packet))))
+
+(declare-datatype State ((mkState (rcv (Array SPair Bool)))))
+(declare-datatype StateList ((consSL (headSL State) (tailSL StateList)) (nilSL)))
+
+; C is destination of P1 and P2
+(define-fun h_State ((s State)) Real
+  (+ 
+    (ite (select (rcv s) (mkPair C P1)) 1.0 0.0)
+    (ite (select (rcv s) (mkPair C P2)) 1.0 0.0)
+  )
+)
+
+; reliability
+(define-fun rel () Real 0.7)
+
+; new chance of success
+(define-fun updateReal ((addP Real) (currP Real)) Real
+  (+ currP (* (- 1.0 currP) addP))
+)
+
+; Actions and how they are interpreted
+
+(declare-datatype Action (
+  (sleep) 
+  (pushPck (push_dst Node) (push_pck Packet))
+  (pullPck (pull_src Node) (pull_pck Packet))
+))
+(declare-datatype ActionList ((consAL (headA Action) (tailA ActionList)) (nilAL)))
+
+;; returns true if action is valid for actor in state s
+(define-fun preconditionAction ((actor Node) (a Action) (s State)) Bool
+  (let ((rcv (rcv s)))
+  (ite ((_ is pullPck) a)
+    (let ((pck (pull_pck a)))
+    ; don't pull if already recieved the packet
+    (not (select rcv (mkPair actor pck)))
+    )
+    true
+  )
+  )
+)
+
+; which action fires in state s?
+(define-fun-rec actionListToAction ((actor Node) (al ActionList) (s State)) Action
+  (ite ((_ is consAL) al)
+    (let ((a (headA al)))
+    (ite (preconditionAction actor a s)
+      a
+      (actionListToAction actor (tailA al) s)
+    )
+    )
+    sleep
+  )
+)
+
+(declare-datatype PState ((mkPState (states StateList) (prob (Array State Real)))))
+
+(define-fun-rec h_PState_rec ((pssl StateList) (pspb (Array State Real))) Real
+  (ite ((_ is consSL) pssl)
+    (let ((s (headSL pssl)))
+      (+ (* (select pspb s) (h_State s)) (h_PState_rec (tailSL pssl) pspb))
+    )
+    0.0)
+)
+(define-fun h_PState ((ps PState)) Real
+  (h_PState_rec (states ps) (prob ps))
+)
+
+(define-fun nilPState () PState
+  (mkPState nilSL ((as const (Array State Real)) 0))
+)
+(define-fun-rec appendStateToPState ((s State) (r Real) (p PState)) PState
+  (let ((pstates (states p)))
+  (let ((pprob (prob p)))
+  (let ((pr (select pprob s)))
+  (mkPState
+    ; add to list if not there already
+    (ite (= pr 0.0)
+      (consSL s pstates)
+      pstates
+    )
+    (store 
+      pprob
+      s (+ r pr)
+    )
+  )
+  )))
+)
+
+
+(define-fun transNode ((actor Node) (a Action) (r Real) (s State) (psp PState)) PState
+  (let ((prevRcv (rcv s)))
+  (ite ((_ is pushPck) a)
+    (let ((dst (push_dst a))) 
+    (let ((pck (push_pck a)))
+    (let ((dst_pair (mkPair dst pck)))
+    (let ((src_pair (mkPair actor pck)))
+    (let ((chSuccess (ite (select prevRcv src_pair) rel 0.0)))
+    ; success and failure
+    (appendStateToPState
+      (mkState (store prevRcv dst_pair true))
+      (* r chSuccess)
+    (appendStateToPState
+      s 
+      (* r (- 1.0 chSuccess))
+      psp
+    ))
+    )))))
+  (ite ((_ is pullPck) a)
+    (let ((src (pull_src a))) 
+    (let ((pck (pull_pck a)))
+    (let ((dst_pair (mkPair actor pck)))
+    (let ((src_pair (mkPair src pck)))
+    (let ((chSuccess (ite (select prevRcv src_pair) rel 0.0)))
+    ; success and failure
+    (appendStateToPState
+      (mkState (store prevRcv dst_pair true))
+      (* r chSuccess)
+    (appendStateToPState
+      s 
+      (* r (- 1.0 chSuccess))
+      psp
+    ))
+    )))))
+    (appendStateToPState
+      s 
+      r
+      psp)
+  ))
+  )
+)
+
+(define-fun-rec transNodeListRec ((actor Node) (al ActionList) (ps PState) (pssl StateList) (psp PState)) PState
+  ; if more states to consider in s
+  (ite ((_ is consSL) pssl) 
+    (let ((s (headSL pssl)))
+    (let ((r (select (prob ps) s)))
+      (transNode actor (actionListToAction actor al s) r s 
+        (transNodeListRec actor al ps (tailSL pssl) psp))
+    ))
+    psp)
+)
+
+(define-fun-rec transNodeList ((actor Node) (al ActionList) (ps PState) (psp PState)) PState
+  (transNodeListRec actor al ps (states ps) psp)
+)
+
+(define-fun trans ((aa ActionList) (ab ActionList) (ac ActionList) (ps PState)) PState
+  ;(transNodeList A aa ps
+  ;(transNodeList B ab ps
+  (transNodeList C ac ps
+    nilPState);))
+)
+
+(synth-fun actionA () ActionList
+  ((GAL ActionList) (GA Action) (GN Node) (GP Packet))
+  (
+  (GAL ActionList (nilAL))
+  (GA Action ((pushPck GN GP) (pullPck GN GP)))
+  (GN Node (B C))
+  (GP Packet (P1 P2))
+  )
+)
+(synth-fun actionB () ActionList
+  ((GAL ActionList) (GA Action) (GN Node) (GP Packet))
+  (
+  (GAL ActionList (nilAL))
+  (GA Action ((pushPck GN GP) (pullPck GN GP)))
+  (GN Node (A C))
+  (GP Packet (P1 P2))
+  )
+)
+(synth-fun actionC () ActionList
+  ((GAL ActionList) (GA Action) (GN Node) (GP Packet))
+  (
+  (GAL ActionList ((consAL GA GAL) nilAL))
+  (GA Action ((pushPck GN GP) (pullPck GN GP)))
+  (GN Node (A B))
+  (GP Packet (P1 P2))
+  )
+)
+
+
+; A and B initially have packets P1 and P2
+(define-fun init-state () State
+  (mkState
+    (store 
+      (store
+        ((as const (Array SPair Bool)) false)
+        (mkPair B P2) true
+      )
+      (mkPair A P1) true
+    )
+  )
+)
+
+(define-fun init-pstate () PState
+  (appendStateToPState init-state 1.0 nilPState)
+)
+
+; expected value of packets is greater than 1.0 after 2 time steps.
+(constraint
+  (< 1.0 (h_PState
+         (trans actionA actionB actionC (trans actionA actionB actionC init-pstate))
+         ))
+  )
+(check-synth)