def __init__(self, search_self, search_up):
self._search_self = search_self
self._search_up = search_up
- self._multipliers = []
+ self._ops = []
def __str__(self):
if self._search_self and not self._search_up:
"cannot set attribute '%s' on proxy object" % attr)
super(BaseProxy, self).__setattr__(attr, value)
- # support for multiplying proxies by constants or other proxies to
- # other params
- def __mul__(self, other):
- if not (isinstance(other, (int, long, float)) or isproxy(other)):
- raise TypeError(
- "Proxy multiplier must be a constant or a proxy to a param")
- self._multipliers.append(other)
- return self
-
+ def _gen_op(operation):
+ def op(self, operand):
+ if not (isinstance(operand, (int, long, float)) or \
+ isproxy(operand)):
+ raise TypeError(
+ "Proxy operand must be a constant or a proxy to a param")
+ self._ops.append((operation, operand))
+ return self
+ return op
+
+ # Support for multiplying proxies by either constants or other proxies
+ __mul__ = _gen_op(lambda operand_a, operand_b : operand_a * operand_b)
__rmul__ = __mul__
- def _mulcheck(self, result, base):
+ # Support for dividing proxies by either constants or other proxies
+ __truediv__ = _gen_op(lambda operand_a, operand_b :
+ operand_a / operand_b)
+ __floordiv__ = _gen_op(lambda operand_a, operand_b :
+ operand_a // operand_b)
+
+ # Support for dividing constants by proxies
+ __rtruediv__ = _gen_op(lambda operand_a, operand_b :
+ operand_b / operand_a.getValue())
+ __rfloordiv__ = _gen_op(lambda operand_a, operand_b :
+ operand_b // operand_a.getValue())
+
+ # After all the operators and operands have been defined, this function
+ # should be called to perform the actual operation
+ def _opcheck(self, result, base):
from . import params
- for multiplier in self._multipliers:
- if isproxy(multiplier):
- multiplier = multiplier.unproxy(base)
- # assert that we are multiplying with a compatible
- # param
- if not isinstance(multiplier, params.NumericParamValue):
- raise TypeError(
- "Proxy multiplier must be a numerical param")
- multiplier = multiplier.getValue()
- result = result * multiplier
+ for operation, operand in self._ops:
+ # Get the operand's value
+ if isproxy(operand):
+ operand = operand.unproxy(base)
+ # assert that we are operating with a compatible param
+ if not isinstance(operand, params.NumericParamValue):
+ raise TypeError("Proxy operand must be a numerical param")
+ operand = operand.getValue()
+
+ # Apply the operation
+ result = operation(result, operand)
+
return result
def unproxy(self, base):
raise RuntimeError("Cycle in unproxy")
result = result.unproxy(obj)
- return self._mulcheck(result, base)
+ return self._opcheck(result, base)
def getindex(obj, index):
if index == None: