TOOM-2 multiplication works for all sizes
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 28 Nov 2022 07:41:18 +0000 (23:41 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 28 Nov 2022 07:41:18 +0000 (23:41 -0800)
src/bigint_presentation_code/_tests/test_toom_cook.py
src/bigint_presentation_code/toom_cook.py

index 41a4e2236b159ad89d8481fc00a2740113dfb5a3..42ab769e037207982d0b83fc77ce50d2fe023656 100644 (file)
@@ -11,6 +11,7 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
 from bigint_presentation_code.register_allocator import allocate_registers
 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
                                                 simple_mul)
+from bigint_presentation_code.util import OSet
 
 _StateFactory = Callable[[], ContextManager[BaseSimState]]
 
@@ -1896,52 +1897,94 @@ class TestToomCook(unittest.TestCase):
                     self.assertEqual(hex(prod), hex(prod_value),
                                      f"failed: state={state}")
 
-    def tst_toom_mul_all_sizes_pre_ra_sim(self, instances):
-        # type: (tuple[ToomCookInstance, ...]) -> None
-        for lhs_signed in False, True:
-            for rhs_signed in False, True:
-                def mul(fn, lhs, rhs):
-                    # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
-                    v = ToomCookMul(
-                        fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
-                        rhs_signed=rhs_signed, instances=instances)
-                    return v.retval, v
-                for lhs_size_in_words in range(1, 32):
-                    for rhs_size_in_words in range(1, 32):
-                        lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
-                        rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
-                        with self.subTest(lhs_size_in_words=lhs_size_in_words,
-                                          rhs_size_in_words=rhs_size_in_words,
-                                          lhs_signed=lhs_signed,
-                                          rhs_signed=rhs_signed):
-                            test_cases = []  # type: list[tuple[int, int]]
-                            test_cases.append((-1, -1))
-                            test_cases.append(((0x80 << 2048) // 0xFF,
-                                               (0x80 << 2048) // 0xFF))
-                            test_cases.append(((0x40 << 2048) // 0xFF,
-                                               (0x80 << 2048) // 0xFF))
-                            test_cases.append(((0x80 << 2048) // 0xFF,
-                                               (0x40 << 2048) // 0xFF))
-                            test_cases.append(((0x40 << 2048) // 0xFF,
-                                               (0x40 << 2048) // 0xFF))
-                            test_cases.append((1 << (lhs_size_in_bits - 1),
-                                               1 << (rhs_size_in_bits - 1)))
-                            test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
-                            test_cases.append((1 << (lhs_size_in_bits - 1), 1))
-                            test_cases.append((1, 1))
-                            self.tst_toom_mul_sim(
-                                code=Mul(mul=mul,
-                                         lhs_size_in_words=lhs_size_in_words,
-                                         rhs_size_in_words=rhs_size_in_words),
-                                lhs_signed=lhs_signed, rhs_signed=rhs_signed,
-                                get_state_factory=get_pre_ra_state_factory,
-                                test_cases=test_cases)
+    def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed):
+        # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
+        def mul(fn, lhs, rhs):
+            # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
+            v = ToomCookMul(
+                fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
+                rhs_signed=rhs_signed, instances=instances)
+            return v.retval, v
+        sizes_in_words = OSet()  # type: OSet[int]
+        for i in range(6):
+            sizes_in_words.add(1 << i)
+            sizes_in_words.add(3 << i)
+        sizes_in_words = OSet(
+            i for i in sorted(sizes_in_words) if 1 <= i <= 16)
+        for lhs_size_in_words in sizes_in_words:
+            for rhs_size_in_words in sizes_in_words:
+                lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
+                rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
+                with self.subTest(lhs_size_in_words=lhs_size_in_words,
+                                  rhs_size_in_words=rhs_size_in_words,
+                                  lhs_signed=lhs_signed,
+                                  rhs_signed=rhs_signed):
+                    test_cases = []  # type: list[tuple[int, int]]
+                    test_cases.append((-1, -1))
+                    test_cases.append(((0x80 << 2048) // 0xFF,
+                                       (0x80 << 2048) // 0xFF))
+                    test_cases.append(((0x40 << 2048) // 0xFF,
+                                       (0x80 << 2048) // 0xFF))
+                    test_cases.append(((0x80 << 2048) // 0xFF,
+                                       (0x40 << 2048) // 0xFF))
+                    test_cases.append(((0x40 << 2048) // 0xFF,
+                                       (0x40 << 2048) // 0xFF))
+                    test_cases.append((1 << (lhs_size_in_bits - 1),
+                                       1 << (rhs_size_in_bits - 1)))
+                    test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
+                    test_cases.append((1 << (lhs_size_in_bits - 1), 1))
+                    test_cases.append((1, 1))
+                    self.tst_toom_mul_sim(
+                        code=Mul(mul=mul,
+                                 lhs_size_in_words=lhs_size_in_words,
+                                 rhs_size_in_words=rhs_size_in_words),
+                        lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+                        get_state_factory=get_pre_ra_state_factory,
+                        test_cases=test_cases)
+
+    def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            (TOOM_2,), lhs_signed=False, rhs_signed=False)
+
+    def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            (TOOM_2,), lhs_signed=False, rhs_signed=True)
+
+    def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            (TOOM_2,), lhs_signed=True, rhs_signed=False)
+
+    def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            (TOOM_2,), lhs_signed=True, rhs_signed=True)
+
+    def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            instances, lhs_signed=False, rhs_signed=False)
+
+    def test_toom_2_mul_us_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            instances, lhs_signed=False, rhs_signed=True)
+
+    def test_toom_2_mul_su_all_sizes_pre_ra_sim(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            instances, lhs_signed=True, rhs_signed=False)
 
-    def test_toom_2_mul_all_sizes_pre_ra_sim(self):
-        self.skipTest("broken")  # FIXME: fix
+    def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self):
         TOOM_2 = ToomCookInstance.make_toom_2()
+        instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
         self.tst_toom_mul_all_sizes_pre_ra_sim(
-            (TOOM_2, TOOM_2, TOOM_2, TOOM_2))
+            instances, lhs_signed=True, rhs_signed=True)
 
 
 if __name__ == "__main__":
index a7c5450f6f88ec5e3f4d204ad5449d7eccb669e4..4ceb3d7e0a0df582b72db31075e89a1d2fb1ec2a 100644 (file)
@@ -165,7 +165,7 @@ class EvalOpPoly:
 @final
 class EvalOpValueRange:
     __slots__ = ("eval_op", "inputs", "min_value", "max_value",
-                 "is_signed", "output_size")
+                 "is_signed", "output_size", "name_part")
 
     def __init__(self, eval_op, inputs):
         # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None
@@ -199,6 +199,10 @@ class EvalOpValueRange:
             min_v <<= GPR_SIZE_IN_BITS
             max_v <<= GPR_SIZE_IN_BITS
         self.output_size = output_size
+        if isinstance(eval_op, int):
+            self.name_part = f"const_{eval_op}"
+        else:
+            self.name_part = eval_op.name_part
 
     @cached_property
     def poly(self):
@@ -314,13 +318,14 @@ class EvalOpGenIrInput:
 @plain_data(frozen=True)
 @final
 class EvalOpGenIrState:
-    __slots__ = "fn", "inputs", "outputs_map"
+    __slots__ = "fn", "inputs", "outputs_map", "name"
 
-    def __init__(self, fn, inputs):
-        # type: (Fn, Iterable[EvalOpGenIrInput]) -> None
+    def __init__(self, fn, inputs, name):
+        # type: (Fn, Iterable[EvalOpGenIrInput], str) -> None
         super().__init__()
         self.fn = fn
         self.inputs = tuple(inputs)
+        self.name = name
         self.outputs_map = {}  # type: dict[EvalOp | int, EvalOpGenIrOutput]
 
     def get_output(self, eval_op):
@@ -330,12 +335,13 @@ class EvalOpGenIrState:
             return retval
         value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs)
         if isinstance(eval_op, int):
+            name = f"{self.name}_{EvalOp.get_name_part(eval_op)}"
             li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op],
-                                       name=f"li_{eval_op}")
+                                       name=f"{name}_li")
             output = cast_to_size(
                 fn=self.fn, ssa_val=li.outputs[0],
                 dest_size=value_range.output_size,
-                src_signed=value_range.is_signed, name=f"cast_{eval_op}")
+                src_signed=value_range.is_signed, name=f"{name}_case")
             retval = EvalOpGenIrOutput(output=output, value_range=value_range)
         else:
             retval = eval_op.make_output(state=self,
@@ -373,6 +379,31 @@ class EvalOp(metaclass=InternedMeta):
         # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
         ...
 
+    @cached_property
+    @abstractmethod
+    def name_part(self):
+        # type: () -> str
+        ...
+
+    @staticmethod
+    def get_name_part(eval_op):
+        # type: (EvalOp | int) -> str
+        if isinstance(eval_op, int):
+            return f"const_{eval_op}"
+        return eval_op.name_part
+
+    @property
+    @final
+    def lhs_name_part(self):
+        # type: () -> str
+        return EvalOp.get_name_part(self.lhs)
+
+    @property
+    @final
+    def rhs_name_part(self):
+        # type: () -> str
+        return EvalOp.get_name_part(self.rhs)
+
     def __init__(self, lhs, rhs):
         # type: (EvalOp | int, EvalOp | int) -> None
         super().__init__()
@@ -390,26 +421,34 @@ class EvalOpAdd(EvalOp):
         # type: () -> EvalOpPoly
         return self.lhs_poly + self.rhs_poly
 
+    @cached_property
+    def name_part(self):
+        # type: () -> str
+        return f"({self.lhs_name_part}+{self.rhs_name_part})"
+
     def make_output(self, state, output_value_range):
         # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
         lhs = state.get_output(self.lhs)
         lhs_output = cast_to_size(
             fn=state.fn, ssa_val=lhs.output,
             dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
-            name="add_lhs_cast")
+            name=f"{state.name}_{self.name_part}_lhs_cast")
         rhs = state.get_output(self.rhs)
         rhs_output = cast_to_size(
             fn=state.fn, ssa_val=rhs.output,
             dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
-            name="add_rhs_cast")
+            name=f"{state.name}_{self.name_part}_rhs_cast")
         setvl = state.fn.append_new_op(
             OpKind.SetVLI, immediates=[output_value_range.output_size],
-            name="setvl", maxvl=output_value_range.output_size)
-        clear_ca = state.fn.append_new_op(OpKind.ClearCA, name="clear_ca")
+            name=f"{state.name}_{self.name_part}_setvl",
+            maxvl=output_value_range.output_size)
+        clear_ca = state.fn.append_new_op(
+            OpKind.ClearCA, name=f"{state.name}_{self.name_part}_clear_ca")
         add = state.fn.append_new_op(
             OpKind.SvAddE, input_vals=[
                 lhs_output, rhs_output, clear_ca.outputs[0], setvl.outputs[0]],
-            maxvl=output_value_range.output_size, name="add")
+            maxvl=output_value_range.output_size,
+            name=f"{state.name}_{self.name_part}_add")
         return EvalOpGenIrOutput(
             output=add.outputs[0], value_range=output_value_range)
 
