nir: add support for address bit sized system values
authorKarol Herbst <kherbst@redhat.com>
Thu, 19 Jul 2018 14:39:58 +0000 (16:39 +0200)
committerKarol Herbst <kherbst@redhat.com>
Tue, 5 Mar 2019 21:28:29 +0000 (22:28 +0100)
v2: add assert in else clause
    make local group intrinsics 32 bit wide
v3: always use 32 bit constant for local_size
v4: add comment by Jason

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_lower_system_values.c

index 1d388c64fc9c11733a7e6a0d45121bb2f59b45b7..d53b26c88d605b44c28abcd0344a36c96479f268 100644 (file)
@@ -532,7 +532,7 @@ system_value("subgroup_lt_mask", 0, bit_sizes=[32, 64])
 system_value("num_subgroups", 1)
 system_value("subgroup_id", 1)
 system_value("local_group_size", 3)
-system_value("global_invocation_id", 3)
+system_value("global_invocation_id", 3, bit_sizes=[32, 64])
 system_value("work_dim", 1)
 # Driver-specific viewport scale/offset parameters.
 #
index 68b0ea89c8d52b7fa7e212e18baab2df03251b68..de5ccab0f38823aaff3c0fa1730a3bdf4c93076a 100644 (file)
@@ -29,7 +29,7 @@
 #include "nir_builder.h"
 
 static nir_ssa_def*
-build_local_group_size(nir_builder *b)
+build_local_group_size(nir_builder *b, unsigned bit_size)
 {
    nir_ssa_def *local_size;
 
@@ -40,6 +40,8 @@ build_local_group_size(nir_builder *b)
    if (b->shader->info.cs.local_size_variable) {
       local_size = nir_load_local_group_size(b);
    } else {
+      /* using a 32 bit constant is safe here as no device/driver needs more
+       * than 32 bits for the local size */
       nir_const_value local_size_const;
       memset(&local_size_const, 0, sizeof(local_size_const));
       local_size_const.u32[0] = b->shader->info.cs.local_size[0];
@@ -48,12 +50,15 @@ build_local_group_size(nir_builder *b)
       local_size = nir_build_imm(b, 3, 32, local_size_const);
    }
 
-   return local_size;
+   return nir_u2u(b, local_size, bit_size);
 }
 
 static nir_ssa_def *
-build_local_invocation_id(nir_builder *b)
+build_local_invocation_id(nir_builder *b, unsigned bit_size)
 {
+   /* If lower_cs_local_id_from_index is true, then we derive the local
+    * index from the local id.
+    */
    if (b->shader->options->lower_cs_local_id_from_index) {
       /* We lower gl_LocalInvocationID from gl_LocalInvocationIndex based
        * on this formula:
@@ -73,8 +78,12 @@ build_local_invocation_id(nir_builder *b)
        * large so it can safely be omitted.
        */
       nir_ssa_def *local_index = nir_load_local_invocation_index(b);
-      nir_ssa_def *local_size = build_local_group_size(b);
+      nir_ssa_def *local_size = build_local_group_size(b, 32);
 
+      /* Because no hardware supports a local workgroup size greater than
+       * about 1K, this calculation can be done in 32-bit and can save some
+       * 64-bit arithmetic.
+       */
       nir_ssa_def *id_x, *id_y, *id_z;
       id_x = nir_umod(b, local_index,
                          nir_channel(b, local_size, 0));
@@ -84,9 +93,9 @@ build_local_invocation_id(nir_builder *b)
       id_z = nir_udiv(b, local_index,
                          nir_imul(b, nir_channel(b, local_size, 0),
                                      nir_channel(b, local_size, 1)));
-      return nir_vec3(b, id_x, id_y, id_z);
+      return nir_u2u(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
    } else {
-      return nir_load_local_invocation_id(b);
+      return nir_u2u(b, nir_load_local_invocation_id(b), bit_size);
    }
 }
 
@@ -120,6 +129,7 @@ convert_block(nir_block *block, nir_builder *b)
 
       b->cursor = nir_after_instr(&load_deref->instr);
 
+      unsigned bit_size = nir_dest_bit_size(load_deref->dest);
       nir_ssa_def *sysval = NULL;
       switch (var->data.location) {
       case SYSTEM_VALUE_GLOBAL_INVOCATION_ID: {
@@ -128,9 +138,9 @@ convert_block(nir_block *block, nir_builder *b)
           *    "The value of gl_GlobalInvocationID is equal to
           *    gl_WorkGroupID * gl_WorkGroupSize + gl_LocalInvocationID"
           */
-         nir_ssa_def *group_size = build_local_group_size(b);
-         nir_ssa_def *group_id = nir_load_work_group_id(b);
-         nir_ssa_def *local_id = build_local_invocation_id(b);
+         nir_ssa_def *group_size = build_local_group_size(b, bit_size);
+         nir_ssa_def *group_id = nir_u2u(b, nir_load_work_group_id(b), bit_size);
+         nir_ssa_def *local_id = build_local_invocation_id(b, bit_size);
 
          sysval = nir_iadd(b, nir_imul(b, group_id, group_size), local_id);
          break;
@@ -157,24 +167,25 @@ convert_block(nir_block *block, nir_builder *b)
          nir_ssa_def *size_y =
             nir_imm_int(b, b->shader->info.cs.local_size[1]);
 
+         /* Because no hardware supports a local workgroup size greater than
+          * about 1K, this calculation can be done in 32-bit and can save some
+          * 64-bit arithmetic.
+          */
          sysval = nir_imul(b, nir_channel(b, local_id, 2),
                               nir_imul(b, size_x, size_y));
          sysval = nir_iadd(b, sysval,
                               nir_imul(b, nir_channel(b, local_id, 1), size_x));
          sysval = nir_iadd(b, sysval, nir_channel(b, local_id, 0));
+         sysval = nir_u2u(b, sysval, bit_size);
          break;
       }
 
       case SYSTEM_VALUE_LOCAL_INVOCATION_ID:
-         /* If lower_cs_local_id_from_index is true, then we derive the local
-          * index from the local id.
-          */
-         if (b->shader->options->lower_cs_local_id_from_index)
-            sysval = build_local_invocation_id(b);
+         sysval = build_local_invocation_id(b, bit_size);
          break;
 
       case SYSTEM_VALUE_LOCAL_GROUP_SIZE: {
-         sysval = build_local_group_size(b);
+         sysval = build_local_group_size(b, bit_size);
          break;
       }
 
@@ -248,8 +259,8 @@ convert_block(nir_block *block, nir_builder *b)
          break;
 
       case SYSTEM_VALUE_GLOBAL_GROUP_SIZE: {
-         nir_ssa_def *group_size = build_local_group_size(b);
-         nir_ssa_def *num_work_groups = nir_load_num_work_groups(b);
+         nir_ssa_def *group_size = build_local_group_size(b, bit_size);
+         nir_ssa_def *num_work_groups = nir_u2u(b, nir_load_num_work_groups(b), bit_size);
          sysval = nir_imul(b, group_size, num_work_groups);
          break;
       }