add PartitionedSignal.__Mux__ using existing PMux function
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 14:17:15 +0000 (15:17 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 14:17:15 +0000 (15:17 +0100)
src/ieee754/part/partsig.py

index 8a7cd57c04c5a0ebb568a54e33d25d72b903a88a..f68cf8ccd9dd33fc3a215cbe3a233605afd9fe41 100644 (file)
@@ -22,6 +22,7 @@ from ieee754.part_bits.xor import PartitionedXOR
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
 from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
 from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
+from ieee754.part_mux.part_mux import PMux
 from operator import or_, xor, and_, not_
 
 from nmigen import (Signal, Const)
 from operator import or_, xor, and_, not_
 
 from nmigen import (Signal, Const)
@@ -54,6 +55,8 @@ class PartitionedSignal(UserValue):
         else:
             self.partpoints = make_partition(mask, width)
         self.modnames = {}
         else:
             self.partpoints = make_partition(mask, width)
         self.modnames = {}
+        # for sub-modules to be created on-demand. Mux is done slightly
+        # differently
         for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
             self.modnames[name] = 0
 
         for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
             self.modnames[name] = 0
 
@@ -79,6 +82,14 @@ class PartitionedSignal(UserValue):
         result.m = other.m
         return result
 
         result.m = other.m
         return result
 
+    # nmigen-redirected constructs (Mux, Cat, Switch, Assign)
+
+    def __Mux__(self, val1, val2):
+        assert len(val1) == len(val2), \
+            "PartitionedSignal width sources must be the same " \
+            "val1 == %d, val2 == %d" % (len(val1), len(val2))
+        return PMux(self.m, self.partpoints, self, val1, val2)
+
     # unary ops that do not require partitioning
 
     def __invert__(self):
     # unary ops that do not require partitioning
 
     def __invert__(self):