self.n_inputs = n_inputs
self.n_parts = n_parts
self.output_width = output_width
- self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
- for i in range(n_parts)]
- self._resized_inputs = [
- Signal(output_width, name=f"resized_inputs[{i}]")
- for i in range(n_inputs)]
+ self.i = AddReduceData(partition_points, n_inputs,
+ output_width, n_parts)
+ self.out_part_ops = self.i.part_ops
+ self._resized_inputs = self.i.inputs
self.register_levels = list(register_levels)
self.partition_points = PartitionPoints(partition_points)
if not self.partition_points.fits_in_width(output_width):
raise ValueError("partition_points doesn't fit in output_width")
- self._reg_partition_points = self.partition_points.like()
+ self._reg_partition_points = self.i.reg_partition_points
max_level = AddReduceSingle.get_max_level(n_inputs)
for level in self.register_levels: