anv/nir: Add a central helper for figuring out SSBO address formats
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 18 Apr 2019 17:08:57 +0000 (12:08 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Fri, 19 Apr 2019 19:56:42 +0000 (19:56 +0000)
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/intel/vulkan/anv_nir.h
src/intel/vulkan/anv_nir_apply_pipeline_layout.c
src/intel/vulkan/anv_pipeline.c

index dd6c89529acb13e7c5ac3d30e5dfd3b50df0f000..c132264b2999a092d9eb8c017b8031c05dc861ab 100644 (file)
@@ -40,6 +40,20 @@ bool anv_nir_lower_multiview(nir_shader *shader, uint32_t view_mask);
 bool anv_nir_lower_ycbcr_textures(nir_shader *shader,
                                   struct anv_pipeline_layout *layout);
 
+static inline nir_address_format
+anv_nir_ssbo_addr_format(const struct anv_physical_device *pdevice,
+                         bool robust_buffer_access)
+{
+   if (pdevice->has_a64_buffer_access) {
+      if (robust_buffer_access)
+         return nir_address_format_64bit_bounded_global;
+      else
+         return nir_address_format_64bit_global;
+   } else {
+      return nir_address_format_32bit_index_offset;
+   }
+}
+
 void anv_nir_apply_pipeline_layout(const struct anv_physical_device *pdevice,
                                    bool robust_buffer_access,
                                    struct anv_pipeline_layout *layout,
index 23b1cb72098d15964bf51e52c9fe0433bcbb24b1..3d9ba5c3ecd557f7443b2489edb071548ddc7acc 100644 (file)
@@ -41,6 +41,7 @@ struct apply_pipeline_layout_state {
 
    struct anv_pipeline_layout *layout;
    bool add_bounds_checks;
+   nir_address_format ssbo_addr_format;
 
    /* Place to flag lowered instructions so we don't lower them twice */
    struct set *lowered_instrs;
@@ -338,6 +339,15 @@ lower_direct_buffer_access(nir_function_impl *impl,
    }
 }
 
+static nir_address_format
+desc_addr_format(VkDescriptorType desc_type,
+                 struct apply_pipeline_layout_state *state)
+{
+   return (desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ||
+           desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC) ?
+           state->ssbo_addr_format : nir_address_format_32bit_index_offset;
+}
+
 static void
 lower_res_index_intrinsic(nir_intrinsic_instr *intrin,
                           struct apply_pipeline_layout_state *state)
@@ -383,7 +393,8 @@ lower_res_index_intrinsic(nir_intrinsic_instr *intrin,
          dynamic_offset_index;
 
       if (state->add_bounds_checks) {
-         /* We're using nir_address_format_64bit_bounded_global */
+         assert(desc_addr_format(desc_type, state) ==
+                nir_address_format_64bit_bounded_global);
          assert(intrin->dest.ssa.num_components == 4);
          assert(intrin->dest.ssa.bit_size == 32);
          index = nir_vec4(b, nir_imm_int(b, desc_offset),
@@ -391,7 +402,8 @@ lower_res_index_intrinsic(nir_intrinsic_instr *intrin,
                              nir_imm_int(b, array_size - 1),
                              nir_ssa_undef(b, 1, 32));
       } else {
-         /* We're using nir_address_format_64bit_global */
+         assert(desc_addr_format(desc_type, state) ==
+                nir_address_format_64bit_global);
          assert(intrin->dest.ssa.num_components == 1);
          assert(intrin->dest.ssa.bit_size == 64);
          index = nir_pack_64_2x32_split(b, nir_imm_int(b, desc_offset),
@@ -399,15 +411,17 @@ lower_res_index_intrinsic(nir_intrinsic_instr *intrin,
       }
    } else if (bind_layout->data & ANV_DESCRIPTOR_INLINE_UNIFORM) {
       /* This is an inline uniform block.  Just reference the descriptor set
-       * and use the descriptor offset as the base.  Inline uniforms always
-       * use  nir_address_format_32bit_index_offset
+       * and use the descriptor offset as the base.
        */
+      assert(desc_addr_format(desc_type, state) ==
+             nir_address_format_32bit_index_offset);
       assert(intrin->dest.ssa.num_components == 2);
       assert(intrin->dest.ssa.bit_size == 32);
       index = nir_imm_ivec2(b, state->set[set].desc_offset,
                                bind_layout->descriptor_offset);
    } else {
-      /* We're using nir_address_format_32bit_index_offset */
+      assert(desc_addr_format(desc_type, state) ==
+             nir_address_format_32bit_index_offset);
       assert(intrin->dest.ssa.num_components == 2);
       assert(intrin->dest.ssa.bit_size == 32);
       index = nir_vec2(b, nir_iadd_imm(b, array_index, surface_index),
@@ -438,32 +452,37 @@ lower_res_reindex_intrinsic(nir_intrinsic_instr *intrin,
    nir_ssa_def *offset = intrin->src[1].ssa;
 
    nir_ssa_def *new_index;
-   if (state->pdevice->has_a64_buffer_access &&
-       (desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ||
-        desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC)) {
-      if (state->add_bounds_checks) {
-         /* We're using nir_address_format_64bit_bounded_global */
-         assert(intrin->dest.ssa.num_components == 4);
-         assert(intrin->dest.ssa.bit_size == 32);
-         new_index = nir_vec4(b, nir_channel(b, old_index, 0),
-                                 nir_iadd(b, nir_channel(b, old_index, 1),
-                                             offset),
-                                 nir_channel(b, old_index, 2),
-                                 nir_ssa_undef(b, 1, 32));
-      } else {
-         /* We're using nir_address_format_64bit_global */
-         assert(intrin->dest.ssa.num_components == 1);
-         assert(intrin->dest.ssa.bit_size == 64);
-         nir_ssa_def *base = nir_unpack_64_2x32_split_x(b, old_index);
-         nir_ssa_def *arr_idx = nir_unpack_64_2x32_split_y(b, old_index);
-         new_index = nir_pack_64_2x32_split(b, base, nir_iadd(b, arr_idx, offset));
-      }
-   } else {
-      /* We're using nir_address_format_32bit_index_offset */
+   switch (desc_addr_format(desc_type, state)) {
+   case nir_address_format_64bit_bounded_global:
+      /* See also lower_res_index_intrinsic() */
+      assert(intrin->dest.ssa.num_components == 4);
+      assert(intrin->dest.ssa.bit_size == 32);
+      new_index = nir_vec4(b, nir_channel(b, old_index, 0),
+                              nir_iadd(b, nir_channel(b, old_index, 1),
+                                          offset),
+                              nir_channel(b, old_index, 2),
+                              nir_ssa_undef(b, 1, 32));
+      break;
+
+   case nir_address_format_64bit_global: {
+      /* See also lower_res_index_intrinsic() */
+      assert(intrin->dest.ssa.num_components == 1);
+      assert(intrin->dest.ssa.bit_size == 64);
+      nir_ssa_def *base = nir_unpack_64_2x32_split_x(b, old_index);
+      nir_ssa_def *arr_idx = nir_unpack_64_2x32_split_y(b, old_index);
+      new_index = nir_pack_64_2x32_split(b, base, nir_iadd(b, arr_idx, offset));
+      break;
+   }
+
+   case nir_address_format_32bit_index_offset:
       assert(intrin->dest.ssa.num_components == 2);
       assert(intrin->dest.ssa.bit_size == 32);
       new_index = nir_vec2(b, nir_iadd(b, nir_channel(b, old_index, 0), offset),
                               nir_channel(b, old_index, 1));
+      break;
+
+   default:
+      unreachable("Uhandled address format");
    }
 
    assert(intrin->dest.is_ssa);
@@ -479,14 +498,22 @@ build_ssbo_descriptor_load(const VkDescriptorType desc_type,
    nir_builder *b = &state->builder;
 
    nir_ssa_def *desc_offset, *array_index;
-   if (state->add_bounds_checks) {
-      /* We're using nir_address_format_64bit_bounded_global */
+   switch (state->ssbo_addr_format) {
+   case nir_address_format_64bit_bounded_global:
+      /* See also lower_res_index_intrinsic() */
       desc_offset = nir_channel(b, index, 0);
       array_index = nir_umin(b, nir_channel(b, index, 1),
                                 nir_channel(b, index, 2));
-   } else {
+      break;
+
+   case nir_address_format_64bit_global:
+      /* See also lower_res_index_intrinsic() */
       desc_offset = nir_unpack_64_2x32_split_x(b, index);
       array_index = nir_unpack_64_2x32_split_y(b, index);
+      break;
+
+   default:
+      unreachable("Unhandled address format for SSBO");
    }
 
    /* The desc_offset is actually 16.8.8 */
@@ -541,14 +568,22 @@ lower_load_vulkan_descriptor(nir_intrinsic_instr *intrin,
           * dynamic offset.
           */
          nir_ssa_def *desc_offset, *array_index;
-         if (state->add_bounds_checks) {
-            /* We're using nir_address_format_64bit_bounded_global */
+         switch (state->ssbo_addr_format) {
+         case nir_address_format_64bit_bounded_global:
+            /* See also lower_res_index_intrinsic() */
             desc_offset = nir_channel(b, index, 0);
             array_index = nir_umin(b, nir_channel(b, index, 1),
                                       nir_channel(b, index, 2));
-         } else {
+            break;
+
+         case nir_address_format_64bit_global:
+            /* See also lower_res_index_intrinsic() */
             desc_offset = nir_unpack_64_2x32_split_x(b, index);
             array_index = nir_unpack_64_2x32_split_y(b, index);
+            break;
+
+         default:
+            unreachable("Unhandled address format for SSBO");
          }
 
          nir_ssa_def *dyn_offset_base =
@@ -573,11 +608,10 @@ lower_load_vulkan_descriptor(nir_intrinsic_instr *intrin,
             nir_bcsel(b, nir_ieq(b, dyn_offset_base, nir_imm_int(b, 0xff)),
                          nir_imm_int(b, 0), &dyn_load->dest.ssa);
 
-         if (state->add_bounds_checks) {
+         switch (state->ssbo_addr_format) {
+         case nir_address_format_64bit_bounded_global: {
             /* The dynamic offset gets added to the base pointer so that we
              * have a sliding window range.
-             *
-             * We're using nir_address_format_64bit_bounded_global.
              */
             nir_ssa_def *base_ptr =
                nir_pack_64_2x32(b, nir_channels(b, desc, 0x3));
@@ -586,9 +620,15 @@ lower_load_vulkan_descriptor(nir_intrinsic_instr *intrin,
                                nir_unpack_64_2x32_split_y(b, base_ptr),
                                nir_channel(b, desc, 2),
                                nir_channel(b, desc, 3));
-         } else {
-            /* We're using nir_address_format_64bit_global */
+            break;
+         }
+
+         case nir_address_format_64bit_global:
             desc = nir_iadd(b, desc, nir_u2u64(b, dynamic_offset));
+            break;
+
+         default:
+            unreachable("Unhandled address format for SSBO");
          }
       }
    } else {
@@ -967,6 +1007,7 @@ anv_nir_apply_pipeline_layout(const struct anv_physical_device *pdevice,
       .shader = shader,
       .layout = layout,
       .add_bounds_checks = robust_buffer_access,
+      .ssbo_addr_format = anv_nir_ssbo_addr_format(pdevice, robust_buffer_access),
       .lowered_instrs = _mesa_pointer_set_create(mem_ctx),
       .dynamic_offset_uniform_start = -1,
    };
index 64d4d93803cc4f4b828347590613404c751e2e24..20eab548fb27e2c1381edf4f3d71661737df7412 100644 (file)
@@ -134,6 +134,8 @@ anv_shader_compile_to_nir(struct anv_device *device,
       }
    }
 
+   nir_address_format ssbo_addr_format =
+      anv_nir_ssbo_addr_format(pdevice, device->robust_buffer_access);
    struct spirv_to_nir_options spirv_options = {
       .lower_workgroup_access_to_offsets = true,
       .caps = {
@@ -169,19 +171,12 @@ anv_shader_compile_to_nir(struct anv_device *device,
          .variable_pointers = true,
       },
       .ubo_ptr_type = glsl_vector_type(GLSL_TYPE_UINT, 2),
+      .ssbo_ptr_type = nir_address_format_to_glsl_type(ssbo_addr_format),
       .phys_ssbo_ptr_type = glsl_vector_type(GLSL_TYPE_UINT64, 1),
       .push_const_ptr_type = glsl_uint_type(),
       .shared_ptr_type = glsl_uint_type(),
    };
 
-   if (pdevice->has_a64_buffer_access) {
-      if (device->robust_buffer_access)
-         spirv_options.ssbo_ptr_type = glsl_vector_type(GLSL_TYPE_UINT, 4);
-      else
-         spirv_options.ssbo_ptr_type = glsl_vector_type(GLSL_TYPE_UINT64, 1);
-   } else {
-      spirv_options.ssbo_ptr_type = glsl_vector_type(GLSL_TYPE_UINT, 2);
-   }
 
    nir_function *entry_point =
       spirv_to_nir(spirv, module->size / 4,
@@ -626,18 +621,9 @@ anv_pipeline_lower_nir(struct anv_pipeline *pipeline,
 
       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo,
                  nir_address_format_32bit_index_offset);
-
-      nir_address_format ssbo_address_format;
-      if (pdevice->has_a64_buffer_access) {
-         if (pipeline->device->robust_buffer_access)
-            ssbo_address_format = nir_address_format_64bit_bounded_global;
-         else
-            ssbo_address_format = nir_address_format_64bit_global;
-      } else {
-         ssbo_address_format = nir_address_format_32bit_index_offset;
-      }
       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ssbo,
-                 ssbo_address_format);
+                 anv_nir_ssbo_addr_format(pdevice,
+                    pipeline->device->robust_buffer_access));
 
       NIR_PASS_V(nir, nir_opt_constant_folding);