python: Add support for Proxy division
authorDaniel R. Carvalho <odanrc@yahoo.com.br>
Thu, 22 Oct 2020 11:34:24 +0000 (13:34 +0200)
committerDaniel Carvalho <odanrc@yahoo.com.br>
Fri, 23 Oct 2020 21:49:11 +0000 (21:49 +0000)
Allow proxies to use python3's division operations. The dividends
and divisors can be either a proxy or a constant.

Change-Id: I96b854355b8f593edfb1ea52a52548b855b05fc0
Signed-off-by: Daniel R. Carvalho <odanrc@yahoo.com.br>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/36496
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Reviewed-by: Jason Lowe-Power <power.jg@gmail.com>
Reviewed-by: Nikos Nikoleris <nikos.nikoleris@arm.com>
Maintainer: Jason Lowe-Power <power.jg@gmail.com>
Tested-by: kokoro <noreply+kokoro@google.com>
src/python/m5/proxy.py

index 9d91b84d9ccde0f1767044c5b8a300b9f4d61e42..d15b6f297c3732cb439e50f290c759b737d5e8e9 100644 (file)
@@ -55,7 +55,7 @@ class BaseProxy(object):
     def __init__(self, search_self, search_up):
         self._search_self = search_self
         self._search_up = search_up
-        self._multipliers = []
+        self._ops = []
 
     def __str__(self):
         if self._search_self and not self._search_up:
@@ -72,29 +72,48 @@ class BaseProxy(object):
                 "cannot set attribute '%s' on proxy object" % attr)
         super(BaseProxy, self).__setattr__(attr, value)
 
-    # support for multiplying proxies by constants or other proxies to
-    # other params
-    def __mul__(self, other):
-        if not (isinstance(other, (int, long, float)) or isproxy(other)):
-            raise TypeError(
-                "Proxy multiplier must be a constant or a proxy to a param")
-        self._multipliers.append(other)
-        return self
-
+    def _gen_op(operation):
+        def op(self, operand):
+            if not (isinstance(operand, (int, long, float)) or \
+                isproxy(operand)):
+                raise TypeError(
+                    "Proxy operand must be a constant or a proxy to a param")
+            self._ops.append((operation, operand))
+            return self
+        return op
+
+    # Support for multiplying proxies by either constants or other proxies
+    __mul__ = _gen_op(lambda operand_a, operand_b : operand_a * operand_b)
     __rmul__ = __mul__
 
-    def _mulcheck(self, result, base):
+    # Support for dividing proxies by either constants or other proxies
+    __truediv__ = _gen_op(lambda operand_a, operand_b :
+        operand_a / operand_b)
+    __floordiv__ = _gen_op(lambda operand_a, operand_b :
+        operand_a // operand_b)
+
+    # Support for dividing constants by proxies
+    __rtruediv__ = _gen_op(lambda operand_a, operand_b :
+        operand_b / operand_a.getValue())
+    __rfloordiv__ = _gen_op(lambda operand_a, operand_b :
+        operand_b // operand_a.getValue())
+
+    # After all the operators and operands have been defined, this function
+    # should be called to perform the actual operation
+    def _opcheck(self, result, base):
         from . import params
-        for multiplier in self._multipliers:
-            if isproxy(multiplier):
-                multiplier = multiplier.unproxy(base)
-                # assert that we are multiplying with a compatible
-                # param
-                if not isinstance(multiplier, params.NumericParamValue):
-                    raise TypeError(
-                        "Proxy multiplier must be a numerical param")
-                multiplier = multiplier.getValue()
-            result = result * multiplier
+        for operation, operand in self._ops:
+            # Get the operand's value
+            if isproxy(operand):
+                operand = operand.unproxy(base)
+                # assert that we are operating with a compatible param
+                if not isinstance(operand, params.NumericParamValue):
+                    raise TypeError("Proxy operand must be a numerical param")
+                operand = operand.getValue()
+
+            # Apply the operation
+            result = operation(result, operand)
+
         return result
 
     def unproxy(self, base):
@@ -128,7 +147,7 @@ class BaseProxy(object):
                 raise RuntimeError("Cycle in unproxy")
             result = result.unproxy(obj)
 
-        return self._mulcheck(result, base)
+        return self._opcheck(result, base)
 
     def getindex(obj, index):
         if index == None: