Add the print benchmark utility (#7196)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 1 Oct 2021 13:58:14 +0000 (08:58 -0500)
committerGitHub <noreply@github.com>
Fri, 1 Oct 2021 13:58:14 +0000 (13:58 +0000)
This utility is capable of printing a vector of Node as a valid (SMT-LIB) benchmark with no prior bookkeeping. It also optionally allows for taking a vector Node corresponding to define-fun.

It will be used to replace the old internal benchmark dumping infrastructure which was error prone.

src/smt/print_benchmark.cpp [new file with mode: 0644]
src/smt/print_benchmark.h [new file with mode: 0644]

diff --git a/src/smt/print_benchmark.cpp b/src/smt/print_benchmark.cpp
new file mode 100644 (file)
index 0000000..c1913e2
--- /dev/null
@@ -0,0 +1,278 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Print benchmark utility.
+ */
+
+#include "smt/print_benchmark.h"
+
+#include "expr/dtype.h"
+#include "expr/node_algorithm.h"
+#include "printer/printer.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace smt {
+
+void PrintBenchmark::printAssertions(std::ostream& out,
+                                     const std::vector<Node>& defs,
+                                     const std::vector<Node>& assertions)
+{
+  std::unordered_set<TypeNode> types;
+  std::unordered_set<TNode> typeVisited;
+  for (const Node& a : defs)
+  {
+    expr::getTypes(a, types, typeVisited);
+  }
+  for (const Node& a : assertions)
+  {
+    expr::getTypes(a, types, typeVisited);
+  }
+  // print the declared types first
+  std::unordered_set<TypeNode> alreadyPrintedDeclSorts;
+  for (const TypeNode& st : types)
+  {
+    // note that we must get all "component types" of a type, so that
+    // e.g. U is printed as a sort declaration when we have type (Array U Int).
+    std::unordered_set<TypeNode> ctypes;
+    expr::getComponentTypes(st, ctypes);
+    for (const TypeNode& stc : ctypes)
+    {
+      // get all connected datatypes to this one
+      std::vector<TypeNode> connectedTypes;
+      getConnectedSubfieldTypes(stc, connectedTypes, alreadyPrintedDeclSorts);
+      // now, separate into sorts and datatypes
+      std::vector<TypeNode> datatypeBlock;
+      for (const TypeNode& ctn : connectedTypes)
+      {
+        if (stc.isSort())
+        {
+          d_printer->toStreamCmdDeclareType(out, stc);
+        }
+        else if (stc.isDatatype())
+        {
+          datatypeBlock.push_back(ctn);
+        }
+      }
+      // print the mutually recursive datatype block if necessary
+      if (!datatypeBlock.empty())
+      {
+        d_printer->toStreamCmdDatatypeDeclaration(out, datatypeBlock);
+      }
+    }
+  }
+
+  // global visited cache for expr::getSymbols calls
+  std::unordered_set<TNode> visited;
+
+  // print the definitions
+  std::unordered_map<Node, std::pair<bool, Node>> defMap;
+  std::vector<Node> defSyms;
+  // first, record all the defined symbols
+  for (const Node& a : defs)
+  {
+    bool isRec;
+    Node defSym;
+    Node defBody;
+    decomposeDefinition(a, isRec, defSym, defBody);
+    if (!defSym.isNull())
+    {
+      Assert(defMap.find(defSym) == defMap.end());
+      defMap[defSym] = std::pair<bool, Node>(isRec, defBody);
+      defSyms.push_back(defSym);
+    }
+  }
+  // go back and print the definitions
+  std::unordered_set<Node> alreadyPrintedDecl;
+  std::unordered_set<Node> alreadyPrintedDef;
+
+  std::unordered_map<Node, std::pair<bool, Node>>::const_iterator itd;
+  for (const Node& s : defSyms)
+  {
+    std::vector<Node> recDefs;
+    std::vector<Node> ordinaryDefs;
+    std::unordered_set<Node> syms;
+    getConnectedDefinitions(
+        s, recDefs, ordinaryDefs, syms, defMap, alreadyPrintedDef, visited);
+    // print the declarations that are encountered for the first time in this
+    // block
+    printDeclaredFuns(out, syms, alreadyPrintedDecl);
+    // print the ordinary definitions
+    for (const Node& f : ordinaryDefs)
+    {
+      itd = defMap.find(f);
+      Assert(itd != defMap.end());
+      Assert(!itd->second.first);
+      d_printer->toStreamCmdDefineFunction(out, f, itd->second.second);
+      // a definition is also a declaration
+      alreadyPrintedDecl.insert(f);
+    }
+    // print a recursive function definition block
+    if (!recDefs.empty())
+    {
+      std::vector<Node> lambdas;
+      for (const Node& f : recDefs)
+      {
+        lambdas.push_back(defMap[f].second);
+        // a recursive definition is also a declaration
+        alreadyPrintedDecl.insert(f);
+      }
+      d_printer->toStreamCmdDefineFunctionRec(out, recDefs, lambdas);
+    }
+  }
+
+  // print the remaining declared symbols
+  std::unordered_set<Node> syms;
+  for (const Node& a : assertions)
+  {
+    expr::getSymbols(a, syms, visited);
+  }
+  printDeclaredFuns(out, syms, alreadyPrintedDecl);
+
+  // print the assertions
+  for (const Node& a : assertions)
+  {
+    d_printer->toStreamCmdAssert(out, a);
+  }
+}
+void PrintBenchmark::printAssertions(std::ostream& out,
+                                     const std::vector<Node>& assertions)
+{
+  std::vector<Node> defs;
+  printAssertions(out, defs, assertions);
+}
+
+void PrintBenchmark::printDeclaredFuns(std::ostream& out,
+                                       const std::unordered_set<Node>& funs,
+                                       std::unordered_set<Node>& alreadyPrinted)
+{
+  for (const Node& f : funs)
+  {
+    Assert(f.isVar());
+    if (alreadyPrinted.find(f) == alreadyPrinted.end())
+    {
+      d_printer->toStreamCmdDeclareFunction(out, f);
+    }
+  }
+  alreadyPrinted.insert(funs.begin(), funs.end());
+}
+
+void PrintBenchmark::getConnectedSubfieldTypes(
+    TypeNode tn,
+    std::vector<TypeNode>& connectedTypes,
+    std::unordered_set<TypeNode>& processed)
+{
+  if (processed.find(tn) != processed.end())
+  {
+    return;
+  }
+  processed.insert(tn);
+  if (tn.isSort())
+  {
+    connectedTypes.push_back(tn);
+  }
+  else if (tn.isDatatype())
+  {
+    connectedTypes.push_back(tn);
+    std::unordered_set<TypeNode> subfieldTypes =
+        tn.getDType().getSubfieldTypes();
+    for (const TypeNode& ctn : subfieldTypes)
+    {
+      getConnectedSubfieldTypes(ctn, connectedTypes, processed);
+    }
+  }
+}
+
+void PrintBenchmark::getConnectedDefinitions(
+    Node n,
+    std::vector<Node>& recDefs,
+    std::vector<Node>& ordinaryDefs,
+    std::unordered_set<Node>& syms,
+    const std::unordered_map<Node, std::pair<bool, Node>>& defMap,
+    std::unordered_set<Node>& processedDefs,
+    std::unordered_set<TNode>& visited)
+{
+  // does it have a definition?
+  std::unordered_map<Node, std::pair<bool, Node>>::const_iterator it =
+      defMap.find(n);
+  if (it == defMap.end())
+  {
+    // an ordinary declared symbol
+    syms.insert(n);
+    return;
+  }
+  if (processedDefs.find(n) != processedDefs.end())
+  {
+    return;
+  }
+  processedDefs.insert(n);
+  if (!it->second.first)
+  {
+    // an ordinary define-fun symbol
+    ordinaryDefs.push_back(n);
+  }
+  else
+  {
+    // a recursively defined symbol
+    recDefs.push_back(n);
+    // get the symbols in the body
+    std::unordered_set<Node> symsBody;
+    expr::getSymbols(it->second.second, symsBody, visited);
+    for (const Node& s : symsBody)
+    {
+      getConnectedDefinitions(
+          s, recDefs, ordinaryDefs, syms, defMap, processedDefs, visited);
+    }
+  }
+}
+
+bool PrintBenchmark::decomposeDefinition(Node a,
+                                         bool& isRecDef,
+                                         Node& sym,
+                                         Node& body)
+{
+  if (a.getKind() == EQUAL && a[0].isVar())
+  {
+    // an ordinary define-fun
+    isRecDef = false;
+    sym = a[0];
+    body = a[1];
+    return true;
+  }
+  else if (a.getKind() == FORALL && a[1].getKind() == EQUAL
+           && a[1][0].getKind() == APPLY_UF)
+  {
+    isRecDef = true;
+    sym = a[1][0].getOperator();
+    body = NodeManager::currentNM()->mkNode(LAMBDA, a[0], a[1][1]);
+    return true;
+  }
+  else
+  {
+    Warning() << "Unhandled definition: " << a << std::endl;
+  }
+  return false;
+}
+
+void PrintBenchmark::printBenchmark(std::ostream& out,
+                                    const std::string& logic,
+                                    const std::vector<Node>& defs,
+                                    const std::vector<Node>& assertions)
+{
+  d_printer->toStreamCmdSetBenchmarkLogic(out, logic);
+  printAssertions(out, defs, assertions);
+  d_printer->toStreamCmdCheckSat(out);
+}
+
+}  // namespace smt
+}  // namespace cvc5
diff --git a/src/smt/print_benchmark.h b/src/smt/print_benchmark.h
new file mode 100644 (file)
index 0000000..387b3cc
--- /dev/null
@@ -0,0 +1,137 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Print benchmark utility.
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__SMT__PRINT_BENCHMARK_H
+#define CVC5__SMT__PRINT_BENCHMARK_H
+
+#include <iosfwd>
+#include <vector>
+
+#include "expr/node.h"
+
+namespace cvc5 {
+
+class Printer;
+
+namespace smt {
+
+/**
+ * A utility for printing a benchmark. This utility requires no bookkeeping
+ * about which commands have been executed. It reconstructs the set of
+ * commands that would have been required for generating a benchmark based on
+ * a list of nodes.
+ */
+class PrintBenchmark
+{
+ public:
+  PrintBenchmark(const Printer* p) : d_printer(p) {}
+  /**
+   * Print assertions. This prints a parsable set of commands on the output
+   * stream out that defines (recursive) functions in defs, and asserts
+   * assertions. It does not print a set-logic or check-sat command.
+   *
+   * Each node in defs is either of the form:
+   * (1) (= s t), where s is a (non-recursively) defined function, where
+   * the term t may be a lambda if s has non-zero arity.
+   * (2) (forall V (= (s V) t)), where s is a recursively defined function.
+   */
+  void printAssertions(std::ostream& out,
+                       const std::vector<Node>& defs,
+                       const std::vector<Node>& assertions);
+  /**
+   * Print assertions, without special handling of defined functions.
+   */
+  void printAssertions(std::ostream& out, const std::vector<Node>& assertions);
+
+  /**
+   * Print benchmark, which prints a parsable benchmark on the output stream
+   * out. It relies on the printAssertions method above, as well as printing
+   * the logic based on given string and a final check-sat command.
+   *
+   * For the best printing, defs should be given in the order in which
+   * the symbols were declared. If this is not the case, then we may e.g.
+   * group blocks of definitions that were not grouped in the input.
+   */
+  void printBenchmark(std::ostream& out,
+                      const std::string& logic,
+                      const std::vector<Node>& defs,
+                      const std::vector<Node>& assertions);
+
+ private:
+  /**
+   * print declared symbols in funs but not processed; updates processed to
+   * include what was printed
+   */
+  void printDeclaredFuns(std::ostream& out,
+                         const std::unordered_set<Node>& funs,
+                         std::unordered_set<Node>& processed);
+  /**
+   * Get the connected types. This traverses subfield types of datatypes and
+   * adds to connectedTypes everything that is necessary for printing tn.
+   *
+   * @param tn The type to traverse
+   * @param connectedTypes The types that tn depends on
+   * @param process The types we have already processed. We update this set
+   * with those added to connectedTypes.
+   */
+  void getConnectedSubfieldTypes(TypeNode tn,
+                                 std::vector<TypeNode>& connectedTypes,
+                                 std::unordered_set<TypeNode>& processed);
+  /**
+   * Get connected definitions for symbol v.
+   *
+   * @param recDefs The recursive function definitions that v depends on
+   * @param ordinaryDefs The non-recursive definitions that v depends on
+   * @param syms The declared symbols that v depends on
+   * @param defMap Map from symbols to their definitions
+   * @param processedDefs The (recursive or non-recursive) definitions we have
+   * processed already. We update this with symbols we add to recDefs and
+   * ordinaryDefs.
+   * @param visited The set of terms we have already visited when searching for
+   * free symbols. This set is updated for the bodies of definitions processed
+   * in this call.
+   */
+  void getConnectedDefinitions(
+      Node v,
+      std::vector<Node>& recDefs,
+      std::vector<Node>& ordinaryDefs,
+      std::unordered_set<Node>& syms,
+      const std::unordered_map<Node, std::pair<bool, Node>>& defMap,
+      std::unordered_set<Node>& processedDefs,
+      std::unordered_set<TNode>& visited);
+  /**
+   * Decompose definition assertion a.
+   *
+   * @param a The definition assertion
+   * @param isRecDef Updated to true if a is a recursive function definition (a
+   * quantified formula)
+   * @param sym Updated to the symbol that a defines
+   * @param body Update to the term that defines sym
+   * @return true if the definition was successfully inferred
+   */
+  bool decomposeDefinition(Node a, bool& isRecDef, Node& sym, Node& body);
+  /**
+   * Pointer to the printer we are using, which is responsible for printing
+   * individual commands.
+   */
+  const Printer* d_printer;
+};
+
+}  // namespace smt
+}  // namespace cvc5
+
+#endif /* CVC5__SMT__PRINT_BENCHMARK_H */