From: Andres Noetzli Date: Wed, 20 Apr 2022 21:00:31 +0000 (-0700) Subject: Improve handling of `(push)` and `(pop)` (#8641) X-Git-Tag: cvc5-1.0.1~238 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=918059c0700b438930a5f667272e168578f5ff7e;p=cvc5.git Improve handling of `(push)` and `(pop)` (#8641) This extends PushCommand and PopCommand to take a number of levels to push/pop. We have support for pushing an arbitrary number of levels at the API level, so this simplifies the parser code and makes dumping more precise (previously, we were dumping (push 2) as two (push 1) commands). --- diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index e66636e22..2a5dcf162 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -406,66 +406,25 @@ command [std::unique_ptr* cmd] GET_LEARNED_LITERALS_TOK { PARSER_STATE->checkThatLogicIsSet(); } { cmd->reset(new GetLearnedLiteralsCommand); } | /* push */ - PUSH_TOK { PARSER_STATE->checkThatLogicIsSet(); } + PUSH_TOK ( k=INTEGER_LITERAL - { unsigned num = AntlrInput::tokenToUnsigned(k); - if(num == 0) { - cmd->reset(new EmptyCommand()); - } else if(num == 1) { - PARSER_STATE->pushScope(true); - cmd->reset(new PushCommand()); - } else { - std::unique_ptr seq(new CommandSequence()); - do { - PARSER_STATE->pushScope(true); - Command* push_cmd = new PushCommand(); - push_cmd->setMuted(num > 1); - seq->addCommand(push_cmd); - --num; - } while(num > 0); - cmd->reset(seq.release()); - } + { + uint32_t num = AntlrInput::tokenToUnsigned(k); + *cmd = PARSER_STATE->handlePush(num); } - | { if(PARSER_STATE->strictModeEnabled()) { - PARSER_STATE->parseError( - "Strict compliance mode demands an integer to be provided to " - "PUSH. Maybe you want (push 1)?"); - } else { - PARSER_STATE->pushScope(true); - cmd->reset(new PushCommand()); - } - } ) - | POP_TOK { PARSER_STATE->checkThatLogicIsSet(); } + | { + *cmd = PARSER_STATE->handlePush(std::nullopt); + } + ) + | /* pop */ + POP_TOK ( k=INTEGER_LITERAL - { unsigned num = AntlrInput::tokenToUnsigned(k); - // we don't compare num to PARSER_STATE->scopeLevel() here, since - // when global declarations is true, the scope level of the parser - // is not indicative of the context level. - if(num == 0) { - cmd->reset(new EmptyCommand()); - } else if(num == 1) { - PARSER_STATE->popScope(); - cmd->reset(new PopCommand()); - } else { - std::unique_ptr seq(new CommandSequence()); - do { - PARSER_STATE->popScope(); - Command* pop_command = new PopCommand(); - pop_command->setMuted(num > 1); - seq->addCommand(pop_command); - --num; - } while(num > 0); - cmd->reset(seq.release()); - } + { + uint32_t num = AntlrInput::tokenToUnsigned(k); + *cmd = PARSER_STATE->handlePop(num); } - | { if(PARSER_STATE->strictModeEnabled()) { - PARSER_STATE->parseError( - "Strict compliance mode demands an integer to be provided to POP." - "Maybe you want (pop 1)?"); - } else { - PARSER_STATE->popScope(); - cmd->reset(new PopCommand()); - } + | { + *cmd = PARSER_STATE->handlePop(std::nullopt); } ) /* exit */ diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 05a1e1982..93a518df0 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -1264,6 +1264,50 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) return ret; } +std::unique_ptr Smt2::handlePush(std::optional nscopes) +{ + checkThatLogicIsSet(); + + if (!nscopes) + { + if (strictModeEnabled()) + { + parseError( + "Strict compliance mode demands an integer to be provided to " + "(push). Maybe you want (push 1)?"); + } + nscopes = 1; + } + + for (uint32_t i = 0; i < *nscopes; i++) + { + pushScope(true); + } + return std::make_unique(*nscopes); +} + +std::unique_ptr Smt2::handlePop(std::optional nscopes) +{ + checkThatLogicIsSet(); + + if (!nscopes) + { + if (strictModeEnabled()) + { + parseError( + "Strict compliance mode demands an integer to be provided to " + "(pop). Maybe you want (pop 1)?"); + } + nscopes = 1; + } + + for (uint32_t i = 0; i < *nscopes; i++) + { + popScope(); + } + return std::make_unique(*nscopes); +} + void Smt2::notifyNamedExpression(cvc5::Term& expr, std::string name) { checkUserSymbol(name); diff --git a/src/parser/smt2/smt2.h b/src/parser/smt2/smt2.h index 1e83375e1..f1491912a 100644 --- a/src/parser/smt2/smt2.h +++ b/src/parser/smt2/smt2.h @@ -18,6 +18,7 @@ #ifndef CVC5__PARSER__SMT2_H #define CVC5__PARSER__SMT2_H +#include #include #include #include @@ -382,6 +383,20 @@ class Smt2 : public Parser */ cvc5::Term applyParseOp(ParseOp& p, std::vector& args); //------------------------- end processing parse operators + + /** + * Handles a push command. + * + * @return An instance of `PushCommand` + */ + std::unique_ptr handlePush(std::optional nscopes); + /** + * Handles a pop command. + * + * @return An instance of `PopCommand` + */ + std::unique_ptr handlePop(std::optional nscopes); + private: void addArithmeticOperators(); diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index 50d96865d..8048d31c1 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -187,13 +187,14 @@ void AstPrinter::toStreamCmdAssert(std::ostream& out, Node n) const out << "Assert(" << n << ')' << std::endl; } -void AstPrinter::toStreamCmdPush(std::ostream& out) const +void AstPrinter::toStreamCmdPush(std::ostream& out, uint32_t nscopes) const { - out << "Push()" << std::endl; + out << "Push(" << nscopes << ")" << std::endl; } -void AstPrinter::toStreamCmdPop(std::ostream& out) const { - out << "Pop()" << std::endl; +void AstPrinter::toStreamCmdPop(std::ostream& out, uint32_t nscopes) const +{ + out << "Pop(" << nscopes << ")" << std::endl; } void AstPrinter::toStreamCmdCheckSat(std::ostream& out) const diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index 54d3e6c94..a57ed0e7c 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -52,10 +52,10 @@ class AstPrinter : public cvc5::internal::Printer void toStreamCmdAssert(std::ostream& out, Node n) const override; /** Print push command */ - void toStreamCmdPush(std::ostream& out) const override; + void toStreamCmdPush(std::ostream& out, uint32_t nscopes) const override; /** Print pop command */ - void toStreamCmdPop(std::ostream& out) const override; + void toStreamCmdPop(std::ostream& out, uint32_t nscopes) const override; /** Print declare-fun command */ void toStreamCmdDeclareFunction(std::ostream& out, diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 23864e71d..eea164b20 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -183,12 +183,12 @@ void Printer::toStreamCmdAssert(std::ostream& out, Node n) const printUnknownCommand(out, "assert"); } -void Printer::toStreamCmdPush(std::ostream& out) const +void Printer::toStreamCmdPush(std::ostream& out, uint32_t nscopes) const { printUnknownCommand(out, "push"); } -void Printer::toStreamCmdPop(std::ostream& out) const +void Printer::toStreamCmdPop(std::ostream& out, uint32_t nscopes) const { printUnknownCommand(out, "pop"); } diff --git a/src/printer/printer.h b/src/printer/printer.h index e3e1eef91..424726fdd 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -81,10 +81,10 @@ class Printer virtual void toStreamCmdAssert(std::ostream& out, Node n) const; /** Print push command */ - virtual void toStreamCmdPush(std::ostream& out) const; + virtual void toStreamCmdPush(std::ostream& out, uint32_t nscopes) const; /** Print pop command */ - virtual void toStreamCmdPop(std::ostream& out) const; + virtual void toStreamCmdPop(std::ostream& out, uint32_t nscopes) const; /** Print declare-fun command */ virtual void toStreamCmdDeclareFunction(std::ostream& out, diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index fe54a6b8d..bc35f639f 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1505,14 +1505,14 @@ void Smt2Printer::toStreamCmdAssert(std::ostream& out, Node n) const out << "(assert " << n << ')' << std::endl; } -void Smt2Printer::toStreamCmdPush(std::ostream& out) const +void Smt2Printer::toStreamCmdPush(std::ostream& out, uint32_t nscopes) const { - out << "(push 1)" << std::endl; + out << "(push " << nscopes << ")" << std::endl; } -void Smt2Printer::toStreamCmdPop(std::ostream& out) const +void Smt2Printer::toStreamCmdPop(std::ostream& out, uint32_t nscopes) const { - out << "(pop 1)" << std::endl; + out << "(pop " << nscopes << ")" << std::endl; } void Smt2Printer::toStreamCmdCheckSat(std::ostream& out) const diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 72b6a0eb3..57688255d 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -64,10 +64,10 @@ class Smt2Printer : public cvc5::internal::Printer void toStreamCmdAssert(std::ostream& out, Node n) const override; /** Print push command */ - void toStreamCmdPush(std::ostream& out) const override; + void toStreamCmdPush(std::ostream& out, uint32_t nscopes) const override; /** Print pop command */ - void toStreamCmdPop(std::ostream& out) const override; + void toStreamCmdPop(std::ostream& out, uint32_t nscopes) const override; /** Print declare-fun command */ void toStreamCmdDeclareFunction(std::ostream& out, diff --git a/src/smt/command.cpp b/src/smt/command.cpp index a5956d63a..d3a5702f3 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -367,11 +367,13 @@ void AssertCommand::toStream(std::ostream& out, /* class PushCommand */ /* -------------------------------------------------------------------------- */ +PushCommand::PushCommand(uint32_t nscopes) : d_nscopes(nscopes) {} + void PushCommand::invoke(cvc5::Solver* solver, SymbolManager* sm) { try { - solver->push(); + solver->push(d_nscopes); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -380,7 +382,7 @@ void PushCommand::invoke(cvc5::Solver* solver, SymbolManager* sm) } } -Command* PushCommand::clone() const { return new PushCommand(); } +Command* PushCommand::clone() const { return new PushCommand(d_nscopes); } std::string PushCommand::getCommandName() const { return "push"; } void PushCommand::toStream(std::ostream& out, @@ -388,18 +390,20 @@ void PushCommand::toStream(std::ostream& out, size_t dag, Language language) const { - Printer::getPrinter(language)->toStreamCmdPush(out); + Printer::getPrinter(language)->toStreamCmdPush(out, d_nscopes); } /* -------------------------------------------------------------------------- */ /* class PopCommand */ /* -------------------------------------------------------------------------- */ +PopCommand::PopCommand(uint32_t nscopes) : d_nscopes(nscopes) {} + void PopCommand::invoke(cvc5::Solver* solver, SymbolManager* sm) { try { - solver->pop(); + solver->pop(d_nscopes); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -408,7 +412,7 @@ void PopCommand::invoke(cvc5::Solver* solver, SymbolManager* sm) } } -Command* PopCommand::clone() const { return new PopCommand(); } +Command* PopCommand::clone() const { return new PopCommand(d_nscopes); } std::string PopCommand::getCommandName() const { return "pop"; } void PopCommand::toStream(std::ostream& out, @@ -416,7 +420,7 @@ void PopCommand::toStream(std::ostream& out, size_t dag, Language language) const { - Printer::getPrinter(language)->toStreamCmdPop(out); + Printer::getPrinter(language)->toStreamCmdPop(out, d_nscopes); } /* -------------------------------------------------------------------------- */ diff --git a/src/smt/command.h b/src/smt/command.h index ddc9ca82e..9faae1fca 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -370,6 +370,8 @@ class CVC5_EXPORT AssertCommand : public Command class CVC5_EXPORT PushCommand : public Command { public: + PushCommand(uint32_t nscopes); + void invoke(cvc5::Solver* solver, SymbolManager* sm) override; Command* clone() const override; std::string getCommandName() const override; @@ -378,11 +380,16 @@ class CVC5_EXPORT PushCommand : public Command size_t dag = 1, internal::Language language = internal::Language::LANG_AUTO) const override; + + private: + uint32_t d_nscopes; }; /* class PushCommand */ class CVC5_EXPORT PopCommand : public Command { public: + PopCommand(uint32_t nscopes); + void invoke(cvc5::Solver* solver, SymbolManager* sm) override; Command* clone() const override; std::string getCommandName() const override; @@ -391,6 +398,9 @@ class CVC5_EXPORT PopCommand : public Command size_t dag = 1, internal::Language language = internal::Language::LANG_AUTO) const override; + + private: + uint32_t d_nscopes; }; /* class PopCommand */ class CVC5_EXPORT DeclarationDefinitionCommand : public Command