Add table.product operator (#8020)
[cvc5.git] / src / theory / bags / theory_bags_type_rules.cpp
index 689b0e208cd2644dfdb15c4a5ccdd059c07af520..b0c79fb1d07bd05e0d427daead177024e99729b8 100644 (file)
@@ -449,6 +449,51 @@ TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
+                                           TNode n,
+                                           bool check)
+{
+  Assert(n.getKind() == kind::TABLE_PRODUCT);
+  Node A = n[0];
+  Node B = n[1];
+  TypeNode typeA = n[0].getType(check);
+  TypeNode typeB = n[1].getType(check);
+
+  if (check && !(typeA.isBag() && typeB.isBag()))
+  {
+    std::stringstream ss;
+    ss << "Operator " << n.getKind() << " expects two bags. "
+       << "Found two terms of types '" << typeA << "' and '" << typeB
+       << "' respectively.";
+    throw TypeCheckingExceptionPrivate(n, ss.str());
+  }
+
+  TypeNode elementAType = typeA.getBagElementType();
+  TypeNode elementBType = typeB.getBagElementType();
+
+  if (check && !(elementAType.isTuple() && elementBType.isTuple()))
+  {
+    std::stringstream ss;
+    ss << "Operator " << n.getKind() << " expects two tables (bags of tuples). "
+       << "Found two terms of types '" << typeA << "' and '" << typeB
+       << "' respectively.";
+    throw TypeCheckingExceptionPrivate(n, ss.str());
+  }
+
+  std::vector<TypeNode> productTupleTypes;
+  std::vector<TypeNode> tupleATypes = elementAType.getTupleTypes();
+  std::vector<TypeNode> tupleBTypes = elementBType.getTupleTypes();
+
+  productTupleTypes.insert(
+      productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end());
+  productTupleTypes.insert(
+      productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end());
+
+  TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes);
+  TypeNode retType = nodeManager->mkBagType(retTupleType);
+  return retType;
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;