@@ -423,26 +462,34 @@ class EvalOpSub(EvalOp):
         # type: () -> EvalOpPoly
         return self.lhs_poly - self.rhs_poly
 
+    @cached_property
+    def name_part(self):
+        # type: () -> str
+        return f"({self.lhs_name_part}-{self.rhs_name_part})"
+
     def make_output(self, state, output_value_range):
         # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
         lhs = state.get_output(self.lhs)
         lhs_output = cast_to_size(
             fn=state.fn, ssa_val=lhs.output,
             dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
-            name="add_lhs_cast")
+            name=f"{state.name}_{self.name_part}_lhs_cast")
         rhs = state.get_output(self.rhs)
         rhs_output = cast_to_size(
             fn=state.fn, ssa_val=rhs.output,
             dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
-            name="add_rhs_cast")
+            name=f"{state.name}_{self.name_part}_rhs_cast")
         setvl = state.fn.append_new_op(
             OpKind.SetVLI, immediates=[output_value_range.output_size],
-            name="setvl", maxvl=output_value_range.output_size)
-        set_ca = state.fn.append_new_op(OpKind.SetCA, name="set_ca")
+            name=f"{state.name}_{self.name_part}_setvl",
+            maxvl=output_value_range.output_size)
+        set_ca = state.fn.append_new_op(
+            OpKind.SetCA, name=f"{state.name}_{self.name_part}_set_ca")
         sub = state.fn.append_new_op(
             OpKind.SvSubFE, input_vals=[
                 rhs_output, lhs_output, set_ca.outputs[0], setvl.outputs[0]],
-            maxvl=output_value_range.output_size, name="sub")
+            maxvl=output_value_range.output_size,
+            name=f"{state.name}_{self.name_part}_sub")
         return EvalOpGenIrOutput(
             output=sub.outputs[0], value_range=output_value_range)
 
