base: Add an asymmetrical Coroutine class
authorGiacomo Travaglini <giacomo.travaglini@arm.com>
Thu, 14 Jun 2018 10:37:20 +0000 (11:37 +0100)
committerGiacomo Travaglini <giacomo.travaglini@arm.com>
Thu, 28 Jun 2018 09:15:27 +0000 (09:15 +0000)
This patch is providing gem5 a Coroutine class to be used for
instantiating asymmetrical coroutines. Coroutines are built on top of
gem5 fibers, which makes them ucontext based.

Change-Id: I7bb673a954d4a456997afd45b696933534f3e239
Signed-off-by: Giacomo Travaglini <giacomo.travaglini@arm.com>
Reviewed-on: https://gem5-review.googlesource.com/11195
Reviewed-by: Gabe Black <gabeblack@google.com>
Maintainer: Gabe Black <gabeblack@google.com>

src/base/SConscript
src/base/coroutine.hh [new file with mode: 0644]
src/base/coroutinetest.cc [new file with mode: 0644]

index b3205a6bb679aad8bbb10cdecd77286f483abf19..ea91f7011f7cd4bbf93a1be2ed0ec7152c461fa4 100644 (file)
@@ -48,6 +48,7 @@ if env['USE_PNG']:
     Source('pngwriter.cc')
 Source('fiber.cc')
 GTest('fibertest', 'fibertest.cc', 'fiber.cc')
+GTest('coroutinetest', 'coroutinetest.cc', 'fiber.cc')
 Source('framebuffer.cc')
 Source('hostinfo.cc')
 Source('inet.cc')
