spirv: Implement OpPtrEqual, OpPtrNotEqual and OpPtrDiff
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index c273e9e9fee732e41ddca36c965ac70c33a9d099..42fbc68023088b1917117dd868dae7529701f142 100644 (file)
@@ -4084,6 +4084,64 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
    return true;
 }
 
+static void
+vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode,
+               const uint32_t *w, unsigned count)
+{
+      struct vtn_type *type1 = vtn_untyped_value(b, w[3])->type;
+      struct vtn_type *type2 = vtn_untyped_value(b, w[4])->type;
+      vtn_fail_if(type1->base_type != vtn_base_type_pointer ||
+                  type2->base_type != vtn_base_type_pointer,
+                  "%s operands must have pointer types",
+                  spirv_op_to_string(opcode));
+      vtn_fail_if(type1->storage_class != type2->storage_class,
+                  "%s operands must have the same storage class",
+                  spirv_op_to_string(opcode));
+
+      const struct glsl_type *type =
+         vtn_value(b, w[1], vtn_value_type_type)->type->type;
+
+      nir_address_format addr_format = vtn_mode_to_address_format(
+         b, vtn_storage_class_to_mode(b, type1->storage_class, NULL, NULL));
+
+      nir_ssa_def *def;
+
+      switch (opcode) {
+      case SpvOpPtrDiff: {
+         /* OpPtrDiff returns the difference in number of elements (not byte offset). */
+         unsigned elem_size, elem_align;
+         glsl_get_natural_size_align_bytes(type1->deref->type,
+                                           &elem_size, &elem_align);
+
+         def = nir_build_addr_isub(&b->nb,
+                                   vtn_ssa_value(b, w[3])->def,
+                                   vtn_ssa_value(b, w[4])->def,
+                                   addr_format);
+         def = nir_idiv(&b->nb, def, nir_imm_intN_t(&b->nb, elem_size, def->bit_size));
+         def = nir_i2i(&b->nb, def, glsl_get_bit_size(type));
+         break;
+      }
+
+      case SpvOpPtrEqual:
+      case SpvOpPtrNotEqual: {
+         def = nir_build_addr_ieq(&b->nb,
+                                  vtn_ssa_value(b, w[3])->def,
+                                  vtn_ssa_value(b, w[4])->def,
+                                  addr_format);
+         if (opcode == SpvOpPtrNotEqual)
+            def = nir_inot(&b->nb, def);
+         break;
+      }
+
+      default:
+         unreachable("Invalid ptr operation");
+      }
+
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+      val->ssa = vtn_create_ssa_value(b, type);
+      val->ssa->def = def;
+}
+
 static bool
 vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
                             const uint32_t *w, unsigned count)
@@ -4421,6 +4479,12 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_subgroup(b, opcode, w, count);
       break;
 
+   case SpvOpPtrDiff:
+   case SpvOpPtrEqual:
+   case SpvOpPtrNotEqual:
+      vtn_handle_ptr(b, opcode, w, count);
+      break;
+
    default:
       vtn_fail_with_opcode("Unhandled opcode", opcode);
    }