Add helper functions for multi-objective optimization + refactoring (#6473)
authorOuyancheng <1024842937@qq.com>
Wed, 5 May 2021 00:34:54 +0000 (17:34 -0700)
committerGitHub <noreply@github.com>
Wed, 5 May 2021 00:34:54 +0000 (00:34 +0000)
add 3 helper functions
judge whether a node is optimizable
make strong improvement expression according to optimization objective
make weak improvement expression according to optimization objective
optChecker is now created by optimizationSolver instead of the minimize/maximize functions
Slightly refactors function signatures so that they are accepting OptimizationObjective instead of accepting target, type in separate parameters.

src/omt/bitvector_optimizer.cpp
src/omt/bitvector_optimizer.h
src/omt/integer_optimizer.cpp
src/omt/integer_optimizer.h
src/omt/omt_optimizer.cpp
src/omt/omt_optimizer.h
src/smt/optimization_solver.cpp
src/smt/optimization_solver.h

index 7edecdb73a3f58de884a5d16237ce8fee906a09f..1c3faddff00fe4a353a10a8734a9f9501d6a92cb 100644 (file)
@@ -43,13 +43,11 @@ BitVector OMTOptimizerBitVector::computeAverage(const BitVector& a,
                         + aMod2PlusbMod2Div2));
 }
 