diff --git a/src/base/coroutine.hh b/src/base/coroutine.hh
new file mode 100644 (file)
index 0000000..d288892
--- /dev/null
@@ -0,0 +1,266 @@
+/*
+ * Copyright (c) 2018 ARM Limited
+ * All rights reserved
+ *
+ * The license below extends only to copyright in the software and shall
+ * not be construed as granting a license to any other intellectual
+ * property including but not limited to intellectual property relating
+ * to a hardware implementation of the functionality of the software
+ * licensed hereunder.  You may use the software subject to the license
+ * terms below provided that you ensure that this notice is replicated
+ * unmodified and in its entirety in all distributions of the software,
+ * modified or unmodified, in source code or in binary form.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met: redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer;
+ * redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution;
+ * neither the name of the copyright holders nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Authors: Giacomo Travaglini
+ */
+
+#ifndef __BASE_COROUTINE_HH__
+#define __BASE_COROUTINE_HH__
+
+#include <functional>
+#include <stack>
+
+#include "base/fiber.hh"
+
+namespace m5
+{
+
+/**
+ * This template defines a Coroutine wrapper type with a Boost-like
+ * interface. It is built on top of the gem5 fiber class.
+ * The two template parameters (Arg and Ret) are the coroutine
+ * argument and coroutine return types which are passed between
+ * the coroutine and the caller via operator() and get() method.
+ * This implementation doesn't support passing multiple values,
+ * so a tuple must be used in that scenario.
+ *
+ * Most methods are templatized since it is relevant to distinguish
+ * the cases where one or both of the template parameters are void
+ */
+template <typename Arg, typename Ret>
+class Coroutine : public Fiber
+{
+
+    // This empty struct type is meant to replace coroutine channels
+    // in case the channel should be void (Coroutine template parameters
+    // are void. (See following ArgChannel, RetChannel typedef)
+    struct Empty {};
+    using ArgChannel = typename std::conditional<
+        std::is_same<Arg, void>::value, Empty, std::stack<Arg>>::type;
+
+    using RetChannel = typename std::conditional<
+        std::is_same<Ret, void>::value, Empty, std::stack<Ret>>::type;
+
+  public:
+    /**
+     * CallerType:
+     * A reference to an object of this class will be passed
+     * to the coroutine task. This is the way it is possible
+     * for the coroutine to interface (e.g. switch back)
+     * to the coroutine caller.
+     */
+    class CallerType
+    {
+        friend class Coroutine;
+      protected:
+        CallerType(Coroutine& _coro) : coro(_coro), callerFiber(nullptr) {}
+
+      public:
+        /**
+         * operator() is the way we can jump outside the coroutine
+         * and return a value to the caller.
+         *
+         * This method is generated only if the coroutine returns
+         * a value (Ret != void)
+         */
+        template <typename T = Ret>
+        CallerType&
+        operator()(typename std::enable_if<
+                   !std::is_same<T, void>::value, T>::type param)
+        {
+            retChannel.push(param);
+            callerFiber->run();
+            return *this;
+        }
+
+        /**
+         * operator() is the way we can jump outside the coroutine
+         *
+         * This method is generated only if the coroutine doesn't
+         * return a value (Ret = void)
+         */
+        template <typename T = Ret>
+        typename std::enable_if<std::is_same<T, void>::value,
+                                CallerType>::type&
+        operator()()
+        {
+            callerFiber->run();
+            return *this;
+        }
+
+        /**
+         * get() is the way we can extrapolate arguments from the
+         * coroutine caller.
+         * The coroutine blocks, waiting for the value, unless it is already
+         * available; otherwise caller execution is resumed,
+         * and coroutine won't execute until a value is pushed
+         * from the caller.
+         *
+         * @return arg coroutine argument
+         */
+        template <typename T = Arg>
+        typename std::enable_if<!std::is_same<T, void>::value, T>::type
+        get()
+        {
+            auto& args_channel = coro.argsChannel;
+            while (args_channel.empty()) {
+                callerFiber->run();
+            }
+
+            auto ret = args_channel.top();
+            args_channel.pop();
+            return ret;
+        }
+
+      private:
+        Coroutine& coro;
+        Fiber* callerFiber;
+        RetChannel retChannel;
+    };
+
+    Coroutine() = delete;
+    Coroutine(const Coroutine& rhs) = delete;
+    Coroutine& operator=(const Coroutine& rhs) = delete;
+
+    /**
+     * Coroutine constructor.
+     * The only way to construct a coroutine is to pass it the routine
+     * it needs to run. The first argument of the function should be a
+     * reference to the Coroutine<Arg,Ret>::caller_type which the
+     * routine will use as a way for yielding to the caller.
+     *
+     * @param f task run by the coroutine
+     */
+    Coroutine(std::function<void(CallerType&)> f)
+      : Fiber(), task(f), caller(*this)
+    {
+        // Create and Run the Coroutine
+        this->call();
+    }
+
+    virtual ~Coroutine() {}
+
+  public:
+    /** Coroutine interface */
+
+    /**
+     * operator() is the way we can jump inside the coroutine
+     * and passing arguments.
+     *
+     * This method is generated only if the coroutine takes
+     * arguments (Arg != void)
+     */
+    template <typename T = Arg>
+    Coroutine&
+    operator()(typename std::enable_if<
+               !std::is_same<T, void>::value, T>::type param)
+    {
+        argsChannel.push(param);
+        this->call();
+        return *this;
+    }
+
+    /**
+     * operator() is the way we can jump inside the coroutine.
+     *
+     * This method is generated only if the coroutine takes
+     * no arguments. (Arg = void)
+     */
+    template <typename T = Arg>
+    typename std::enable_if<std::is_same<T, void>::value, Coroutine>::type&
+    operator()()
+    {
+        this->call();
+        return *this;
+    }
+
+    /**
+     * get() is the way we can extrapolate return values
+     * (yielded) from the coroutine.
+     * The caller blocks, waiting for the value, unless it is already
+     * available; otherwise coroutine execution is resumed,
+     * and caller won't execute until a value is yielded back
+     * from the coroutine.
+     *
+     * @return ret yielded value
+     */
+    template <typename T = Ret>
+    typename std::enable_if<!std::is_same<T, void>::value, T>::type
+    get()
+    {
+        auto& ret_channel = caller.retChannel;
+        while (ret_channel.empty()) {
+            this->call();
+        }
+
+        auto ret = ret_channel.top();
+        ret_channel.pop();
+        return ret;
+    }
+
+    /** Check if coroutine is still running */
+    operator bool() const { return !this->finished(); }
+
+  private:
+    /**
+     * Overriding base (Fiber) main.
+     * This method will be automatically called by the Fiber
+     * running engine and it is a simple wrapper for the task
+     * that the coroutine is supposed to run.
+     */
+    void main() override { this->task(caller); }
+
+    void
+    call()
+    {
+        caller.callerFiber = currentFiber();
+        run();
+    }
+
+  private:
+    /** Arguments for the coroutine */
+    ArgChannel argsChannel;
+
+    /** Coroutine task */
+    std::function<void(CallerType&)> task;
+
+    /** Coroutine caller */
+    CallerType caller;
+};
+
+} //namespace m5
+
+#endif // __BASE_COROUTINE_HH__
diff --git a/src/base/coroutinetest.cc b/src/base/coroutinetest.cc
new file mode 100644 (file)
index 0000000..655bc25
--- /dev/null
@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) 2018 ARM Limited
+ * All rights reserved
+ *
+ * The license below extends only to copyright in the software and shall
+ * not be construed as granting a license to any other intellectual
+ * property including but not limited to intellectual property relating
+ * to a hardware implementation of the functionality of the software
+ * licensed hereunder.  You may use the software subject to the license
+ * terms below provided that you ensure that this notice is replicated
+ * unmodified and in its entirety in all distributions of the software,
+ * modified or unmodified, in source code or in binary form.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met: redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer;
+ * redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution;
+ * neither the name of the copyright holders nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Authors: Giacomo Travaglini
+ */
+
+#include <gtest/gtest.h>
+
+#include "base/coroutine.hh"
+
+using namespace m5;
+
+/**
+ * This test is checking if the Coroutine, once it yields
+ * back to the caller, it is still marked as not finished.
+ */
+TEST(Coroutine, Unfinished)
+{
+    auto yielding_task =
+    [] (Coroutine<void, void>::CallerType& yield)
+    {
+        yield();
+    };
+
+    Coroutine<void, void> coro(yielding_task);
+    ASSERT_TRUE(coro);
+}
+
+/**
+ * This test is checking the parameter passing interface of a
+ * coroutine which takes an integer as an argument.
+ * Coroutine::operator() and CallerType::get() are the tested
+ * APIS.
+ */
+TEST(Coroutine, Passing)
+{
+    const std::vector<int> input{ 1, 2, 3 };
+    const std::vector<int> expected_values = input;
+
+    auto passing_task =
+    [&expected_values] (Coroutine<int, void>::CallerType& yield)
+    {
+        int argument;
+
+        for (const auto expected : expected_values) {
+            argument = yield.get();
+            ASSERT_EQ(argument, expected);
+        }
+    };
+
+    Coroutine<int, void> coro(passing_task);
+    ASSERT_TRUE(coro);
+
+    for (const auto val : input) {
+        coro(val);
+    }
+}
+
+/**
+ * This test is checking the yielding interface of a coroutine
+ * which takes no argument and returns integers.
+ * Coroutine::get() and CallerType::operator() are the tested
+ * APIS.
+ */
+TEST(Coroutine, Returning)
+{
+    const std::vector<int> output{ 1, 2, 3 };
+    const std::vector<int> expected_values = output;
+
+    auto returning_task =
+    [&output] (Coroutine<void, int>::CallerType& yield)
+    {
+        for (const auto ret : output) {
+            yield(ret);
+        }
+    };
+
+    Coroutine<void, int> coro(returning_task);
+    ASSERT_TRUE(coro);
+
+    for (const auto expected : expected_values) {
+        int returned = coro.get();
+        ASSERT_EQ(returned, expected);
+    }
+}
+
+/**
+ * This test is still supposed to test the returning interface
+ * of the the Coroutine, proving how coroutine can be used
+ * for generators.
+ * The coroutine is computing the first #steps of the fibonacci
+ * sequence and it is yielding back results one number per time.
+ */
+TEST(Coroutine, Fibonacci)
+{
+    const std::vector<int> expected_values{
+        1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233 };
+
+    const int steps = expected_values.size();
+
+    auto fibonacci_task =
+    [steps] (Coroutine<void, int>::CallerType& yield)
+    {
+        int prev = 0;
+        int current = 1;
+
+        for (auto iter = 0; iter < steps; iter++) {
+            int sum = prev + current;
+            yield(sum);
+
+            prev = current;
+            current = sum;
+        }
+    };
+
+    Coroutine<void, int> coro(fibonacci_task);
+    ASSERT_TRUE(coro);
+
+    for (const auto expected : expected_values) {
+        ASSERT_TRUE(coro);
+        int returned = coro.get();
+        ASSERT_EQ(returned, expected);
+    }
+}
+
+/**
+ * This test is using a bi-channel coroutine (accepting and
+ * yielding values) for testing a cooperative task.
+ * The caller and the coroutine have a string each; they are
+ * composing a new string by merging the strings together one
+ * character per time.
+ * The result string is hence passed back and forth between the
+ * coroutine and the caller.
+ */
+TEST(Coroutine, Cooperative)
+{
+    const std::string caller_str("HloWrd");
+    const std::string coro_str("el ol!");
+    const std::string expected("Hello World!");
+
+    auto cooperative_task =
+    [&coro_str] (Coroutine<std::string, std::string>::CallerType& yield)
+    {
+        for (auto& appended_c : coro_str) {
+            auto old_str = yield.get();
+            yield(old_str + appended_c);
+        }
+    };
+
+    Coroutine<std::string, std::string> coro(cooperative_task);
+
+    std::string result;
+    for (auto& c : caller_str) {
+        ASSERT_TRUE(coro);
+        result += c;
+        result = coro(result).get();
+    }
+
+    ASSERT_EQ(result, expected);
+}
+
+/**
+ * This test is testing nested coroutines by using one inner and one
+ * outer coroutine. It basically ensures that yielding from the inner
+ * coroutine returns to the outer coroutine (mid-layer of execution) and
+ * not to the outer caller.
+ */
+TEST(Coroutine, Nested)
+{
+    const std::string wrong("Inner");
+    const std::string expected("Inner + Outer");
+
+    auto inner_task =
+    [] (Coroutine<void, std::string>::CallerType& yield)
+    {
+        std::string inner_string("Inner");
+        yield(inner_string);
+    };
+
+    auto outer_task =
+    [&inner_task] (Coroutine<void, std::string>::CallerType& yield)
+    {
+        Coroutine<void, std::string> coro(inner_task);
+        std::string inner_string = coro.get();
+
+        std::string outer_string("Outer");
+        yield(inner_string + " + " + outer_string);
+    };
+
+
+    Coroutine<void, std::string> coro(outer_task);
+    ASSERT_TRUE(coro);
+
+    std::string result = coro.get();
+
+    ASSERT_NE(result, wrong);
+    ASSERT_EQ(result, expected);
+}
+
+/**
+ * This test is stressing the scenario where two distinct fibers are
+ * calling the same coroutine.  First the test instantiates (and runs) a
+ * coroutine, then spawns another one and it passes it a reference to
+ * the first coroutine. Once the new coroutine calls the first coroutine
+ * and the first coroutine yields, we are expecting execution flow to
+ * be yielded to the second caller (the second coroutine) and not the
+ * original caller (the test itself)
+ */
+TEST(Coroutine, TwoCallers)
+{
+    bool valid_return = false;
+
+    Coroutine<void, void> callee{[]
+        (Coroutine<void, void>::CallerType& yield)
+    {
+        yield();
+        yield();
+    }};
+
+    Coroutine<void, void> other_caller{[&callee, &valid_return]
+        (Coroutine<void, void>::CallerType& yield)
+    {
+        callee();
+        valid_return = true;
+        yield();
+    }};
+
+    ASSERT_TRUE(valid_return);
+}