Refactor initial phase of transcendental solver (#5599)
authorGereon Kremer <nafur42@gmail.com>
Mon, 7 Dec 2020 19:59:10 +0000 (20:59 +0100)
committerGitHub <noreply@github.com>
Mon, 7 Dec 2020 19:59:10 +0000 (20:59 +0100)
This PR refactors the initialization of the transcendental solver, decoupling the setup of generic caches from initial lemmas for exponential and sine functions.

src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/transcendental/exponential_solver.cpp
src/theory/arith/nl/transcendental/exponential_solver.h
src/theory/arith/nl/transcendental/sine_solver.cpp
src/theory/arith/nl/transcendental/sine_solver.h
src/theory/arith/nl/transcendental/transcendental_solver.cpp
src/theory/arith/nl/transcendental/transcendental_solver.h
src/theory/arith/nl/transcendental/transcendental_state.cpp
src/theory/arith/nl/transcendental/transcendental_state.h

index 22a69cadb4ed14709a39d300e62d1aa73bcba24c..9434a049105a5d1532d9bcaa345bbe8d6b9033de 100644 (file)
@@ -656,7 +656,7 @@ void NonlinearExtension::runStrategy(Theory::Effort effort,
         d_tangentPlaneSlv.check(true);
         break;
       case InferStep::TRANS_INIT:
-        d_trSlv.initLastCall(assertions, false_asserts, xts);
+        d_trSlv.initLastCall(xts);
         break;
       case InferStep::TRANS_INITIAL:
         d_trSlv.checkTranscendentalInitialRefine();
index 0726e03b33e8ab93363a6f6e10e7771c246e5986..243fa251b7ed5974b624997f08b30618eb000ee2 100644 (file)
@@ -37,6 +37,19 @@ ExponentialSolver::ExponentialSolver(TranscendentalState* tstate)
 
 ExponentialSolver::~ExponentialSolver() {}
 
+void ExponentialSolver::doPurification(TNode a, TNode new_a, TNode y)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // do both equalities to ensure that new_a becomes a preregistered term
+  Node lem = nm->mkNode(Kind::AND, a.eqNode(new_a), a[0].eqNode(y));
+  // note we must do preprocess on this lemma
+  Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
+                        << std::endl;
+  NlLemma nlem(
+      lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG);
+  d_data->d_im.addPendingArithLemma(nlem);
+}
+
 void ExponentialSolver::checkInitialRefine()
 {
   NodeManager* nm = NodeManager::currentNM();
index f1a30b17746d31398f7d22e35763fef4930f88c2..1fad896e4cafc6bf3c8653306feddf9ec44bae0e 100644 (file)
@@ -49,9 +49,11 @@ class ExponentialSolver
   ExponentialSolver(TranscendentalState* tstate);
   ~ExponentialSolver();
 
-  void initLastCall(const std::vector<Node>& assertions,
-                    const std::vector<Node>& false_asserts,
-                    const std::vector<Node>& xts);
+  /**
+   * Ensures that new_a is properly registered as a term where new_a is the
+   * purified version of a, y being the new skolem used for purification.
+   */
+  void doPurification(TNode a, TNode new_a, TNode y);
 
   /**
    * check initial refine
index 5eed810b37d2567e8f83405338f4a3e5c5cf536b..1747077bd88611b23c4b1683e93fd73c823e3552 100644 (file)
@@ -29,11 +29,55 @@ namespace theory {
 namespace arith {
 namespace nl {
 namespace transcendental {
+namespace {
+
+/**
+ * Ensure a is in the main phase:
+ *   -pi <= a <= pi
+ */
+inline Node mkValidPhase(TNode a, TNode pi)
+{
+  return mkBounded(
+      NodeManager::currentNM()->mkNode(Kind::MULT, mkRationalNode(-1), pi),
+      a,
+      pi);
+}
+}  // namespace
 
 SineSolver::SineSolver(TranscendentalState* tstate) : d_data(tstate) {}
 
 SineSolver::~SineSolver() {}
 
+void SineSolver::doPhaseShift(TNode a, TNode new_a, TNode y)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Assert(a.getKind() == Kind::SINE);
+  Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a << std::endl;
+  Assert(!d_data->d_pi.isNull());
+  Node shift = nm->mkSkolem("s", nm->integerType(), "number of shifts");
+  // TODO (cvc4-projects #47) : do not introduce shift here, instead needs model-based
+  // refinement for constant shifts (cvc4-projects #1284)
+  Node lem = nm->mkNode(
+      Kind::AND,
+      mkValidPhase(y, d_data->d_pi),
+      nm->mkNode(Kind::ITE,
+                 mkValidPhase(a[0], d_data->d_pi),
+                 a[0].eqNode(y),
+                 a[0].eqNode(nm->mkNode(Kind::PLUS,
+                                        y,
+                                        nm->mkNode(Kind::MULT,
+                                                   nm->mkConst(Rational(2)),
+                                                   shift,
+                                                   d_data->d_pi)))),
+      new_a.eqNode(a));
+  // note we must do preprocess on this lemma
+  Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
+                        << std::endl;
+  NlLemma nlem(
+      lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG);
+  d_data->d_im.addPendingArithLemma(nlem);
+}
+
 void SineSolver::checkInitialRefine()
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -52,7 +96,6 @@ void SineSolver::checkInitialRefine()
       if (d_tf_initial_refine.find(t) == d_tf_initial_refine.end())
       {
         d_tf_initial_refine[t] = true;
-        Node lem;
         Node symn = nm->mkNode(Kind::SINE,
                                nm->mkNode(Kind::MULT, d_data->d_neg_one, t[0]));
         symn = Rewriter::rewrite(symn);
@@ -61,54 +104,70 @@ void SineSolver::checkInitialRefine()
         d_data->d_trMaster[symn] = symn;
         d_data->d_trSlaves[symn].insert(symn);
         Assert(d_data->d_trSlaves.find(t) != d_data->d_trSlaves.end());
-        std::vector<Node> children;
 
-        lem =
-            nm->mkNode(Kind::AND,
-                       // bounds
-                       nm->mkNode(Kind::AND,
-                                  nm->mkNode(Kind::LEQ, t, d_data->d_one),
-                                  nm->mkNode(Kind::GEQ, t, d_data->d_neg_one)),
-                       // symmetry
-                       nm->mkNode(Kind::PLUS, t, symn).eqNode(d_data->d_zero),
-                       // sign
-                       nm->mkNode(Kind::EQUAL,
-                                  nm->mkNode(Kind::LT, t[0], d_data->d_zero),
-                                  nm->mkNode(Kind::LT, t, d_data->d_zero)),
-                       // zero val
-                       nm->mkNode(Kind::EQUAL,
-                                  nm->mkNode(Kind::GT, t[0], d_data->d_zero),
-                                  nm->mkNode(Kind::GT, t, d_data->d_zero)));
-        lem = nm->mkNode(
-            Kind::AND,
-            lem,
-            // zero tangent
-            nm->mkNode(Kind::AND,
-                       nm->mkNode(Kind::IMPLIES,
-                                  nm->mkNode(Kind::GT, t[0], d_data->d_zero),
-                                  nm->mkNode(Kind::LT, t, t[0])),
-                       nm->mkNode(Kind::IMPLIES,
-                                  nm->mkNode(Kind::LT, t[0], d_data->d_zero),
-                                  nm->mkNode(Kind::GT, t, t[0]))),
-            // pi tangent
-            nm->mkNode(
-                Kind::AND,
-                nm->mkNode(
-                    Kind::IMPLIES,
-                    nm->mkNode(Kind::LT, t[0], d_data->d_pi),
-                    nm->mkNode(Kind::LT,
-                               t,
-                               nm->mkNode(Kind::MINUS, d_data->d_pi, t[0]))),
-                nm->mkNode(
-                    Kind::IMPLIES,
-                    nm->mkNode(Kind::GT, t[0], d_data->d_pi_neg),
-                    nm->mkNode(
-                        Kind::GT,
-                        t,
-                        nm->mkNode(Kind::MINUS, d_data->d_pi_neg, t[0])))));
-        if (!lem.isNull())
         {
-          d_data->d_im.addPendingArithLemma(lem, InferenceId::NL_T_INIT_REFINE);
+          // sine bounds: -1 <= sin(t) <= 1
+          Node lem = nm->mkNode(Kind::AND,
+                                nm->mkNode(Kind::LEQ, t, d_data->d_one),
+                                nm->mkNode(Kind::GEQ, t, d_data->d_neg_one));
+          d_data->d_im.addPendingArithLemma(
+              lem, InferenceId::NL_T_INIT_REFINE);
+        }
+        {
+          // sine symmetry: sin(t) - sin(-t) = 0
+          Node lem = nm->mkNode(Kind::PLUS, t, symn).eqNode(d_data->d_zero);
+          d_data->d_im.addPendingArithLemma(
+              lem, InferenceId::NL_T_INIT_REFINE);
+        }
+        {
+          // sine zero tangent:
+          //   t > 0  =>  sin(t) < t
+          //   t < 0  =>  sin(t) > t
+          Node lem =
+              nm->mkNode(Kind::AND,
+                         nm->mkNode(Kind::IMPLIES,
+                                    nm->mkNode(Kind::GT, t[0], d_data->d_zero),
+                                    nm->mkNode(Kind::LT, t, t[0])),
+                         nm->mkNode(Kind::IMPLIES,
+                                    nm->mkNode(Kind::LT, t[0], d_data->d_zero),
+                                    nm->mkNode(Kind::GT, t, t[0])));
+          d_data->d_im.addPendingArithLemma(
+              lem, InferenceId::NL_T_INIT_REFINE);
+        }
+        {
+          // sine pi tangent:
+          //   t > -pi  =>  sin(t) > -pi-t
+          //   t <  pi  =>  sin(t) <  pi-t
+          Node lem = nm->mkNode(
+              Kind::AND,
+              nm->mkNode(
+                  Kind::IMPLIES,
+                  nm->mkNode(Kind::GT, t[0], d_data->d_pi_neg),
+                  nm->mkNode(Kind::GT,
+                             t,
+                             nm->mkNode(Kind::MINUS, d_data->d_pi_neg, t[0]))),
+              nm->mkNode(
+                  Kind::IMPLIES,
+                  nm->mkNode(Kind::LT, t[0], d_data->d_pi),
+                  nm->mkNode(Kind::LT,
+                             t,
+                             nm->mkNode(Kind::MINUS, d_data->d_pi, t[0]))));
+          d_data->d_im.addPendingArithLemma(
+              lem, InferenceId::NL_T_INIT_REFINE);
+        }
+        {
+          Node lem =
+              nm->mkNode(Kind::AND,
+                         // sign
+                         nm->mkNode(Kind::EQUAL,
+                                    nm->mkNode(Kind::LT, t[0], d_data->d_zero),
+                                    nm->mkNode(Kind::LT, t, d_data->d_zero)),
+                         // zero val
+                         nm->mkNode(Kind::EQUAL,
+                                    nm->mkNode(Kind::GT, t[0], d_data->d_zero),
+                                    nm->mkNode(Kind::GT, t, d_data->d_zero)));
+          d_data->d_im.addPendingArithLemma(
+              lem, InferenceId::NL_T_INIT_REFINE);
         }
       }
     }