-OptimizationResult OMTOptimizerBitVector::minimize(SmtEngine* parentSMTSolver,
+OptimizationResult OMTOptimizerBitVector::minimize(SmtEngine* optChecker,
                                                    TNode target)
 {
   // the smt engine to which we send intermediate queries
   // for the binary search.
-  std::unique_ptr<SmtEngine> optChecker =
-      OMTOptimizer::createOptCheckerWithTimeout(parentSMTSolver, false);
   NodeManager* nm = optChecker->getNodeManager();
   Result intermediateSatResult = optChecker->checkSat();
   // Model-value of objective (used in optimization loop)
@@ -137,13 +135,11 @@ OptimizationResult OMTOptimizerBitVector::minimize(SmtEngine* parentSMTSolver,
   return OptimizationResult(OptimizationResult::OPTIMAL, value);
 }
 
-OptimizationResult OMTOptimizerBitVector::maximize(SmtEngine* parentSMTSolver,
+OptimizationResult OMTOptimizerBitVector::maximize(SmtEngine* optChecker,
                                                    TNode target)
 {
   // the smt engine to which we send intermediate queries
   // for the binary search.
-  std::unique_ptr<SmtEngine> optChecker =
-      OMTOptimizer::createOptCheckerWithTimeout(parentSMTSolver, false);
   NodeManager* nm = optChecker->getNodeManager();
   Result intermediateSatResult = optChecker->checkSat();
   // Model-value of objective (used in optimization loop)
index b95c185f8e82ca0c8dac46869b726d03a2386a90..3b1bdebca565612e3bc0ed25bf4cda6f0fc173a8 100644 (file)
@@ -28,9 +28,9 @@ class OMTOptimizerBitVector : public OMTOptimizer
  public:
   OMTOptimizerBitVector(bool isSigned);
   virtual ~OMTOptimizerBitVector() = default;
-  smt::OptimizationResult minimize(SmtEngine* parentSMTSolver,
+  smt::OptimizationResult minimize(SmtEngine* optChecker,
                                    TNode target) override;
-  smt::OptimizationResult maximize(SmtEngine* parentSMTSolver,
+  smt::OptimizationResult maximize(SmtEngine* optChecker,
                                    TNode target) override;
 
  private:
index 8fbfff1a2b475702863e310b507af9c190da6e7d..5e3dc15fffdeaa3bc1a7bfbeba68330975e2b70c 100644 (file)
 using namespace cvc5::smt;
 namespace cvc5::omt {
 
-OptimizationResult OMTOptimizerInteger::optimize(
-    SmtEngine* parentSMTSolver,
-    TNode target,
-    OptimizationObjective::ObjectiveType objType)
+OptimizationResult OMTOptimizerInteger::optimize(SmtEngine* optChecker,
+                                                 TNode target,
+                                                 bool isMinimize)
 {
   // linear search for integer goal
   // the smt engine to which we send intermediate queries
   // for the linear search.
-  std::unique_ptr<SmtEngine> optChecker =
-      OMTOptimizer::createOptCheckerWithTimeout(parentSMTSolver, false);
   NodeManager* nm = optChecker->getNodeManager();
 
   Result intermediateSatResult = optChecker->checkSat();
@@ -47,16 +44,16 @@ OptimizationResult OMTOptimizerInteger::optimize(
   // asserts objective > old_value (used in optimization loop)
   Node increment;
   Kind incrementalOperator = kind::NULL_EXPR;
-  if (objType == OptimizationObjective::MINIMIZE)
+  if (isMinimize)
   {
-    // if objective is MIN, then assert optimization_target <
-    // current_model_value
+    // if objective is minimize,
+    // then assert optimization_target < current_model_value
     incrementalOperator = kind::LT;
   }
-  else if (objType == OptimizationObjective::MAXIMIZE)
+  else
   {
-    // if objective is MAX, then assert optimization_target >
-    // current_model_value
+    // if objective is maximize,
+    // then assert optimization_target > current_model_value
     incrementalOperator = kind::GT;
   }
   // Workhorse of linear search:
@@ -74,17 +71,15 @@ OptimizationResult OMTOptimizerInteger::optimize(
   return OptimizationResult(OptimizationResult::OPTIMAL, value);
 }
 
-OptimizationResult OMTOptimizerInteger::minimize(
-    SmtEngine* parentSMTSolver, TNode target)
+OptimizationResult OMTOptimizerInteger::minimize(SmtEngine* optChecker,
+                                                 TNode target)
 {
-  return this->optimize(
-      parentSMTSolver, target, OptimizationObjective::MINIMIZE);
+  return this->optimize(optChecker, target, true);
 }
-OptimizationResult OMTOptimizerInteger::maximize(
-    SmtEngine* parentSMTSolver, TNode target)
+OptimizationResult OMTOptimizerInteger::maximize(SmtEngine* optChecker,
+                                                 TNode target)
 {
-  return this->optimize(
-      parentSMTSolver, target, OptimizationObjective::MAXIMIZE);
+  return this->optimize(optChecker, target, false);
 }
 
 }  // namespace cvc5::omt
index 48d16249460195c1519a3c9ffac06ab63f51cc96..34605cc71267b27b5c53c89e112bac982c9e5158 100644 (file)
@@ -28,20 +28,21 @@ class OMTOptimizerInteger : public OMTOptimizer
  public:
   OMTOptimizerInteger() = default;
   virtual ~OMTOptimizerInteger() = default;
-  smt::OptimizationResult minimize(SmtEngine* parentSMTSolver,
+  smt::OptimizationResult minimize(SmtEngine* optChecker,
                                    TNode target) override;
-  smt::OptimizationResult maximize(SmtEngine* parentSMTSolver,
+  smt::OptimizationResult maximize(SmtEngine* optChecker,
                                    TNode target) override;
 
  private:
   /**
    * Handles the optimization query specified by objType
-   * (objType = MINIMIZE / MAXIMIZE)
+   * isMinimize = true will trigger minimization, 
+   * otherwise trigger maximization
    **/
   smt::OptimizationResult optimize(
-      SmtEngine* parentSMTSolver,
+      SmtEngine* optChecker,
       TNode target,
-      smt::OptimizationObjective::ObjectiveType objType);
+      bool isMinimize);
 };
 
 }  // namespace cvc5::omt
index 49b07fe26ab8d206900883dd27375be0e5ae0f3f..bcf84cb538181fc4f1e1f375929e03ccdad3e81c 100644 (file)
 
 #include "omt/bitvector_optimizer.h"
 #include "omt/integer_optimizer.h"
-#include "options/smt_options.h"
-#include "smt/smt_engine.h"
-#include "theory/quantifiers/quantifiers_attributes.h"
-#include "theory/smt_engine_subsolver.h"
 
 using namespace cvc5::theory;
 using namespace cvc5::smt;
 namespace cvc5::omt {
 
-std::unique_ptr<OMTOptimizer> OMTOptimizer::getOptimizerForNode(TNode targetNode,
-                                                                bool isSigned)
+bool OMTOptimizer::nodeSupportsOptimization(TNode node)
+{
+  TypeNode type = node.getType();
+  // only supports Integer and BitVectors as of now
+  return (type.isInteger() || type.isBitVector());
+}
+
+std::unique_ptr<OMTOptimizer> OMTOptimizer::getOptimizerForObjective(
+    OptimizationObjective& objective)
 {
   // the datatype of the target node
-  TypeNode objectiveType = targetNode.getType(true);
+  TypeNode objectiveType = objective.getTarget().getType(true);
   if (objectiveType.isInteger())
   {
     // integer type: use OMTOptimizerInteger
@@ -39,7 +42,8 @@ std::unique_ptr<OMTOptimizer> OMTOptimizer::getOptimizerForNode(TNode targetNode
   else if (objectiveType.isBitVector())
   {
     // bitvector type: use OMTOptimizerBitVector
-    return std::unique_ptr<OMTOptimizer>(new OMTOptimizerBitVector(isSigned));
+    return std::unique_ptr<OMTOptimizer>(
+        new OMTOptimizerBitVector(objective.bvIsSigned()));
   }
   else
   {
@@ -47,24 +51,126 @@ std::unique_ptr<OMTOptimizer> OMTOptimizer::getOptimizerForNode(TNode targetNode
   }
 }
 
-std::unique_ptr<SmtEngine> OMTOptimizer::createOptCheckerWithTimeout(
-    SmtEngine* parentSMTSolver, bool needsTimeout, unsigned long timeout)
+Node OMTOptimizer::mkStrongIncrementalExpression(
+    NodeManager* nm, TNode lhs, TNode rhs, OptimizationObjective& objective)
+{
+  constexpr const char lhsTypeError[] =
+      "lhs type does not match or is not implicitly convertable to the target "
+      "type";
+  constexpr const char rhsTypeError[] =
+      "rhs type does not match or is not implicitly convertable to the target "
+      "type";
+  TypeNode targetType = objective.getTarget().getType();
+  switch (objective.getType())
+  {
+    case OptimizationObjective::MINIMIZE:
+    {
+      if (targetType.isInteger())
+      {
+        Assert(lhs.getType().isInteger()) << lhsTypeError;
+        Assert(rhs.getType().isInteger()) << rhsTypeError;
+        return nm->mkNode(Kind::LT, lhs, rhs);
+      }
+      else if (targetType.isBitVector())
+      {
+        Assert(lhs.getType() == targetType) << lhsTypeError;
+        Assert(rhs.getType() == targetType) << rhsTypeError;
+        return (objective.bvIsSigned())
+                   ? (nm->mkNode(Kind::BITVECTOR_SLT, lhs, rhs))
+                   : (nm->mkNode(Kind::BITVECTOR_ULT, lhs, rhs));
+      }
+      else
+      {
+        Unimplemented() << "Target type does not support optimization";
+      }
+    }
+    case OptimizationObjective::MAXIMIZE:
+    {
+      if (targetType.isInteger())
+      {
+        Assert(lhs.getType().isInteger()) << lhsTypeError;
+        Assert(rhs.getType().isInteger()) << rhsTypeError;
+        return nm->mkNode(Kind::GT, lhs, rhs);
+      }
+      else if (targetType.isBitVector())
+      {
+        Assert(lhs.getType() == targetType) << lhsTypeError;
+        Assert(rhs.getType() == targetType) << rhsTypeError;
+        return (objective.bvIsSigned())
+                   ? (nm->mkNode(Kind::BITVECTOR_SGT, lhs, rhs))
+                   : (nm->mkNode(Kind::BITVECTOR_UGT, lhs, rhs));
+      }
+      else
+      {
+        Unimplemented() << "Target type does not support optimization";
+      }
+    }
+    default:
+      CVC5_FATAL() << "Optimization objective is neither MAXIMIZE nor MINIMIZE";
+  }
+  Unreachable();
+}
+
+Node OMTOptimizer::mkWeakIncrementalExpression(NodeManager* nm,
+                                               TNode lhs,
+                                               TNode rhs,
+                                               OptimizationObjective& objective)
 {
-  std::unique_ptr<SmtEngine> optChecker;
-  // initializeSubSolver will copy the options and theories enabled
-  // from the current solver to optChecker and adds timeout
-  theory::initializeSubsolver(optChecker, needsTimeout, timeout);
-  // we need to be in incremental mode for multiple objectives since we need to
-  // push pop we need to produce models to inrement on our objective
-  optChecker->setOption("incremental", "true");
-  optChecker->setOption("produce-models", "true");
-  // Move assertions from the parent solver to the subsolver
-  std::vector<Node> p_assertions = parentSMTSolver->getExpandedAssertions();
-  for (const Node& e : p_assertions)
+  constexpr const char lhsTypeError[] =
+      "lhs type does not match or is not implicitly convertable to the target "
+      "type";
+  constexpr const char rhsTypeError[] =
+      "rhs type does not match or is not implicitly convertable to the target "
+      "type";
+  TypeNode targetType = objective.getTarget().getType();
+  switch (objective.getType())
   {
-    optChecker->assertFormula(e);
+    case OptimizationObjective::MINIMIZE:
+    {
+      if (targetType.isInteger())
+      {
+        Assert(lhs.getType().isInteger()) << lhsTypeError;
+        Assert(rhs.getType().isInteger()) << rhsTypeError;
+        return nm->mkNode(Kind::LEQ, lhs, rhs);
+      }
+      else if (targetType.isBitVector())
+      {
+        Assert(lhs.getType() == targetType) << lhsTypeError;
+        Assert(rhs.getType() == targetType) << rhsTypeError;
+        return (objective.bvIsSigned())
+                   ? (nm->mkNode(Kind::BITVECTOR_SLE, lhs, rhs))
+                   : (nm->mkNode(Kind::BITVECTOR_ULE, lhs, rhs));
+      }
+      else
+      {
+        Unimplemented() << "Target type does not support optimization";
+      }
+    }
+    case OptimizationObjective::MAXIMIZE:
+    {
+      if (targetType.isInteger())
+      {
+        Assert(lhs.getType().isInteger()) << lhsTypeError;
+        Assert(rhs.getType().isInteger()) << rhsTypeError;
+        return nm->mkNode(Kind::GEQ, lhs, rhs);
+      }
+      else if (targetType.isBitVector())
+      {
+        Assert(lhs.getType() == targetType) << lhsTypeError;
+        Assert(rhs.getType() == targetType) << rhsTypeError;
+        return (objective.bvIsSigned())
+                   ? (nm->mkNode(Kind::BITVECTOR_SGE, lhs, rhs))
+                   : (nm->mkNode(Kind::BITVECTOR_UGE, lhs, rhs));
+      }
+      else
+      {
+        Unimplemented() << "Target type does not support optimization";
+      }
+    }
+    default:
+      CVC5_FATAL() << "Optimization objective is neither MAXIMIZE nor MINIMIZE";
   }
-  return optChecker;
+  Unreachable();
 }
 
 }  // namespace cvc5::omt
index 792a60169262d97a197ebb5c762d43414fb9afb3..1052865b0f878e6e0487dac947a3de1811b5f1c5 100644 (file)
@@ -27,54 +27,96 @@ class OMTOptimizer
 {
  public:
   virtual ~OMTOptimizer() = default;
+
   /**
-   * Given a target node, retrieve an optimizer specific for the node's type
-   * the second field isSigned specifies whether we should use signed comparison
-   * for BitVectors and it's only valid when the type is BitVector
-   *
-   * @param targetNode the target node for the expression to be optimized
-   * @param isSigned speficies whether to use signed comparison for BitVectors
-   *   and it's only valid when the type of targetNode is BitVector
+   * Returns whether node supports optimization
+   * Currently supported: BitVectors, Integers (preliminary).
+   * @param node the target node to check for optimizability
+   * @return whether node supports optimization
+   **/
+  static bool nodeSupportsOptimization(TNode node);
+
+  /**
+   * Given an optimization objective,
+   * retrieve an optimizer specific for the optimization target
+   * @param objective the an OptimizationObjective object containing
+   *   the optimization target, whether it's maximized or minimized
+   *   and whether it's signed for BV (only applies when the target type is BV)
    * @return a unique_pointer pointing to a derived class of OMTOptimizer
    *   and this is the optimizer for targetNode
    **/
-  static std::unique_ptr<OMTOptimizer> getOptimizerForNode(
-      TNode targetNode, bool isSigned = false);
+  static std::unique_ptr<OMTOptimizer> getOptimizerForObjective(
+      smt::OptimizationObjective& objective);
+
+  /**
+   * Given the lhs and rhs expressions, with an optimization objective,
+   * makes an incremental expression stating that
+   *   lhs `better than` rhs
+   * under the context specified by objective
+   * for minimize, it would be lhs < rhs
+   * for maximize, it would be lhs > rhs
+   *
+   * Note: the types of lhs and rhs nodes must match or be convertable
+   *   to the type of the optimization target!
+   *
+   * @param nm the NodeManager to manage the made expression
+   * @param lhs the left hand side of the expression
+   * @param rhs the right hand side of the expression
+   * @param objective the optimization objective
+   *   stating whether it's maximize / minimize etc.
+   * @return an expression stating lhs `better than` rhs,
+   **/
+  static Node mkStrongIncrementalExpression(
+      NodeManager* nm,
+      TNode lhs,
+      TNode rhs,
+      smt::OptimizationObjective& objective);
 
   /**
-   * Initialize an SMT subsolver for offline optimization purpose
-   * @param parentSMTSolver the parental solver containing the assertions
-   * @param needsTimeout specifies whether it needs timeout for each single
-   *    query
-   * @param timeout the timeout value, given in milliseconds (ms)
-   * @return a unique_pointer of SMT subsolver
+   * Given the lhs and rhs expressions, with an optimization objective,
+   * makes an incremental expression stating that
+   *   lhs `better than or equal to` rhs
+   * under the context specified by objective
+   * for minimize, it would be lhs <= rhs
+   * for maximize, it would be lhs >= rhs
+   *
+   * Note: the types of lhs and rhs nodes must match or be convertable
+   *   to the type of the optimization target!
+   *
+   * @param nm the NodeManager to manage the made expression
+   * @param lhs the left hand side of the expression
+   * @param rhs the right hand side of the expression
+   * @param objective the optimization objective
+   *   stating whether it's maximize / minimize etc.
+   * @return an expression stating lhs `better than or equal to` rhs,
    **/
-  static std::unique_ptr<SmtEngine> createOptCheckerWithTimeout(
-      SmtEngine* parentSMTSolver,
-      bool needsTimeout = false,
-      unsigned long timeout = 0);
+  static Node mkWeakIncrementalExpression(
+      NodeManager* nm,
+      TNode lhs,
+      TNode rhs,
+      smt::OptimizationObjective& objective);
 
   /**
-   * Minimize the target node with constraints encoded in parentSMTSolver
+   * Minimize the target node with constraints encoded in optChecker
    *
-   * @param parentSMTSolver an SMT solver encoding the assertions as the
+   * @param optChecker an SMT solver encoding the assertions as the
    *   constraints
    * @param target the target expression to optimize
    * @return smt::OptimizationResult the result of optimization, containing
    *   whether it's optimal and the optimized value.
    **/
-  virtual smt::OptimizationResult minimize(SmtEngine* parentSMTSolver,
+  virtual smt::OptimizationResult minimize(SmtEngine* optChecker,
                                            TNode target) = 0;
   /**
-   * Maximize the target node with constraints encoded in parentSMTSolver
+   * Maximize the target node with constraints encoded in optChecker
    *
-   * @param parentSMTSolver an SMT solver encoding the assertions as the
+   * @param optChecker an SMT solver encoding the assertions as the
    *   constraints
    * @param target the target expression to optimize
    * @return smt::OptimizationResult the result of optimization, containing
    *   whether it's optimal and the optimized value.
    **/
-  virtual smt::OptimizationResult maximize(SmtEngine* parentSMTSolver,
+  virtual smt::OptimizationResult maximize(SmtEngine* optChecker,
                                            TNode target) = 0;
 };
 
index 1c8fe6514c097bf02c95a3ab82e01d3bae861c32..e66e4e2ca4460c63bc55acc63d8b318f2e968d51 100644 (file)
 #include "smt/optimization_solver.h"
 
 #include "omt/omt_optimizer.h"
+#include "options/smt_options.h"
 #include "smt/assertions.h"
+#include "smt/smt_engine.h"
+#include "theory/smt_engine_subsolver.h"
 
 using namespace cvc5::theory;
 using namespace cvc5::omt;
@@ -27,19 +30,22 @@ OptimizationResult::ResultType OptimizationSolver::checkOpt()
 {
   Assert(d_objectives.size() == 1);
   // NOTE: currently we are only dealing with single obj
-  std::unique_ptr<OMTOptimizer> optimizer = OMTOptimizer::getOptimizerForNode(
-      d_objectives[0].getTarget(), d_objectives[0].bvIsSigned());
+  std::unique_ptr<OMTOptimizer> optimizer =
+      OMTOptimizer::getOptimizerForObjective(d_objectives[0]);
 
   if (!optimizer) return OptimizationResult::UNSUPPORTED;
 
   OptimizationResult optResult;
+  std::unique_ptr<SmtEngine> optChecker = createOptCheckerWithTimeout(d_parent);
   if (d_objectives[0].getType() == OptimizationObjective::MAXIMIZE)
   {
-    optResult = optimizer->maximize(d_parent, d_objectives[0].getTarget());
+    optResult =
+        optimizer->maximize(optChecker.get(), d_objectives[0].getTarget());
   }
   else if (d_objectives[0].getType() == OptimizationObjective::MINIMIZE)
   {
-    optResult = optimizer->minimize(d_parent, d_objectives[0].getTarget());
+    optResult =
+        optimizer->minimize(optChecker.get(), d_objectives[0].getTarget());
   }
 
   d_results[0] = optResult;
@@ -65,5 +71,25 @@ std::vector<OptimizationResult> OptimizationSolver::getValues()
   return d_results;
 }
 
+std::unique_ptr<SmtEngine> OptimizationSolver::createOptCheckerWithTimeout(
+    SmtEngine* parentSMTSolver, bool needsTimeout, unsigned long timeout)
+{
+  std::unique_ptr<SmtEngine> optChecker;
+  // initializeSubSolver will copy the options and theories enabled
+  // from the current solver to optChecker and adds timeout
+  theory::initializeSubsolver(optChecker, needsTimeout, timeout);
+  // we need to be in incremental mode for multiple objectives since we need to
+  // push pop we need to produce models to inrement on our objective
+  optChecker->setOption("incremental", "true");
+  optChecker->setOption("produce-models", "true");
+  // Move assertions from the parent solver to the subsolver
+  std::vector<Node> p_assertions = parentSMTSolver->getExpandedAssertions();
+  for (const Node& e : p_assertions)
+  {
+    optChecker->assertFormula(e);
+  }
+  return optChecker;
+}
+
 }  // namespace smt
 }  // namespace cvc5
index 0babd7a4a7be65b5349cfd30192c212ebdaec720..3037c29248e74e7be283632efcb5e6c90b8512b1 100644 (file)
@@ -208,6 +208,19 @@ class OptimizationSolver
   std::vector<OptimizationResult> getValues();
 
  private:
+  /**
+   * Initialize an SMT subsolver for offline optimization purpose
+   * @param parentSMTSolver the parental solver containing the assertions
+   * @param needsTimeout specifies whether it needs timeout for each single
+   *    query
+   * @param timeout the timeout value, given in milliseconds (ms)
+   * @return a unique_pointer of SMT subsolver
+   **/
+  static std::unique_ptr<SmtEngine> createOptCheckerWithTimeout(
+      SmtEngine* parentSMTSolver,
+      bool needsTimeout = false,
+      unsigned long timeout = 0);
+
   /** The parent SMT engine **/
   SmtEngine* d_parent;