8ae238103e32b3d7b7f176d6c2a838a517a372f4
[ieee754fpu.git] / src / ieee754 / part_mul_add / partpoints.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Value, Cat, C
6
7
8 def make_partition(mask, width):
9 """ from a mask and a bitwidth, create partition points.
10 note that the assumption is that the mask indicates the
11 breakpoints in regular intervals, and that the last bit (MSB)
12 of the mask is therefore *ignored*.
13 mask len = 4, width == 16 will return:
14 {4: mask[0], 8: mask[1], 12: mask[2]}
15 mask len = 8, width == 64 will return:
16 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
17 """
18 ppoints = {}
19 mlen = len(mask)
20 ppos = mlen
21 midx = 0
22 while ppos < width and midx < mlen: # -1, ignore last bit
23 ppoints[ppos] = mask[midx]
24 ppos += mlen
25 midx += 1
26 return ppoints
27
28
29 def make_partition2(mask, width):
30 """ from a mask and a bitwidth, create partition points.
31 note that the mask represents the actual partition points
32 and therefore must be ONE LESS than the number of required
33 partitions
34
35 mask len = 3, width == 16 will return:
36 {4: mask[0], 8: mask[1], 12: mask[2]}
37 mask len = 7, width == 64 will return:
38 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
39 """
40 mlen = len(mask) + 1 # ONE MORE partitions than break-points
41 jumpsize = width // mlen # amount to jump by (size of each partition)
42 ppoints = {}
43 ppos = jumpsize
44 midx = 0
45 if isinstance(mask, dict): # convert dict/partpoints to sequential list
46 mask = list(mask.values())
47 print ("make_partition2", width, mask, mlen, jumpsize)
48 while ppos < width and midx < mlen: # -1, ignore last bit
49 print (" make_partition2", ppos, width, midx, mlen)
50 ppoints[ppos] = mask[midx]
51 ppos += jumpsize
52 midx += 1
53 print (" make_partition2", mask, width, ppoints)
54 return ppoints
55
56
57 class PartitionPoints(dict):
58 """Partition points and corresponding ``Value``s.
59
60 The points at where an ALU is partitioned along with ``Value``s that
61 specify if the corresponding partition points are enabled.
62
63 For example: ``{1: True, 5: True, 10: True}`` with
64 ``width == 16`` specifies that the ALU is split into 4 sections:
65 * bits 0 <= ``i`` < 1
66 * bits 1 <= ``i`` < 5
67 * bits 5 <= ``i`` < 10
68 * bits 10 <= ``i`` < 16
69
70 If the partition_points were instead ``{1: True, 5: a, 10: True}``
71 where ``a`` is a 1-bit ``Signal``:
72 * If ``a`` is asserted:
73 * bits 0 <= ``i`` < 1
74 * bits 1 <= ``i`` < 5
75 * bits 5 <= ``i`` < 10
76 * bits 10 <= ``i`` < 16
77 * Otherwise
78 * bits 0 <= ``i`` < 1
79 * bits 1 <= ``i`` < 10
80 * bits 10 <= ``i`` < 16
81 """
82
83 def __init__(self, partition_points=None):
84 """Create a new ``PartitionPoints``.
85
86 :param partition_points: the input partition points to values mapping.
87 """
88 super().__init__()
89 if partition_points is not None:
90 for point, enabled in partition_points.items():
91 if not isinstance(point, int):
92 raise TypeError("point must be a non-negative integer")
93 if point < 0:
94 raise ValueError("point must be a non-negative integer")
95 self[point] = Value.cast(enabled)
96
97 def like(self, name=None, src_loc_at=0, mul=1):
98 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
99
100 :param name: the base name for the new ``Signal``s.
101 :param mul: a multiplication factor on the indices
102 """
103 if name is None:
104 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
105 retval = PartitionPoints()
106 for point, enabled in self.items():
107 point *= mul
108 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
109 return retval
110
111 def eq(self, rhs):
112 """Assign ``PartitionPoints`` using ``Signal.eq``."""
113 if set(self.keys()) != set(rhs.keys()):
114 raise ValueError("incompatible point set")
115 for point, enabled in self.items():
116 yield enabled.eq(rhs[point])
117
118 def as_mask(self, width, mul=1):
119 """Create a bit-mask from `self`.
120
121 Each bit in the returned mask is clear only if the partition point at
122 the same bit-index is enabled.
123
124 :param width: the bit width of the resulting mask
125 :param mul: a "multiplier" which in-place expands the partition points
126 typically set to "2" when used for multipliers
127 """
128 bits = []
129 for i in range(width):
130 i /= mul
131 if i.is_integer() and int(i) in self:
132 bits.append(~self[i])
133 else:
134 bits.append(True)
135 return Cat(*bits)
136
137 def as_sig(self):
138 """Create a straight concatenation of `self` signals
139 """
140 return Cat(self.values())
141
142 def get_max_partition_count(self, width):
143 """Get the maximum number of partitions.
144
145 Gets the number of partitions when all partition points are enabled.
146 """
147 retval = 1
148 for point in self.keys():
149 if point < width:
150 retval += 1
151 return retval
152
153 def fits_in_width(self, width):
154 """Check if all partition points are smaller than `width`."""
155 for point in self.keys():
156 if point >= width:
157 return False
158 return True
159
160 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
161 if index == -1 or index == 7:
162 return C(True, 1)
163 assert index >= 0 and index < 8
164 return self[(index * 8 + 8)*mfactor]
165
166