Simplify co_await_expander.
authorBin Cheng <bin.cheng@linux.alibaba.com>
Fri, 10 Apr 2020 04:38:53 +0000 (12:38 +0800)
committerBin Cheng <bin.cheng@linux.alibaba.com>
Fri, 10 Apr 2020 04:53:00 +0000 (12:53 +0800)
gcc/cp
2020-04-10  Bin Cheng  <bin.cheng@linux.alibaba.com>

    * coroutines.cc (co_await_expander): Simplify.

gcc/testsuite
2020-04-10  Bin Cheng  <bin.cheng@linux.alibaba.com>

    * g++.dg/coroutines/co-await-syntax-10.C: New test.
    * g++.dg/coroutines/co-await-syntax-11.C: New test.

gcc/cp/ChangeLog
gcc/cp/coroutines.cc
gcc/testsuite/ChangeLog
gcc/testsuite/g++.dg/coroutines/co-await-syntax-10.C [new file with mode: 0644]
gcc/testsuite/g++.dg/coroutines/co-await-syntax-11.C [new file with mode: 0644]

index 49246e8fa2d8553e3b5fc9a270dd6797f45f2db3..38f86cd3e87296a3a89a9545f8ce0ae53c6719c7 100644 (file)
@@ -1,3 +1,7 @@
+2020-04-10  Bin Cheng  <bin.cheng@linux.alibaba.com>
+
+       * coroutines.cc (co_await_expander): Simplify.
+
 2020-04-09  Jason Merrill  <jason@redhat.com>
 
        PR c++/94523
index 936be06c33629806a320b45c3676cb8f677b6dd1..ab06c0aef54a2fd2233167769febc6dcbe61051d 100644 (file)
@@ -1389,34 +1389,13 @@ co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
     return NULL_TREE;
 
   coro_aw_data *data = (coro_aw_data *) d;
-
   enum tree_code stmt_code = TREE_CODE (*stmt);
   tree stripped_stmt = *stmt;
-
-  /* Look inside <(void) (expr)> cleanup */
-  if (stmt_code == CLEANUP_POINT_EXPR)
-    {
-      stripped_stmt = TREE_OPERAND (*stmt, 0);
-      stmt_code = TREE_CODE (stripped_stmt);
-      if (stmt_code == EXPR_STMT
-         && (TREE_CODE (EXPR_STMT_EXPR (stripped_stmt)) == CONVERT_EXPR
-             || TREE_CODE (EXPR_STMT_EXPR (stripped_stmt)) == CAST_EXPR)
-         && VOID_TYPE_P (TREE_TYPE (EXPR_STMT_EXPR (stripped_stmt))))
-       {
-         stripped_stmt = TREE_OPERAND (EXPR_STMT_EXPR (stripped_stmt), 0);
-         stmt_code = TREE_CODE (stripped_stmt);
-       }
-    }
-
   tree *buried_stmt = NULL;
   tree saved_co_await = NULL_TREE;
   enum tree_code sub_code = NOP_EXPR;
 