@@ -287,7 +346,8 @@ void SineSolver::doTangentLemma(TNode e, TNode c, TNode poly_approx, int region)
   Trace("nl-ext-sine") << "*** Tangent plane lemma : " << lem << std::endl;
   Assert(d_data->d_model.computeAbstractModelValue(lem) == d_data->d_false);
   // Figure 3 : line 9
-  d_data->d_im.addPendingArithLemma(lem, InferenceId::NL_T_TANGENT, nullptr, true);
+  d_data->d_im.addPendingArithLemma(
+      lem, InferenceId::NL_T_TANGENT, nullptr, true);
 }
 
 void SineSolver::doSecantLemmas(TNode e,
index 15f7d46e838dccedb4ff62520710feaf731af1bb..5eace6104d773f7de37255fbb68ae7d8be28929b 100644 (file)
@@ -49,9 +49,11 @@ class SineSolver
   SineSolver(TranscendentalState* tstate);
   ~SineSolver();
 
-  void initLastCall(const std::vector<Node>& assertions,
-                    const std::vector<Node>& false_asserts,
-                    const std::vector<Node>& xts);
+  /**
+   * Introduces new_a as purified version of a which is also shifted to the main
+   * phase (from -pi to pi). y is the new skolem used for purification.
+   */
+  void doPhaseShift(TNode a, TNode new_a, TNode y);
 
   /**
    * check initial refine
index 2a22853a2cffa347b245a2641f37cddba2ecdef4..c2841c13505c1b336a3146744707de554aa8fe42 100644 (file)
@@ -41,11 +41,36 @@ TranscendentalSolver::TranscendentalSolver(InferenceManager& im, NlModel& m)
 
 TranscendentalSolver::~TranscendentalSolver() {}
 
-void TranscendentalSolver::initLastCall(const std::vector<Node>& assertions,
-                                        const std::vector<Node>& false_asserts,
-                                        const std::vector<Node>& xts)
+void TranscendentalSolver::initLastCall(const std::vector<Node>& xts)
 {
-  d_tstate.init(assertions, false_asserts, xts);
+  std::vector<Node> needsMaster;
+  d_tstate.init(xts, needsMaster);
+
+  if (d_tstate.d_im.hasUsed()) {
+    return;
+  }
+
+  NodeManager* nm = NodeManager::currentNM();
+  for (const Node& a : needsMaster)
+  {
+    // should not have processed this already
+    Assert(d_tstate.d_trMaster.find(a) == d_tstate.d_trMaster.end());
+    Kind k = a.getKind();
+    Assert(k == Kind::SINE || k == Kind::EXPONENTIAL);
+    Node y =
+        nm->mkSkolem("y", nm->realType(), "phase shifted trigonometric arg");
+    Node new_a = nm->mkNode(k, y);
+    d_tstate.d_trSlaves[new_a].insert(new_a);
+    d_tstate.d_trSlaves[new_a].insert(a);
+    d_tstate.d_trMaster[a] = new_a;
+    d_tstate.d_trMaster[new_a] = new_a;
+    switch (k)
+    {
+      case Kind::SINE: d_sineSlv.doPhaseShift(a, new_a, y); break;
+      case Kind::EXPONENTIAL: d_expSlv.doPurification(a, new_a, y); break;
+      default: AlwaysAssert(false) << "Unexpected Kind " << k;
+    }
+  }
 }
 
 bool TranscendentalSolver::preprocessAssertionsCheckModel(
index 80def6f05aebd7c2f13a6e65dbf1faa63fe9e4f2..64f6db1633a036f609e823c5a653b3c4ee01a142 100644 (file)
@@ -79,9 +79,7 @@ class TranscendentalSolver
    * This call may add lemmas to lems based on registering term
    * information (for example, purification of sine terms).
    */
-  void initLastCall(const std::vector<Node>& assertions,
-                    const std::vector<Node>& false_asserts,
-                    const std::vector<Node>& xts);
+  void initLastCall(const std::vector<Node>& xts);
   /** increment taylor degree */
   void incrementTaylorDegree();
   /** get taylor degree */
index 0e47f425773ac5692b40cf1759caa02e9df43dd1..ba60b6a0ee4504d4c533e829a537f025df6dd496 100644 (file)
@@ -48,64 +48,59 @@ TranscendentalState::TranscendentalState(InferenceManager& im, NlModel& model)
   d_neg_one = NodeManager::currentNM()->mkConst(Rational(-1));
 }
 
