Make LFSC printer robust to internal types (#8616)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 14 Apr 2022 21:13:04 +0000 (16:13 -0500)
committerGitHub <noreply@github.com>
Thu, 14 Apr 2022 21:13:04 +0000 (21:13 +0000)
This makes the LFSC node converter track the "user declared" symbols and types that it encounters.

It furthermore makes the "dry run" phase of proof printing happen before types and symbols are declared, so that all declared symbols are found before the preamble of LFSC proofs are printed.

These changes are specifically to fix cases where a internal type is generated that does not appear in the input. For example, some preprocessing passes may construct auxiliary uninterpreted sorts.

This fixes 6 more LFSC failures from our regressions.

src/proof/lfsc/lfsc_node_converter.cpp
src/proof/lfsc/lfsc_node_converter.h
src/proof/lfsc/lfsc_printer.cpp

index 25ae48d9213244280fc97d7a6dbb4b64d2a95f63..12302bd7adcda659288983a98fc1105e6b5af92b 100644 (file)
@@ -156,7 +156,8 @@ Node LfscNodeConverter::postConvert(Node n)
     }
     // Otherwise, it is an uncategorized skolem, must use a fresh variable.
     // This case will only apply for terms originating from places with no
-    // proof support.
+    // proof support. Note it is not added as a declared variable, instead it
+    // is used as (var N T) throughout.
     TypeNode intType = nm->integerType();
     TypeNode varType = nm->mkFunctionType({intType, d_sortType}, tn);
     Node var = mkInternalSymbol("var", varType);
@@ -166,6 +167,7 @@ Node LfscNodeConverter::postConvert(Node n)
   }
   else if (n.isVar())
   {
+    d_declVars.insert(n);
     return mkInternalSymbol(getNameForUserNameOf(n), tn);
   }
   else if (k == CARDINALITY_CONSTRAINT)
@@ -188,6 +190,9 @@ Node LfscNodeConverter::postConvert(Node n)
   else if (k == APPLY_CONSTRUCTOR || k == APPLY_SELECTOR || k == APPLY_TESTER
            || k == APPLY_UPDATER)
   {
+    // must add to declared types
+    const DType& dt = DType::datatypeOf(n.getOperator());
+    d_declTypes.insert(dt.getTypeNode());
     // must convert other kinds of apply to functions, since we convert to
     // HO_APPLY
     Node opc = getOperatorOfTerm(n, true);
@@ -533,6 +538,8 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
   }
   else if (tn.getNumChildren() == 0)
   {
+    // an uninterpreted sort, or an uninstantiatied (maybe parametric) datatype
+    d_declTypes.insert(tn);
     // special case: tuples must be distinguished by their arity
     if (tn.isTuple())
     {
@@ -593,6 +600,8 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
     Node op;
     if (k == PARAMETRIC_DATATYPE)
     {
+      // note we don't add to declared types here, since the parametric
+      // datatype is traversed and will be declared as a type constructor
       // erase first child, which repeats the datatype
       targs.erase(targs.begin(), targs.begin() + 1);
       types.erase(types.begin(), types.begin() + 1);
@@ -605,6 +614,10 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
     }
     else if (k == SORT_TYPE)
     {
+      // Add its uninterpreted sort constructor to the list of declared types.
+      // This is required since the (type) operator is not part of the AST of
+      // the TypeNode.
+      d_declTypes.insert(tn.getUninterpretedSortConstructor());
       TypeNode ftype = nm->mkFunctionType(types, d_sortType);
       std::string name;
       tn.getAttribute(expr::VarNameAttr(), name);
@@ -1214,5 +1227,15 @@ size_t LfscNodeConverter::getOrAssignIndexForVar(Node v)
   return id;
 }
 
+const std::unordered_set<Node>& LfscNodeConverter::getDeclaredSymbols() const
+{
+  return d_declVars;
+}
+
+const std::unordered_set<TypeNode>& LfscNodeConverter::getDeclaredTypes() const
+{
+  return d_declTypes;
+}
+
 }  // namespace proof
 }  // namespace cvc5::internal
index f7fcc6b7ede8e8b9e4c820fd09f39c92079726af..5d2861af2d017eee89a85e1123a61eabae6b0a33 100644 (file)
@@ -115,6 +115,10 @@ class LfscNodeConverter : public NodeConverter
                                         size_t variant = 0);
   /** get name for the name of node v, where v should be a variable */
   std::string getNameForUserNameOf(Node v);
+  /** Get the declared symbols (variables) that we have converted */
+  const std::unordered_set<Node>& getDeclaredSymbols() const;
+  /** Get the declared types that we have converted */
+  const std::unordered_set<TypeNode>& getDeclaredTypes() const;
 
  private:
   /** Should we traverse n? */