-  if (stmt_code == EXPR_STMT
-      && TREE_CODE (EXPR_STMT_EXPR (stripped_stmt)) == CO_AWAIT_EXPR)
-    saved_co_await
-      = EXPR_STMT_EXPR (stripped_stmt); /* hopefully, a void exp.  */
-  else if (stmt_code == MODIFY_EXPR || stmt_code == INIT_EXPR)
+  if (stmt_code == MODIFY_EXPR || stmt_code == INIT_EXPR)
     {
       sub_code = TREE_CODE (TREE_OPERAND (stripped_stmt, 1));
       if (sub_code == CO_AWAIT_EXPR)
@@ -1435,6 +1414,8 @@ co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
   else if ((stmt_code == CONVERT_EXPR || stmt_code == NOP_EXPR)
           && TREE_CODE (TREE_OPERAND (stripped_stmt, 0)) == CO_AWAIT_EXPR)
     saved_co_await = TREE_OPERAND (stripped_stmt, 0);
+  else if (stmt_code == CO_AWAIT_EXPR)
+    saved_co_await = stripped_stmt;
 
   if (!saved_co_await)
     return NULL_TREE;
index d8fa35aca9c25eb15ef21bfee20af0edb79ca5b8..72e0d519bb95d80d4f4f5a556250bec0fdb15ce2 100644 (file)
@@ -1,3 +1,8 @@
+2020-04-10  Bin Cheng  <bin.cheng@linux.alibaba.com>
+
+       * g++.dg/coroutines/co-await-syntax-10.C: New test.
+       * g++.dg/coroutines/co-await-syntax-11.C: New test.
+
 2020-04-09  Fritz Reese  <foreese@gcc.gnu.org>
 
        PR fortran/87923
diff --git a/gcc/testsuite/g++.dg/coroutines/co-await-syntax-10.C b/gcc/testsuite/g++.dg/coroutines/co-await-syntax-10.C
new file mode 100644 (file)
index 0000000..8304344
--- /dev/null
@@ -0,0 +1,40 @@
+//  { dg-additional-options "-std=c++17 -w" }
+
+#include "coro.h"
+
+class await {
+public:
+  class promise_type {
+  public:
+    std::suspend_always initial_suspend() const noexcept { return {}; }
+    std::suspend_always final_suspend() const noexcept { return {}; }
+    void unhandled_exception() noexcept { }
+    await get_return_object() { return await{}; }
+    void return_void() {}
+  };
+  bool await_ready() const noexcept { return false; }
+  bool await_suspend(std::coroutine_handle<>) noexcept {return true;}
+  void await_resume() { }
+};
+
+class mycoro {
+public:
+  class promise_type {
+  public:
+    std::suspend_always initial_suspend() const noexcept { return {}; }
+    std::suspend_always final_suspend() const noexcept { return {}; }
+    void unhandled_exception() noexcept { }
+    mycoro get_return_object() { return mycoro{}; }
+    void return_void() {}
+  };
+};
+mycoro foo(await awaitable) {
+  co_return co_await awaitable;
+}
+
+mycoro bar()
+{
+ auto t = [&]() -> await { co_return; }();
+ return foo (t);
+}
+
diff --git a/gcc/testsuite/g++.dg/coroutines/co-await-syntax-11.C b/gcc/testsuite/g++.dg/coroutines/co-await-syntax-11.C
new file mode 100644 (file)
index 0000000..69810ab
--- /dev/null
@@ -0,0 +1,205 @@
+//  { dg-additional-options "-std=c++17 -w" }
+
+#include <utility>
+#include <type_traits>
+#include <tuple>
+#include <functional>
+#include <coroutine>
+
+struct any {
+  template <typename T> any(T &&) noexcept;
+};
+
+template <typename T>
+auto get_awaiter_impl(T &&value, int) noexcept
+    -> decltype(static_cast<T &&>(value).operator co_await()) {
+  return static_cast<T &&>(value).operator co_await();
+}
+template <typename T, int = 0>
+T &&get_awaiter_impl(T &&value, any) noexcept;
+template <typename T>
+auto get_awaiter(T &&value) noexcept
+    -> decltype(get_awaiter_impl(static_cast<T &&>(value), 123)) {
+  return get_awaiter_impl(static_cast<T &&>(value), 123);
+}
+
+template <typename T, typename = void> struct awaitable_traits {
+  using awaiter_t = decltype(get_awaiter(std::declval<T>()));
+  using await_result_t = decltype(std::declval<awaiter_t>().await_resume());
+};
+
+template <typename TASK_CONTAINER> class when_all_ready_awaitable;
+template <typename... TASKS>
+class when_all_ready_awaitable<std::tuple<TASKS...>> {
+public:
+  explicit when_all_ready_awaitable(std::tuple<TASKS...> &&tasks) noexcept
+    : m_tasks(std::move(tasks)) {}
+  auto operator co_await() &&noexcept {
+    struct awaiter {
+      awaiter(when_all_ready_awaitable &awaitable) noexcept
+        : m_awaitable(awaitable) {}
+      bool await_ready() const noexcept { return false; }
+      bool await_suspend() noexcept { return false; }
+      std::tuple<TASKS...> &&await_resume() noexcept {
+        return std::move(m_awaitable.m_tasks);
+      }
+      when_all_ready_awaitable& m_awaitable;
+    };
+    return awaiter{*this};
+  }
+  std::tuple<TASKS...> m_tasks;
+};
+
+inline void *operator new(std::size_t, void *__p) noexcept;
+
+template <typename RESULT>
+class when_all_task_promise final{
+public:
+  using coroutine_handle_t = std::coroutine_handle<when_all_task_promise>;
+  RESULT &&result() &&;
+};
+template <typename RESULT> class when_all_task final {
+public:
+  using promise_type = when_all_task_promise<RESULT>;
+  using coroutine_handle_t = typename promise_type::coroutine_handle_t;
+  decltype(auto) result() &;
+  decltype(auto) result() && {
+    return std::move(m_coroutine.promise()).result();
+  }
+  decltype(auto) non_void_result() && {
+    if constexpr (std::is_void_v<decltype(0)>)
+      ;
+    else
+      return std::move(*this).result();
+  }
+  coroutine_handle_t m_coroutine;
+};
+class task;
+template <typename AWAITABLE,
+          typename RESULT = 
+              typename awaitable_traits<AWAITABLE &&>::await_result_t,
+          std::enable_if_t<!std::is_void_v<RESULT>, int> = 0>
+when_all_task<RESULT> make_when_all_task(AWAITABLE awaitable);
+
+template <typename... AWAITABLES>
+inline auto when_all_ready(AWAITABLES &&... awaitables) {
+  return when_all_ready_awaitable<
+      std::tuple<when_all_task<typename awaitable_traits<
+          std::remove_reference_t<AWAITABLES>>::await_result_t>...>>(
+      std::make_tuple(
+          make_when_all_task(std::forward<AWAITABLES>(awaitables))...));
+}
+
+template <typename FUNC, typename AWAITABLE> class fmap_awaiter {
+  using awaiter_t = typename awaitable_traits<AWAITABLE &&>::awaiter_t;
+
+public:
+  fmap_awaiter(FUNC &&func, AWAITABLE &&awaitable) noexcept
+      : m_func(static_cast<FUNC &&>(func)),
+        m_awaiter(get_awaiter(static_cast<AWAITABLE &&>(awaitable))) {}
+  decltype(auto) await_ready() noexcept {
+    return static_cast<awaiter_t &&>(m_awaiter).await_ready();
+  }
+  template <typename PROMISE>
+  decltype(auto) await_suspend(std::coroutine_handle<PROMISE> coro) noexcept {}
+  template <typename AWAIT_RESULT =
+                decltype(std::declval<awaiter_t>().await_resume()),
+            std::enable_if_t<!std::is_void_v<AWAIT_RESULT>, int> = 0>
+  decltype(auto) await_resume() noexcept {
+    return std::invoke(static_cast<FUNC &&>(m_func),
+                       static_cast<awaiter_t &&>(m_awaiter).await_resume());
+  }
+
+private:
+  FUNC &&m_func;
+  awaiter_t m_awaiter;
+};
+template <typename FUNC, typename AWAITABLE> class fmap_awaitable {
+public:
+  template <
+      typename FUNC_ARG, typename AWAITABLE_ARG,
+      std::enable_if_t<std::is_constructible_v<FUNC, FUNC_ARG &&> &&
+                           std::is_constructible_v<AWAITABLE, AWAITABLE_ARG &&>,
+                       int> = 0>
+  explicit fmap_awaitable(FUNC_ARG &&func, AWAITABLE_ARG &&awaitable) noexcept
+      : m_func(static_cast<FUNC_ARG &&>(func)),
+        m_awaitable(static_cast<AWAITABLE_ARG &&>(awaitable)) {}
+  auto operator co_await() && {
+    return fmap_awaiter(static_cast<FUNC &&>(m_func),
+                        static_cast<AWAITABLE &&>(m_awaitable));
+  }
+
+private:
+  FUNC m_func;
+  AWAITABLE m_awaitable;
+};
+
+template <typename FUNC, typename AWAITABLE>
+auto fmap(FUNC &&func, AWAITABLE &&awaitable) {
+  return fmap_awaitable<std::remove_cv_t<std::remove_reference_t<FUNC>>,
+                        std::remove_cv_t<std::remove_reference_t<AWAITABLE>>>(
+      std::forward<FUNC>(func), std::forward<AWAITABLE>(awaitable));
+}
+template <typename... AWAITABLES>
+auto when_all(AWAITABLES &&... awaitables) {
+  return fmap(
+      [](auto &&taskTuple) {
+        decltype(auto) __trans_tmp_1 = std::apply(
+            [](auto &&... tasks) {
+              return std::make_tuple(
+                  static_cast<decltype(tasks)>(tasks).non_void_result()...);
+            },
+            static_cast<decltype(taskTuple)>(taskTuple));
+        return __trans_tmp_1;
+      },
+      when_all_ready(std::forward<AWAITABLES>(awaitables)...));
+}
+class async_mutex_scoped_lock_operation;
+class async_mutex {
+public:
+  async_mutex() noexcept;
+  async_mutex_scoped_lock_operation scoped_lock_async() noexcept;
+};
+class async_mutex_lock {
+public:
+  explicit async_mutex_lock();
+  ~async_mutex_lock();
+
+private:
+  async_mutex *m_mutex;
+};
+class async_mutex_scoped_lock_operation {
+public:
+  async_mutex_lock await_resume() const noexcept;
+};
+class task {
+public:
+  class promise_type {
+  public:
+    auto initial_suspend() noexcept { return std::suspend_always{}; }
+    auto final_suspend() noexcept { return std::suspend_always{}; }
+    task get_return_object() noexcept { return task{}; }
+    void unhandled_exception() noexcept {}
+    void return_value(int value) noexcept { v = value; }
+    int result(){ return v; }
+    int v = 0;
+  };
+public:
+  task() noexcept {}
+  auto operator co_await() const &noexcept {
+    struct awaitable {
+      std::coroutine_handle<promise_type> m_coroutine;
+      decltype(auto) await_resume() {
+        return this->m_coroutine.promise().result();
+      }
+    };
+    return awaitable{};
+  }
+};
+void foo() {
+  (void) []() -> task {
+    auto makeTask = [](int x) -> task { co_return x; };
+    async_mutex_scoped_lock_operation op;
+    co_await when_all(std::move(op), makeTask(123));
+  }();
+}