nir/vtn: CL SPIR-V callers should specify address modes
authorJesse Natalie <jenatali@microsoft.com>
Thu, 21 May 2020 22:12:15 +0000 (15:12 -0700)
committerMarge Bot <eric+marge@anholt.net>
Mon, 17 Aug 2020 14:36:18 +0000 (14:36 +0000)
Instead of inferring the address mode from the environment, allows
callers to override to suit their needs.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6330>

src/compiler/spirv/spirv_to_nir.c
src/gallium/frontends/clover/nir/invocation.cpp

index 747d1a0cdcce5e929984a34ce794f406e718ce1b..26c42fd46d817d411197ac936f283e742d5e7730 100644 (file)
@@ -4378,18 +4378,28 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
                      "AddressingModelPhysical32 only supported for kernels");
          b->shader->info.cs.ptr_size = 32;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_32bit_global;
-         b->options->global_addr_format = nir_address_format_32bit_global;
-         b->options->temp_addr_format = nir_address_format_32bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 32);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelPhysical64:
          vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
                      "AddressingModelPhysical64 only supported for kernels");
          b->shader->info.cs.ptr_size = 64;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_64bit_global;
-         b->options->global_addr_format = nir_address_format_64bit_global;
-         b->options->temp_addr_format = nir_address_format_64bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 64);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelLogical:
          vtn_fail_if(b->shader->info.stage == MESA_SHADER_KERNEL,
@@ -5531,6 +5541,10 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
       return NULL;
    }
 
+   /* Ensure a sane address mode is being used for function temps */
+   assert(nir_address_format_bit_size(b->options->temp_addr_format) == nir_get_ptr_bitsize(b->shader));
+   assert(nir_address_format_num_components(b->options->temp_addr_format) == 1);
+
    /* Set shader info defaults */
    if (stage == MESA_SHADER_GEOMETRY)
       b->shader->info.gs.invocations = 1;
index cae6ff235bab229c106668314288b92218640e2f..36ee8c9a2eaf7bee776e3691af071b4a399a8a88 100644 (file)
@@ -63,6 +63,17 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
 {
    struct spirv_to_nir_options spirv_options = {};
    spirv_options.environment = NIR_SPIRV_OPENCL;
+   if (dev.address_bits() == 32u) {
+      spirv_options.shared_addr_format = nir_address_format_32bit_global;
+      spirv_options.global_addr_format = nir_address_format_32bit_global;
+      spirv_options.temp_addr_format = nir_address_format_32bit_global;
+      spirv_options.ubo_addr_format = nir_address_format_32bit_global;
+   } else {
+      spirv_options.shared_addr_format = nir_address_format_64bit_global;
+      spirv_options.global_addr_format = nir_address_format_64bit_global;
+      spirv_options.temp_addr_format = nir_address_format_64bit_global;
+      spirv_options.ubo_addr_format = nir_address_format_32bit_index_offset;
+   }
    spirv_options.caps.address = true;
    spirv_options.caps.float64 = true;
    spirv_options.caps.int8 = true;