@@ -180,6 +184,10 @@ class LfscNodeConverter : public NodeConverter
   std::map<TypeNode, Node> d_typeAsNode;
   /** Used for interpreted builtin parametric sorts */
   std::map<Kind, Node> d_typeKindToNodeCons;
+  /** The set of declared variables */
+  std::unordered_set<Node> d_declVars;
+  /** The set of declared types */
+  std::unordered_set<TypeNode> d_declTypes;
 };
 
 }  // namespace proof
index e41813ee376b593ede631d0a5ebb76241a1ab65c..a2cabd7b6706304550059f44fe58bdfcf2f8efb0 100644 (file)
@@ -53,29 +53,84 @@ void LfscPrinter::print(std::ostream& out,
   // clear the rules we have warned about
   d_trustWarned.clear();
 
+  // [1] convert assertions to internal and set up assumption map
   Trace("lfsc-print-debug") << "; print declarations" << std::endl;
-  // [1] compute and print the declarations
-  std::unordered_set<Node> syms;
-  std::unordered_set<TNode> visited;
   std::vector<Node> iasserts;
   std::map<Node, size_t> passumeMap;
-  std::unordered_set<TypeNode> types;
-  std::unordered_set<TNode> typeVisited;
   for (size_t i = 0, nasserts = assertions.size(); i < nasserts; i++)
   {
     Node a = assertions[i];
-    expr::getSymbols(a, syms, visited);
-    expr::getTypes(a, types, typeVisited);
     iasserts.push_back(d_tproc.convert(a));
     // remember the assumption name
     passumeMap[a] = i;
   }
   d_assumpCounter = assertions.size();
-  Trace("lfsc-print-debug") << "; print sorts" << std::endl;
-  // [1a] user declared sorts
+
+  // [2] compute the proof letification
+  Trace("lfsc-print-debug") << "; compute proof letification" << std::endl;
+  std::vector<const ProofNode*> pletList;
+  std::map<const ProofNode*, size_t> pletMap;
+  computeProofLetification(pnBody, pletList, pletMap);
+
+  // [3] compute the global term letification and declared symbols and types
+  Trace("lfsc-print-debug")
+      << "; compute global term letification and declared symbols" << std::endl;
+  LetBinding lbind;
+  for (const Node& ia : iasserts)
+  {
+    lbind.process(ia);
+  }
+  // We do a "dry-run" of proof printing here, using the LetBinding print
+  // channel. This pass traverses the proof but does not print it, but instead
+  // updates the let binding data structure for all nodes that appear anywhere
+  // in the proof. It is also important for the term processor for collecting
+  // symbols and types that are used in the proof.
+  LfscPrintChannelPre lpcp(lbind);
+  LetBinding emptyLetBind;
+  std::map<const ProofNode*, size_t>::iterator itp;
+  for (const ProofNode* p : pletList)
+  {
+    itp = pletMap.find(p);
+    Assert(itp != pletMap.end());
+    size_t pid = itp->second;
+    pletMap.erase(p);
+    printProofInternal(&lpcp, p, emptyLetBind, pletMap, passumeMap);
+    pletMap[p] = pid;
+  }
+  // Print the body of the outermost scope
+  printProofInternal(&lpcp, pnBody, emptyLetBind, pletMap, passumeMap);
+
+  // [4] print declared sorts and symbols
+  // [4a] user declare function symbols
+  // Note that this is buffered into an output stream preambleSymDecl and then
+  // printed after types. We require printing the declared symbols here so that
+  // the set of collected declared types is complete at [4b].
+  Trace("lfsc-print-debug") << "; print user symbols" << std::endl;
+  std::stringstream preambleSymDecl;
+  const std::unordered_set<Node>& syms = d_tproc.getDeclaredSymbols();
+  for (const Node& s : syms)
+  {
+    TypeNode st = s.getType();
+    if (st.isDatatypeConstructor() || st.isDatatypeSelector()
+        || st.isDatatypeTester() || st.isDatatypeUpdater())
+    {
+      // constructors, selector, testers, updaters are defined by the datatype
+      continue;
+    }
+    Node si = d_tproc.convert(s);
+    preambleSymDecl << "(define " << si << " (var "
+                    << d_tproc.getOrAssignIndexForVar(s) << " ";
+    printType(preambleSymDecl, st);
+    preambleSymDecl << "))" << std::endl;
+  }
+  // [4b] user declared sorts
+  Trace("lfsc-print-debug") << "; print user sorts" << std::endl;
   std::stringstream preamble;
   std::unordered_set<TypeNode> sts;
   std::unordered_set<size_t> tupleArity;
+  // get the types from the term processor, which has seen all terms occurring
+  // in the proof at this point
+  const std::unordered_set<TypeNode>& types = d_tproc.getDeclaredTypes();
   for (const TypeNode& st : types)
   {
     // note that we must get all "component types" of a type, so that
@@ -113,66 +168,19 @@ void LfscPrinter::print(std::ostream& out,
     // shared selectors are instance of parametric symbol "sel"
     preamble << "; END DATATYPE " << std::endl;
   }
-  Trace("lfsc-print-debug") << "; print user symbols" << std::endl;
-  // [1b] user declare function symbols
-  for (const Node& s : syms)
-  {
-    TypeNode st = s.getType();
-    if (st.isDatatypeConstructor() || st.isDatatypeSelector()
-        || st.isDatatypeTester() || st.isDatatypeUpdater())
-    {
-      // constructors, selector, testers, updaters are defined by the datatype
-      continue;
-    }
-    Node si = d_tproc.convert(s);
-    preamble << "(define " << si << " (var "
-             << d_tproc.getOrAssignIndexForVar(s) << " ";
-    printType(preamble, st);
-    preamble << "))" << std::endl;
-  }
-
-  Trace("lfsc-print-debug") << "; compute proof letification" << std::endl;
-  // [2] compute the proof letification
-  std::vector<const ProofNode*> pletList;
-  std::map<const ProofNode*, size_t> pletMap;
-  computeProofLetification(pnBody, pletList, pletMap);
-
-  Trace("lfsc-print-debug") << "; compute term lets" << std::endl;
-  // compute the term lets
-  LetBinding lbind;
-  for (const Node& ia : iasserts)
-  {
-    lbind.process(ia);
-  }
-  // We do a "dry-run" of proof printing here, using the LetBinding print
-  // channel. This pass traverses the proof but does not print it, but instead
-  // updates the let binding data structure for all nodes that appear anywhere
-  // in the proof.
-  LfscPrintChannelPre lpcp(lbind);
-  LetBinding emptyLetBind;
-  std::map<const ProofNode*, size_t>::iterator itp;
-  for (const ProofNode* p : pletList)
-  {
-    itp = pletMap.find(p);
-    Assert(itp != pletMap.end());
-    size_t pid = itp->second;
-    pletMap.erase(p);
-    printProofInternal(&lpcp, p, emptyLetBind, pletMap, passumeMap);
-    pletMap[p] = pid;
-  }
-  // Print the body of the outermost scope
-  printProofInternal(&lpcp, pnBody, emptyLetBind, pletMap, passumeMap);
+  // [4c] user declared function symbols
+  preamble << preambleSymDecl.str();
 
-  // [3] print warnings
+  // [5] print warnings
   for (PfRule r : d_trustWarned)
   {
     out << "; WARNING: adding trust step for " << r << std::endl;
   }
 
-  // [4] print the DSL rewrite rule declarations
+  // [6] print the DSL rewrite rule declarations
   // TODO cvc5-projects #285.
 
-  // [5] print the check command and term lets
+  // [7] print the check command and term lets
   out << preamble.str();
   out << "(check" << std::endl;
   cparen << ")";
@@ -180,7 +188,7 @@ void LfscPrinter::print(std::ostream& out,
   printLetList(out, cparen, lbind);
 
   Trace("lfsc-print-debug") << "; print asserts" << std::endl;
-  // [6] print the assertions, with letification
+  // [8] print the assertions, with letification
   // the assumption identifier mapping
   for (size_t i = 0, nasserts = iasserts.size(); i < nasserts; i++)
   {
@@ -194,18 +202,19 @@ void LfscPrinter::print(std::ostream& out,
   }
 
   Trace("lfsc-print-debug") << "; print annotation" << std::endl;
-  // [7] print the annotation
+  // [9] print the annotation
   out << "(: (holds false)" << std::endl;
   cparen << ")";
 
   Trace("lfsc-print-debug") << "; print proof body" << std::endl;
-  // [8] print the proof body
+  // [10] print the proof body
   Assert(pn->getRule() == PfRule::SCOPE);
   // the outermost scope can be ignored (it is the scope of the assertions,
   // which are already printed above).
   LfscPrintChannelOut lout(out);
   printProofLetify(&lout, pnBody, lbind, pletList, pletMap, passumeMap);
 
+  // [11] print closing parantheses
   out << cparen.str() << std::endl;
 }
 
@@ -245,12 +254,12 @@ void LfscPrinter::printTypeDefinition(
   }
   else if (tn.isDatatype())
   {
-    const DType& dt = tn.getDType();
     if (tn.getKind() == PARAMETRIC_DATATYPE)
     {
       // skip the instance of a parametric datatype
       return;
     }
+    const DType& dt = tn.getDType();
     if (dt.isTuple())
     {
       const DTypeConstructor& cons = dt[0];