@@ -459,6 +506,11 @@ class EvalOpMul(EvalOp):
             raise TypeError("invalid rhs type")
         return self.lhs_poly * self.rhs
 
+    @cached_property
+    def name_part(self):
+        # type: () -> str
+        return f"({self.lhs_name_part}*{self.rhs})"
+
     def make_output(self, state, output_value_range):
         # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
         raise NotImplementedError  # FIXME: finish
@@ -476,6 +528,11 @@ class EvalOpExactDiv(EvalOp):
             raise TypeError("invalid rhs type")
         return self.lhs_poly / self.rhs
 
+    @cached_property
+    def name_part(self):
+        # type: () -> str
+        return f"({self.lhs_name_part}/{self.rhs})"
+
     def make_output(self, state, output_value_range):
         # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
         raise NotImplementedError  # FIXME: finish
@@ -500,6 +557,11 @@ class EvalOpInput(EvalOp):
     def part_index(self):
         return self.lhs
 
+    @cached_property
+    def name_part(self):
+        # type: () -> str
+        return f"part[{self.part_index}]"
+
     def _make_poly(self):
         # type: () -> EvalOpPoly
         return EvalOpPoly({self.part_index: 1})
@@ -510,7 +572,7 @@ class EvalOpInput(EvalOp):
         output = cast_to_size(
             fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed,
             dest_size=output_value_range.output_size,
-            name=f"input_{self.part_index}_cast")
+            name=f"{state.name}_{self.name_part}_cast")
         return EvalOpGenIrOutput(output=output, value_range=output_value_range)
 
 