-void TranscendentalState::init(const std::vector<Node>& assertions,
-                               const std::vector<Node>& false_asserts,
-                               const std::vector<Node>& xts)
+void TranscendentalState::init(const std::vector<Node>& xts,
+                               std::vector<Node>& needsMaster)
 {
   d_funcCongClass.clear();
   d_funcMap.clear();
   d_tf_region.clear();
 
-  NodeManager* nm = NodeManager::currentNM();
-
-  // register the extended function terms
-  std::vector<Node> trNeedsMaster;
   bool needPi = false;
   // for computing congruence
   std::map<Kind, ArgTrie> argTrie;
-  for (unsigned i = 0, xsize = xts.size(); i < xsize; i++)
+  for (std::size_t i = 0, xsize = xts.size(); i < xsize; ++i)
   {
+    // Ignore if it is not a transcendental
+    if (!isTranscendentalKind(xts[i].getKind()))
+    {
+      continue;
+    }
     Node a = xts[i];
     Kind ak = a.getKind();
     bool consider = true;
-    // if is an unpurified application of SINE, or it is a transcendental
-    // applied to a trancendental, purify.
-    if (isTranscendentalKind(ak))
+    // if we've already computed master for a
+    if (d_trMaster.find(a) != d_trMaster.end())
+    {
+      // a master has at least one slave
+      consider = (d_trSlaves.find(a) != d_trSlaves.end());
+    }
+    else
     {
-      // if we've already computed master for a
-      if (d_trMaster.find(a) != d_trMaster.end())
+      if (ak == Kind::SINE)
       {
-        // a master has at least one slave
-        consider = (d_trSlaves.find(a) != d_trSlaves.end());
+        // always not a master
+        consider = false;
       }
       else
       {
-        if (ak == Kind::SINE)
-        {
-          // always not a master
-          consider = false;
-        }
-        else
+        for (const Node& ac : a)
         {
-          for (const Node& ac : a)
+          if (isTranscendentalKind(ac.getKind()))
           {
-            if (isTranscendentalKind(ac.getKind()))
-            {
-              consider = false;
-              break;
-            }
+            consider = false;
+            break;
           }
         }
-        if (!consider)
-        {
-          // wait to assign a master below
-          trNeedsMaster.push_back(a);
-        }
-        else
-        {
-          d_trMaster[a] = a;
-          d_trSlaves[a].insert(a);
-        }
+      }
+      if (!consider)
+      {
+        // wait to assign a master below
+        needsMaster.push_back(a);
+      }
+      else
+      {
+        d_trMaster[a] = a;
+        d_trSlaves[a].insert(a);
       }
     }
     if (ak == Kind::EXPONENTIAL || ak == Kind::SINE)
@@ -114,38 +109,7 @@ void TranscendentalState::init(const std::vector<Node>& assertions,
       // if we didn't indicate that it should be purified above
       if (consider)
       {
-        std::vector<Node> repList;
-        for (const Node& ac : a)
-        {
-          Node r = d_model.computeConcreteModelValue(ac);
-          repList.push_back(r);
-        }
-        Node aa = argTrie[ak].add(a, repList);
-        if (aa != a)
-        {
-          // apply congruence to pairs of terms that are disequal and congruent
-          Assert(aa.getNumChildren() == a.getNumChildren());
-          Node mvaa = d_model.computeAbstractModelValue(a);
-          Node mvaaa = d_model.computeAbstractModelValue(aa);
-          if (mvaa != mvaaa)
-          {
-            std::vector<Node> exp;
-            for (unsigned j = 0, size = a.getNumChildren(); j < size; j++)
-            {
-              exp.push_back(a[j].eqNode(aa[j]));
-            }
-            Node expn = exp.size() == 1 ? exp[0] : nm->mkNode(Kind::AND, exp);
-            Node cong_lemma = nm->mkNode(Kind::OR, expn.negate(), a.eqNode(aa));
-            d_im.addPendingArithLemma(cong_lemma, InferenceId::NL_CONGRUENCE);
-          }
-        }
-        else
-        {
-          // new representative of congruence class
-          d_funcMap[ak].push_back(a);
-        }
-        // add to congruence class
-        d_funcCongClass[aa].push_back(a);
+        ensureCongruence(a, argTrie);
       }
     }
     else if (ak == Kind::PI)
@@ -163,61 +127,6 @@ void TranscendentalState::init(const std::vector<Node>& assertions,
     getCurrentPiBounds();
   }
 
-  if (d_im.hasUsed())
-  {
-    return;
-  }
-
-  // process SINE phase shifting
-  for (const Node& a : trNeedsMaster)
-  {
-    // should not have processed this already
-    Assert(d_trMaster.find(a) == d_trMaster.end());
-    Kind k = a.getKind();
-    Assert(k == Kind::SINE || k == Kind::EXPONENTIAL);
-    Node y =
-        nm->mkSkolem("y", nm->realType(), "phase shifted trigonometric arg");
-    Node new_a = nm->mkNode(k, y);
-    d_trSlaves[new_a].insert(new_a);
-    d_trSlaves[new_a].insert(a);
-    d_trMaster[a] = new_a;
-    d_trMaster[new_a] = new_a;
-    Node lem;
-    if (k == Kind::SINE)
-    {
-      Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a
-                         << std::endl;
-      Assert(!d_pi.isNull());
-      Node shift = nm->mkSkolem("s", nm->integerType(), "number of shifts");
-      // TODO : do not introduce shift here, instead needs model-based
-      // refinement for constant shifts (cvc4-projects #1284)
-      lem = nm->mkNode(
-          Kind::AND,
-          transcendental::mkValidPhase(y, d_pi),
-          nm->mkNode(
-              Kind::ITE,
-              transcendental::mkValidPhase(a[0], d_pi),
-              a[0].eqNode(y),
-              a[0].eqNode(nm->mkNode(
-                  Kind::PLUS,
-                  y,
-                  nm->mkNode(
-                      Kind::MULT, nm->mkConst(Rational(2)), shift, d_pi)))),
-          new_a.eqNode(a));
-    }
-    else
-    {
-      // do both equalities to ensure that new_a becomes a preregistered term
-      lem = nm->mkNode(Kind::AND, a.eqNode(new_a), a[0].eqNode(y));
-    }
-    // note we must do preprocess on this lemma
-    Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
-                          << std::endl;
-    NlLemma nlem(
-        lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG);
-    d_im.addPendingArithLemma(nlem);
-  }
-
   if (Trace.isOn("nl-ext-mv"))
   {
     Trace("nl-ext-mv") << "Arguments of trancendental functions : "
@@ -239,6 +148,44 @@ void TranscendentalState::init(const std::vector<Node>& assertions,
   }
 }
 
+void TranscendentalState::ensureCongruence(TNode a,
+                                           std::map<Kind, ArgTrie>& argTrie)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<Node> repList;
+  for (const Node& ac : a)
+  {
+    Node r = d_model.computeConcreteModelValue(ac);
+    repList.push_back(r);
+  }
+  Node aa = argTrie[a.getKind()].add(a, repList);
+  if (aa != a)
+  {
+    // apply congruence to pairs of terms that are disequal and congruent
+    Assert(aa.getNumChildren() == a.getNumChildren());
+    Node mvaa = d_model.computeAbstractModelValue(a);
+    Node mvaaa = d_model.computeAbstractModelValue(aa);
+    if (mvaa != mvaaa)
+    {
+      std::vector<Node> exp;
+      for (unsigned j = 0, size = a.getNumChildren(); j < size; j++)
+      {
+        exp.push_back(a[j].eqNode(aa[j]));
+      }
+      Node expn = exp.size() == 1 ? exp[0] : nm->mkNode(Kind::AND, exp);
+      Node cong_lemma = expn.impNode(a.eqNode(aa));
+      d_im.addPendingArithLemma(cong_lemma, InferenceId::NL_CONGRUENCE);
+    }
+  }
+  else
+  {
+    // new representative of congruence class
+    d_funcMap[a.getKind()].push_back(a);
+  }
+  // add to congruence class
+  d_funcCongClass[aa].push_back(a);
+}
+
 void TranscendentalState::mkPi()
 {
   NodeManager* nm = NodeManager::currentNM();
index 7062e8183307b1535cdb93cfb1918ce9644d71ed..f940ae2e3f5b56fb022c7a2bfccd131824ddd202 100644 (file)
@@ -56,18 +56,23 @@ struct TranscendentalState
 
   /** init last call
    *
-   * This is called at the beginning of last call effort check, where
-   * assertions are the set of assertions belonging to arithmetic,
-   * false_asserts is the subset of assertions that are false in the current
-   * model, and xts is the set of extended function terms that are active in
-   * the current context.
+   * This is called at the beginning of last call effort check xts is the set of
+   * extended function terms that are active in the current context.
    *
    * This call may add lemmas to lems based on registering term
-   * information (for example, purification of sine terms).
+   * information (for example to ensure congruence of terms).
+   * It puts terms that need to be treated further as a master term on their own
+   * (for example purification of sine terms) into needsMaster.
    */
-  void init(const std::vector<Node>& assertions,
-            const std::vector<Node>& false_asserts,
-            const std::vector<Node>& xts);
+  void init(const std::vector<Node>& xts, std::vector<Node>& needsMaster);
+
+  /**
+   * Checks for terms that are congruent but disequal to a.
+   * If any are found, appropriate lemmas are sent.
+   * @param a Some node
+   * @param argTrie Lookup for equivalence classes
+   */
+  void ensureCongruence(TNode a, std::map<Kind, ArgTrie>& argTrie);
 
   /** Initialize members for pi-related values */
   void mkPi();