--- /dev/null
+#!/usr/bin/env python3
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Callable, Dict, Sequence, Iterable, NoReturn, Set
+from itertools import permutations
+import pprint
+import enum
+import random
+import sys
+import unittest
+
+
+def unreachable(v):
+ # type: (NoReturn) -> NoReturn
+ raise TypeError("v was expected to be a never type")
+
+
+def lut3(imm8, a, b, c):
+ # type: (int, int, int, int) -> int
+ a &= 1 # ensure a, b, c are in-range
+ b &= 1
+ c &= 1
+ idx = c << 2 | b << 1 | a
+ return (imm8 >> idx) & 1 # idx by LSB0 order
+
+
+def ternlogi_8(RT, RA, RB, imm8):
+ # type: (int, int, int, int) -> int
+ for i in range(64):
+ RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i
+ return RT
+
+
+# found using: python3 ternlogi_simplification_experiment.py generate 24798
+def expand_encoded_imm(imm7):
+ imm7 &= 0x7F
+ v0 = imm7
+ v1 = v0 >> 0x2
+ v3 = v1 ^ v0
+ v7 = v1 << 0x4
+ v8 = v3 ^ v7
+ return v8 & 0xFF
+
+
+def ternlogi_7(RT, RA, RB, imm7):
+ # type: (int, int, int, int) -> int
+ imm8 = expand_encoded_imm(imm7)
+ for i in range(64):
+ RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i
+ return RT
+
+
+@dataclass(unsafe_hash=True, order=True, frozen=True)
+class TruthTable:
+ truth_table: "tuple[bool, ...]"
+ num_vars: int
+
+ def __post_init__(self):
+ assert len(self.truth_table) == 1 << self.num_vars
+
+ @staticmethod
+ def all_truth_tables(num_vars):
+ # type: (int) -> Iterable[TruthTable]
+ for bits in range(1 << (1 << num_vars)):
+ yield TruthTable.from_bits(num_vars, bits)
+
+ @property
+ def bits(self):
+ # type: () -> int
+ retval = 0
+ for i, v in enumerate(self.truth_table):
+ if v:
+ retval |= 1 << i
+ return retval
+
+ @staticmethod
+ def from_bits(num_vars, bits):
+ # type: (int, int) -> TruthTable
+ return TruthTable.from_fn_with_index(
+ num_vars, lambda index: bool((bits >> index) & 1))
+
+ def vars_to_index(self, variables):
+ # type: (Sequence[bool]) -> int
+ index = 0
+ for i in range(self.num_vars):
+ if variables[i]:
+ index |= 1 << i
+ return index
+
+ @staticmethod
+ def index_to_vars(num_vars, index):
+ # type: (int, int) -> tuple[bool, ...]
+ return tuple(bool((index >> i) & 1) for i in range(num_vars))
+
+ def __getitem__(self, variables):
+ # type: (Sequence[bool]) -> bool
+ return self.truth_table[self.vars_to_index(variables)]
+
+ @staticmethod
+ def from_fn_with_index(num_vars, fn):
+ # type: (int, Callable[[int], bool]) -> TruthTable
+ truth_table = [] # type: list[bool]
+ for i in range(1 << num_vars):
+ truth_table.append(fn(i))
+ return TruthTable(truth_table=tuple(truth_table), num_vars=num_vars)
+
+ @staticmethod
+ def from_fn(num_vars, fn):
+ # type: (int, Callable[[tuple[bool, ...]], bool]) -> TruthTable
+ return TruthTable.from_fn_with_index(
+ num_vars, lambda i: fn(TruthTable.index_to_vars(num_vars, i)))
+
+ @staticmethod
+ def for_const(num_vars, v):
+ # type: (int, bool) -> TruthTable
+ return TruthTable.from_fn_with_index(num_vars, lambda _: v)
+
+ @staticmethod
+ def for_var(num_vars, var_index):
+ # type: (int, int) -> TruthTable
+ return TruthTable.from_fn(
+ num_vars, lambda variables: variables[var_index])
+
+ def __invert__(self):
+ # type: () -> TruthTable
+ return TruthTable(tuple(not v for v in self.truth_table),
+ num_vars=self.num_vars)
+
+ def __and__(self, rhs):
+ # type: (TruthTable) -> TruthTable
+ assert self.num_vars == rhs.num_vars, "mismatched num_vars"
+ return TruthTable(tuple(
+ l & r for l, r in zip(self.truth_table, rhs.truth_table)),
+ num_vars=self.num_vars)
+
+ __rand__ = __and__
+
+ def __or__(self, rhs):
+ # type: (TruthTable) -> TruthTable
+ assert self.num_vars == rhs.num_vars, "mismatched num_vars"
+ return TruthTable(tuple(
+ l | r for l, r in zip(self.truth_table, rhs.truth_table)),
+ num_vars=self.num_vars)
+
+ __ror__ = __or__
+
+ def __xor__(self, rhs):
+ # type: (TruthTable) -> TruthTable
+ assert self.num_vars == rhs.num_vars, "mismatched num_vars"
+ return TruthTable(tuple(
+ l ^ r for l, r in zip(self.truth_table, rhs.truth_table)),
+ num_vars=self.num_vars)
+
+ __rxor__ = __xor__
+
+
+class TestTruthTable(unittest.TestCase):
+ def test_all_truth_tables(self):
+ self.assertEqual(list(TruthTable.all_truth_tables(num_vars=2)), [
+ TruthTable(truth_table=(False, False, False, False), num_vars=2),
+ TruthTable(truth_table=(True, False, False, False), num_vars=2),
+ TruthTable(truth_table=(False, True, False, False), num_vars=2),
+ TruthTable(truth_table=(True, True, False, False), num_vars=2),
+ TruthTable(truth_table=(False, False, True, False), num_vars=2),
+ TruthTable(truth_table=(True, False, True, False), num_vars=2),
+ TruthTable(truth_table=(False, True, True, False), num_vars=2),
+ TruthTable(truth_table=(True, True, True, False), num_vars=2),
+ TruthTable(truth_table=(False, False, False, True), num_vars=2),
+ TruthTable(truth_table=(True, False, False, True), num_vars=2),
+ TruthTable(truth_table=(False, True, False, True), num_vars=2),
+ TruthTable(truth_table=(True, True, False, True), num_vars=2),
+ TruthTable(truth_table=(False, False, True, True), num_vars=2),
+ TruthTable(truth_table=(True, False, True, True), num_vars=2),
+ TruthTable(truth_table=(False, True, True, True), num_vars=2),
+ TruthTable(truth_table=(True, True, True, True), num_vars=2),
+ ])
+
+ def test_bits(self):
+ bits = [i.bits for i in TruthTable.all_truth_tables(num_vars=2)]
+ expected = list(range(16))
+ self.assertEqual(bits, expected)
+
+ def test_const(self):
+ self.assertEqual(TruthTable.for_const(3, False), TruthTable(
+ truth_table=(False,) * 8, num_vars=3))
+ self.assertEqual(TruthTable.for_const(3, True), TruthTable(
+ truth_table=(True,) * 8, num_vars=3))
+
+ def test_var(self):
+ self.assertEqual(TruthTable.for_var(3, 0), TruthTable(
+ truth_table=(False, True) * 4, num_vars=3))
+ self.assertEqual(TruthTable.for_var(3, 1), TruthTable(
+ truth_table=(False, False, True, True) * 2, num_vars=3))
+ self.assertEqual(TruthTable.for_var(3, 2), TruthTable(
+ truth_table=(False,) * 4 + (True,) * 4, num_vars=3))
+
+ def test_invert(self):
+ self.assertEqual(~TruthTable.for_var(3, 0), TruthTable(
+ truth_table=(True, False) * 4, num_vars=3))
+
+ def test_and(self):
+ self.assertEqual(
+ TruthTable.for_var(3, 0) & TruthTable.for_var(3, 1),
+ TruthTable(truth_table=(False, False, False, True) * 2,
+ num_vars=3))
+
+ def test_or(self):
+ self.assertEqual(
+ TruthTable.for_var(3, 0) | TruthTable.for_var(3, 1),
+ TruthTable(truth_table=(False, True, True, True) * 2,
+ num_vars=3))
+
+ def test_xor(self):
+ self.assertEqual(
+ TruthTable.for_var(3, 0) ^ TruthTable.for_var(3, 1),
+ TruthTable(truth_table=(False, True, True, False) * 2,
+ num_vars=3))
+
+
+_VARS_PERMS_SET = Set["tuple[int, ...]"]
+_VARS_PERMS = Dict[TruthTable, _VARS_PERMS_SET]
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class EquivalenceClass:
+ """ an equivalence class of all truth tables where variables can be
+ arbitrarily permuted.
+ """
+ truth_tables: "frozenset[TruthTable]"
+
+ @staticmethod
+ def get_with_vars_perms(truth_table):
+ # type: (TruthTable) -> tuple[EquivalenceClass, _VARS_PERMS]
+ truth_tables = set() # type: set[TruthTable]
+ vars_perms = defaultdict(set) # type: _VARS_PERMS
+ iter_ = EquivalenceClass.permuted_truth_tables(truth_table)
+ for tt, vars_perm in iter_:
+ truth_tables.add(tt)
+ vars_perms[tt].add(vars_perm)
+ return EquivalenceClass(frozenset(truth_tables)), vars_perms
+
+ @staticmethod
+ def get(truth_table):
+ # type: (TruthTable) -> EquivalenceClass
+ retval, _ = EquivalenceClass.get_with_vars_perms(truth_table)
+ return retval
+
+ @staticmethod
+ def permuted_truth_tables(truth_table):
+ # type: (TruthTable) -> Iterable[tuple[TruthTable, tuple[int, ...]]]
+ for vars_perm in permutations(range(truth_table.num_vars)):
+ def fn(variables):
+ # type: (Sequence[bool]) -> bool
+ return truth_table[[variables[i] for i in vars_perm]]
+ perm_tt = TruthTable.from_fn(truth_table.num_vars, fn)
+ yield (perm_tt, vars_perm)
+
+ @property
+ def representative_truth_table(self):
+ # type: () -> TruthTable
+ return min(self.truth_tables)
+
+ def __repr__(self):
+ # type: () -> str
+ return f"EquivalenceClass({self.representative_truth_table!r})"
+
+ def __lt__(self, rhs):
+ # type: (EquivalenceClass) -> bool
+ return self.representative_truth_table < rhs.representative_truth_table
+
+ def __gt__(self, rhs):
+ # type: (EquivalenceClass) -> bool
+ return self.representative_truth_table > rhs.representative_truth_table
+
+ def __le__(self, rhs):
+ # type: (EquivalenceClass) -> bool
+ return self.representative_truth_table <= rhs.representative_truth_table
+
+ def __ge__(self, rhs):
+ # type: (EquivalenceClass) -> bool
+ return self.representative_truth_table >= rhs.representative_truth_table
+
+
+class TestEquivalenceClass(unittest.TestCase):
+ maxDiff = None
+
+ def test_get_with_vars_perms(self):
+ truth_table = TruthTable.from_bits(2, 0x4)
+ eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table)
+ self.assertEqual(eqv_cls.truth_tables, frozenset(vars_perms.keys()))
+ self.assertEqual(dict(vars_perms), {
+ TruthTable(truth_table=(False, False, True, False),
+ num_vars=2): {(0, 1)},
+ TruthTable(truth_table=(False, True, False, False),
+ num_vars=2): {(1, 0)},
+ })
+ for tt in TruthTable.all_truth_tables(3):
+ with self.subTest(tt=repr(tt)):
+ eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(tt)
+ self.assertEqual(eqv_cls.truth_tables,
+ frozenset(vars_perms.keys()))
+ expected_vars_perms = defaultdict(set) # type: _VARS_PERMS
+ expected_vars_perms[tt].add((0, 1, 2))
+ perm = TruthTable.from_fn(3, lambda v: tt[(v[0], v[2], v[1])])
+ expected_vars_perms[perm].add((0, 2, 1))
+ perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[0], v[2])])
+ expected_vars_perms[perm].add((1, 0, 2))
+ perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[2], v[0])])
+ expected_vars_perms[perm].add((1, 2, 0))
+ perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[0], v[1])])
+ expected_vars_perms[perm].add((2, 0, 1))
+ perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[1], v[0])])
+ expected_vars_perms[perm].add((2, 1, 0))
+ l = dict(vars_perms)
+ r = dict(expected_vars_perms)
+ if l != r:
+ l = pprint.pformat(l)
+ r = pprint.pformat(r)
+ self.fail(f"vars_perms not as expected:\nvars_perms:\n{l}"
+ f"\nexpected_vars_perms:\n{r}")
+
+
+NUM_VARS = 3
+
+
+def needed_equivalence_classes():
+ # type: () -> set[EquivalenceClass]
+ ZERO = TruthTable.for_const(NUM_VARS, False)
+ ONE = TruthTable.for_const(NUM_VARS, True)
+ A = TruthTable.for_var(NUM_VARS, 0)
+ B = TruthTable.for_var(NUM_VARS, 1)
+ equivalence_classes = set() # type: set[EquivalenceClass]
+ for truth_table in TruthTable.all_truth_tables(NUM_VARS):
+ equivalence_classes.add(EquivalenceClass.get(truth_table))
+
+ # remove equivalence classes that are already covered by
+ # existing PowerISA instructions:
+
+ # li
+ equivalence_classes.remove(EquivalenceClass.get(ZERO))
+ equivalence_classes.remove(EquivalenceClass.get(ONE))
+
+ # mv
+ equivalence_classes.remove(EquivalenceClass.get(A))
+
+ # not
+ equivalence_classes.remove(EquivalenceClass.get(~A))
+
+ # and
+ equivalence_classes.remove(EquivalenceClass.get(A & B))
+
+ # andc
+ equivalence_classes.remove(EquivalenceClass.get(A & ~B))
+
+ # nand
+ equivalence_classes.remove(EquivalenceClass.get(~(A & B)))
+
+ # or
+ equivalence_classes.remove(EquivalenceClass.get(A | B))
+
+ # orc
+ equivalence_classes.remove(EquivalenceClass.get(A | ~B))
+
+ # nor
+ equivalence_classes.remove(EquivalenceClass.get(~(A | B)))
+
+ # xor
+ equivalence_classes.remove(EquivalenceClass.get(A ^ B))
+
+ # eqv
+ equivalence_classes.remove(EquivalenceClass.get(~(A ^ B)))
+
+ return equivalence_classes
+
+
+NEEDED_EQUIVALENCE_CLASSES = frozenset(needed_equivalence_classes())
+
+
+class MissingNeededEquivalenceClasses(Exception):
+ pass
+
+
+@dataclass(frozen=True, unsafe_hash=True)
+class Imm7AndVarsPerm:
+ imm7: int
+ vars_perm: "tuple[int, ...]"
+
+ def __repr__(self):
+ # type: () -> str
+ return (f"Imm7AndVarsPerm(imm7={self.imm7:#04x}, "
+ f"vars_perm={self.vars_perm})")
+
+
+def verify_covers_needed_equivalence_classes(expand_encoded_imm):
+ # type: (Callable[[int], int]) -> str
+ equivalence_classes = set(NEEDED_EQUIVALENCE_CLASSES)
+ imm8_imm7_map = defaultdict(set) # type: dict[int, set[Imm7AndVarsPerm]]
+ log_lines = [] # type: list[str]
+ for imm7 in range(1 << 7):
+ imm8 = expand_encoded_imm(imm7)
+ truth_table = TruthTable.from_bits(num_vars=NUM_VARS, bits=imm8)
+ log_lines.append(
+ f"imm7={hex(imm7)} imm8={hex(imm8)} truth_table={truth_table}")
+ eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table)
+ if eqv_cls in equivalence_classes:
+ log_lines.append(f"new equivalence class: {eqv_cls}")
+ equivalence_classes.remove(eqv_cls)
+ for truth_table, vars_perms_set in vars_perms.items():
+ imm8 = truth_table.bits
+ for vars_perm in vars_perms_set:
+ imm8_imm7_map[imm8].add(Imm7AndVarsPerm(
+ imm7=imm7, vars_perm=vars_perm))
+
+ for imm8 in range(1 << 8):
+ imm8_imm7_map[imm8] # fill in empty set defaults
+
+ log_lines.append("imm8_imm7_map:")
+ log_lines.append(pprint.pformat(dict(imm8_imm7_map)))
+ log_lines.append("equivalence classes we failed to cover:")
+ log_lines.append(pprint.pformat(sorted(equivalence_classes)))
+ if len(equivalence_classes) != 0:
+ log_lines.append("failed to cover all equivalence classes")
+ raise MissingNeededEquivalenceClasses("\n".join(log_lines))
+ return "\n".join(log_lines)
+
+
+@enum.unique
+class OpKind(enum.Enum):
+ Shl = "<<"
+ Shr = ">>"
+ And = "&"
+ Or = "|"
+ Xor = "^"
+ Input = "imm7"
+
+ @property
+ def is_shift(self):
+ return self is OpKind.Shl or self is OpKind.Shr
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class Operation:
+ ssa_reg: int
+ kind: OpKind
+ lhs: "Operation | int"
+ rhs: "Operation | int"
+
+ @staticmethod
+ def make_random(op_count, rand):
+ # type: (int, random.Random) -> Operation
+ inp = Operation(ssa_reg=0, kind=OpKind.Input, lhs=0, rhs=0)
+ values = [None, inp] # type: list[Operation | None]
+ kinds = tuple(OpKind)
+
+ def make_arg():
+ # type: () -> Operation | int
+ arg = rand.choice(values)
+ if arg is None:
+ return rand.randrange(0, 1 << 8)
+ return arg
+ for ssa_reg in range(1, op_count):
+ kind = rand.choice(kinds)
+ lhs = make_arg()
+ if kind.is_shift:
+ rhs = rand.randint(1, 7)
+ else:
+ rhs = make_arg()
+ values.append(Operation(ssa_reg=ssa_reg, kind=kind,
+ lhs=lhs, rhs=rhs))
+ retval = values[-1]
+ assert retval is not None
+ return retval
+
+ def to_python(self):
+ # type: () -> str
+ lines = [
+ "def expand_encoded_imm(imm7):",
+ " imm7 &= 0x7F",
+ ]
+ ssa_regs = set() # type: set[int]
+
+ def visit(op):
+ # type: (Operation | int) -> str
+ if isinstance(op, int):
+ return hex(op)
+ reg = f"v{op.ssa_reg}"
+ if op.ssa_reg in ssa_regs:
+ return reg
+ ssa_regs.add(op.ssa_reg)
+ lhs = visit(op.lhs)
+ rhs = visit(op.rhs)
+ if op.kind is OpKind.Shl:
+ lines.append(f" {reg} = {lhs} << {rhs}")
+ elif op.kind is OpKind.Shr:
+ lines.append(f" {reg} = {lhs} >> {rhs}")
+ elif op.kind is OpKind.And:
+ lines.append(f" {reg} = {lhs} & {rhs}")
+ elif op.kind is OpKind.Or:
+ lines.append(f" {reg} = {lhs} | {rhs}")
+ elif op.kind is OpKind.Xor:
+ lines.append(f" {reg} = {lhs} ^ {rhs}")
+ elif op.kind is OpKind.Input:
+ lines.append(f" {reg} = imm7")
+ else:
+ unreachable(op.kind)
+ return reg
+ final_reg = visit(self)
+ lines.append(f" return {final_reg} & 0xFF")
+ return "\n".join(lines)
+
+ def __repr__(self):
+ python_str = self.to_python()
+ return f"Operation{{\n{python_str}\n}}"
+
+
+@enum.unique
+class Cmd(enum.Enum):
+ Generate = "generate"
+ Check = "check"
+ UnitTest = "unittest"
+ Help = "--help"
+
+
+def parse_command():
+ cmd = ""
+ if len(sys.argv) > 1:
+ cmd = sys.argv[1]
+ try:
+ retval = Cmd(cmd)
+ if retval is not Cmd.Help:
+ return retval
+ except ValueError:
+ pass
+ cmds = "|".join(i.value for i in Cmd)
+ print(f"usage: {sys.argv[0]} {cmds}", file=sys.stderr)
+ exit(1)
+
+
+def main():
+ cmd = parse_command()
+ if cmd is Cmd.UnitTest:
+ del sys.argv[1]
+ unittest.main()
+ return
+ print("equivalence classes we need to cover:")
+ pprint.pprint(sorted(NEEDED_EQUIVALENCE_CLASSES))
+ print(flush=True)
+ if cmd is Cmd.Check:
+ log_str = verify_covers_needed_equivalence_classes(
+ expand_encoded_imm=expand_encoded_imm)
+ print(log_str)
+ elif cmd is Cmd.Generate:
+ rand = random.Random()
+ seed = sys.argv[-1]
+ rand.seed(seed, version=2) # seed for reproducibility
+ for i in range(100000):
+ if i % 1000 == 0:
+ print(f"try #{i} seed={seed!r}", flush=True)
+ op_count = rand.randint(1, 10)
+ python_str = Operation.make_random(op_count, rand).to_python()
+ try:
+ globals_ = {}
+ exec(python_str, globals_)
+ fn = globals_["expand_encoded_imm"]
+ verify_covers_needed_equivalence_classes(expand_encoded_imm=fn)
+ except MissingNeededEquivalenceClasses:
+ continue
+ success_str = f"# found working function:\nseed={seed!r}\n\n{python_str}"
+ print(success_str)
+ with open("found_working_function.txt", "xt") as f:
+ print(success_str, file=f)
+ return
+
+
+if __name__ == "__main__":
+ main()