@@ -951,6 +1013,8 @@ def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
     for part in range(part_count):
         start = part * part_size
         stop = min(maxvl, start + part_size)
+        if part == part_count - 1:
+            stop = maxvl
         part_maxvl = stop - start
         part_setvl = fn.append_new_op(
             OpKind.SetVLI, immediates=[part_size], maxvl=part_size,
@@ -971,16 +1035,25 @@ _TCIs = Tuple[ToomCookInstance, ...]
 class ToomCookMul:
     __slots__ = (
         "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
-        "retval_size", "start_instance_index", "instance", "part_size",
+        "retval_size", "start_instance_index", "name", "instance", "part_size",
         "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
         "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
         "prod_inputs", "prod_eval_state", "prod_parts",
         "partial_products", "retval",
     )
 
-    def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances,
-                 retval_size=None, start_instance_index=0):
-        # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
+    def __init__(
+        self, fn,  # type: Fn
+        lhs,  # type: SSAVal
+        lhs_signed,  # type: bool
+        rhs,  # type: SSAVal
+        rhs_signed,  # type: bool
+        instances,  # type: _TCIs
+        retval_size=None,  # type: None | int
+        name=None,  # type: None | str
+        start_instance_index=0,  # type: int
+    ):
+        # type: (...) -> None
         self.fn = fn
         self.lhs = lhs
         self.lhs_signed = lhs_signed
@@ -990,6 +1063,9 @@ class ToomCookMul:
         if retval_size is None:
             retval_size = lhs.ty.reg_len + rhs.ty.reg_len
         self.retval_size = retval_size
+        if name is None:
+            name = "mul"
+        self.name = name
         if start_instance_index < 0:
             raise ValueError("start_instance_index must be non-negative")
         self.start_instance_index = start_instance_index
@@ -997,9 +1073,16 @@ class ToomCookMul:
         self.part_size = 0  # type: int
         while start_instance_index < len(instances):
             self.instance = instances[start_instance_index]
-            self.part_size = max(
-                lhs.ty.reg_len // self.instance.lhs_part_count,
-                rhs.ty.reg_len // self.instance.rhs_part_count)
+            self.part_size = 0
+            # FIXME: this loop is some kind of integer division,
+            # figure out the correct formula
+            for shift in reversed(range(6)):
+                next_part_size = self.part_size + (1 << shift)
+                if (lhs.ty.reg_len > (
+                        self.instance.lhs_part_count - 1) * next_part_size
+                        and rhs.ty.reg_len > (
+                        self.instance.rhs_part_count - 1) * next_part_size):
+                    self.part_size = next_part_size
             if self.part_size <= 0:
                 self.instance = None
                 start_instance_index += 1
@@ -1009,40 +1092,45 @@ class ToomCookMul:
             self.retval = simple_mul(fn=fn,
                                      lhs=lhs, lhs_signed=lhs_signed,
                                      rhs=rhs, rhs_signed=rhs_signed,
-                                     name="toom_cook_base_case")
+                                     name=f"{name}_base_case")
             return
         self.lhs_parts = split_into_exact_sized_parts(
             fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count,
-            part_size=self.part_size, name="lhs")
+            part_size=self.part_size, name=f"{name}_lhs")
         self.lhs_inputs = []  # type: list[EvalOpGenIrInput]
         for part, ssa_val in enumerate(self.lhs_parts):
             self.lhs_inputs.append(EvalOpGenIrInput(
                 ssa_val=ssa_val,
                 is_signed=lhs_signed and part == len(self.lhs_parts) - 1))
-        self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs)
+        self.lhs_eval_state = EvalOpGenIrState(
+            fn=fn, inputs=self.lhs_inputs, name=f"{name}_lhs_eval")
         lhs_eval_ops = self.instance.lhs_eval_ops
         self.lhs_outputs = [
             self.lhs_eval_state.get_output(i) for i in lhs_eval_ops]
         self.rhs_parts = split_into_exact_sized_parts(
             fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count,
-            part_size=self.part_size, name="rhs")
+            part_size=self.part_size, name=f"{name}_rhs")
         self.rhs_inputs = []  # type: list[EvalOpGenIrInput]
         for part, ssa_val in enumerate(self.rhs_parts):
             self.rhs_inputs.append(EvalOpGenIrInput(
                 ssa_val=ssa_val,
                 is_signed=rhs_signed and part == len(self.rhs_parts) - 1))
-        self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs)
+        self.rhs_eval_state = EvalOpGenIrState(
+            fn=fn, inputs=self.rhs_inputs, name=f"{name}_rhs_eval")
         rhs_eval_ops = self.instance.rhs_eval_ops
         self.rhs_outputs = [
             self.rhs_eval_state.get_output(i) for i in rhs_eval_ops]
         self.prod_inputs = []  # type: list[EvalOpGenIrInput]
-        for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs):
-            ssa_val = toom_cook_mul(
+        for point_index, (lhs_output, rhs_output) in enumerate(
+                zip(self.lhs_outputs, self.rhs_outputs)):
+            ssa_val = ToomCookMul(
                 fn=fn,
                 lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
                 rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
                 instances=instances,
-                start_instance_index=start_instance_index + 1)
+                start_instance_index=start_instance_index + 1,
+                retval_size=None,
+                name=f"{name}_pt{point_index}").retval
             products = (lhs_output.min_value * rhs_output.min_value,
                         lhs_output.min_value * rhs_output.max_value,
                         lhs_output.max_value * rhs_output.min_value,
@@ -1052,7 +1140,8 @@ class ToomCookMul:
                 is_signed=None,
                 min_value=min(products),
                 max_value=max(products)))
-        self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs)
+        self.prod_eval_state = EvalOpGenIrState(
+            fn=fn, inputs=self.prod_inputs, name=f"{name}_prod_eval")
         prod_eval_ops = self.instance.prod_eval_ops
         self.prod_parts = [
             self.prod_eval_state.get_output(i) for i in prod_eval_ops]
@@ -1063,24 +1152,31 @@ class ToomCookMul:
                 part_maxvl = prod_part.output.ty.reg_len
                 part_setvl = fn.append_new_op(
                     OpKind.SetVLI, immediates=[part_maxvl],
-                    name=f"prod_{part}_setvl", maxvl=part_maxvl)
+                    name=f"{name}_prod_{part}_setvl", maxvl=part_maxvl)
                 spread_part = fn.append_new_op(
                     OpKind.Spread,
                     input_vals=[prod_part.output, part_setvl.outputs[0]],
-                    name=f"prod_{part}_spread", maxvl=part_maxvl)
+                    name=f"{name}_prod_{part}_spread", maxvl=part_maxvl)
                 yield PartialProduct(
                     spread_part.outputs, shift_in_words=part * self.part_size,
                     is_signed=prod_part.is_signed, subtract=False)
         self.partial_products = tuple(partial_products())
         self.retval = sum_partial_products(
             fn=fn, partial_products=self.partial_products,
-            retval_size=retval_size, name="prod")
-
-
-def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
-                  retval_size=None, start_instance_index=0):
-    # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
+            retval_size=retval_size, name=f"{name}_sum_p_prods")
+
+
+def toom_cook_mul(
+    fn,  # type: Fn
+    lhs,  # type: SSAVal
+    lhs_signed,  # type: bool
+    rhs,  # type: SSAVal
+    rhs_signed,  # type: bool
+    instances,  # type: _TCIs
+    retval_size=None,  # type: None | int
+    name=None,  # type: None | str
+):
+    # type: (...) -> SSAVal
     return ToomCookMul(
         fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed,
-        instances=instances, retval_size=retval_size,
-        start_instance_index=start_instance_index).retval
+        instances=instances, retval_size=retval_size, name=name).retval