nir: add vec2_index_32bit_offset address format
authorConnor Abbott <cwabbott0@gmail.com>
Mon, 29 Jun 2020 17:47:57 +0000 (19:47 +0200)
committerMarge Bot <eric+marge@anholt.net>
Mon, 6 Jul 2020 16:44:15 +0000 (16:44 +0000)
For turnip, we use the "bindless" model on a6xx. Loads and stores with
the bindless model require a bindless base, which is an immediate field
in the instruction that selects between 5 different 64-bit "bindless
base registers", a 32-bit descriptor index that's added to the base, and
the usual 32-bit offset. The bindless base usually, but not always,
corresponds to the Vulkan descriptor set.  We can handle the case where
the base is non-constant by using a bunch of if-statements, to make it a
little easier in core NIR, and this seems to be what Qualcomm's driver
does too. Therefore, the pointer format we need to use in NIR has a vec2
index, for the bindless base and descriptor index. Plumb this format
through core NIR.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Reviewed-by: Kristian H. Kristensen <hoegsberg@google.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5683>

src/compiler/nir/nir.h
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_lower_io.c

index c90b025f0a89021fc148f3707c413267b441003b..030589932436eb88bda47ae8103738765f93b562 100644 (file)
@@ -4017,6 +4017,12 @@ typedef enum {
     */
    nir_address_format_32bit_index_offset,
 
+   /**
+    * An address format which is comprised of a vec3 where the first two
+    * components specify the buffer and the third is an offset.
+    */
+   nir_address_format_vec2_index_32bit_offset,
+
    /**
     * An address format which is a simple 32-bit offset.
     */
@@ -4036,12 +4042,13 @@ static inline unsigned
 nir_address_format_bit_size(nir_address_format addr_format)
 {
    switch (addr_format) {
-   case nir_address_format_32bit_global:           return 32;
-   case nir_address_format_64bit_global:           return 64;
-   case nir_address_format_64bit_bounded_global:   return 32;
-   case nir_address_format_32bit_index_offset:     return 32;
-   case nir_address_format_32bit_offset:           return 32;
-   case nir_address_format_logical:                return 32;
+   case nir_address_format_32bit_global:              return 32;
+   case nir_address_format_64bit_global:              return 64;
+   case nir_address_format_64bit_bounded_global:      return 32;
+   case nir_address_format_32bit_index_offset:        return 32;
+   case nir_address_format_vec2_index_32bit_offset:   return 32;
+   case nir_address_format_32bit_offset:              return 32;
+   case nir_address_format_logical:                   return 32;
    }
    unreachable("Invalid address format");
 }
@@ -4050,12 +4057,13 @@ static inline unsigned
 nir_address_format_num_components(nir_address_format addr_format)
 {
    switch (addr_format) {
-   case nir_address_format_32bit_global:           return 1;
-   case nir_address_format_64bit_global:           return 1;
-   case nir_address_format_64bit_bounded_global:   return 4;
-   case nir_address_format_32bit_index_offset:     return 2;
-   case nir_address_format_32bit_offset:           return 1;
-   case nir_address_format_logical:                return 1;
+   case nir_address_format_32bit_global:              return 1;
+   case nir_address_format_64bit_global:              return 1;
+   case nir_address_format_64bit_bounded_global:      return 4;
+   case nir_address_format_32bit_index_offset:        return 2;
+   case nir_address_format_vec2_index_32bit_offset:   return 3;
+   case nir_address_format_32bit_offset:              return 1;
+   case nir_address_format_logical:                   return 1;
    }
    unreachable("Invalid address format");
 }
index bd3ed616ea36f8c786fabb653d75abb9abc1581c..5cc45f64c1b5669cebe1f344b07d7b5a360ce2b8 100644 (file)
@@ -481,20 +481,20 @@ intrinsic("deref_atomic_fcomp_swap", src_comp=[-1, 1, 1], dest_comp=1, indices=[
 # 2: The data parameter to the atomic function (i.e. the value to add
 #    in ssbo_atomic_add, etc).
 # 3: For CompSwap only: the second data parameter.
-intrinsic("ssbo_atomic_add",  src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_imin", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_umin", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_imax", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_umax", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_and",  src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_or",   src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_xor",  src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_exchange", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_comp_swap", src_comp=[1, 1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_fadd", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_fmin", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_fmax", src_comp=[1, 1, 1], dest_comp=1, indices=[ACCESS])
-intrinsic("ssbo_atomic_fcomp_swap", src_comp=[1, 1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_add",  src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_imin", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_umin", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_imax", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_umax", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_and",  src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_or",   src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_xor",  src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_exchange", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_comp_swap", src_comp=[-1, 1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_fadd", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_fmin", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_fmax", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS])
+intrinsic("ssbo_atomic_fcomp_swap", src_comp=[-1, 1, 1, 1], dest_comp=1, indices=[ACCESS])
 
 # CS shared variable atomic intrinsics
 #
@@ -718,7 +718,7 @@ def load(name, src_comp, indices=[], flags=[]):
 # src[] = { offset }.
 load("uniform", [1], [BASE, RANGE, TYPE], [CAN_ELIMINATE, CAN_REORDER])
 # src[] = { buffer_index, offset }.
-load("ubo", [1, 1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE, CAN_REORDER])
+load("ubo", [-1, 1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE, CAN_REORDER])
 # src[] = { offset }.
 load("input", [1], [BASE, COMPONENT, TYPE], [CAN_ELIMINATE, CAN_REORDER])
 # src[] = { vertex_id, offset }.
@@ -729,7 +729,7 @@ load("per_vertex_input", [1, 1], [BASE, COMPONENT], [CAN_ELIMINATE, CAN_REORDER]
 load("interpolated_input", [2, 1], [BASE, COMPONENT], [CAN_ELIMINATE, CAN_REORDER])
 
 # src[] = { buffer_index, offset }.
-load("ssbo", [1, 1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
+load("ssbo", [-1, 1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
 # src[] = { buffer_index }
 load("ssbo_address", [1], [], [CAN_ELIMINATE, CAN_REORDER])
 # src[] = { offset }.
@@ -763,7 +763,7 @@ store("output", [1], [BASE, WRMASK, COMPONENT, TYPE])
 # src[] = { value, vertex, offset }.
 store("per_vertex_output", [1, 1], [BASE, WRMASK, COMPONENT])
 # src[] = { value, block_index, offset }
-store("ssbo", [1, 1], [WRMASK, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
+store("ssbo", [-1, 1], [WRMASK, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
 # src[] = { value, offset }.
 store("shared", [1], [BASE, WRMASK, ALIGN_MUL, ALIGN_OFFSET])
 # src[] = { value, address }.
index e45fffbcf9e539ee2c4e30c90f7b85585656af32..6a0e18144131497a26ceea0583c74e5f9b7914b7 100644 (file)
@@ -784,6 +784,10 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
       assert(addr->num_components == 2);
       return nir_vec2(b, nir_channel(b, addr, 0),
                          nir_iadd(b, nir_channel(b, addr, 1), offset));
+   case nir_address_format_vec2_index_32bit_offset:
+      assert(addr->num_components == 3);
+      return nir_vec3(b, nir_channel(b, addr, 0), nir_channel(b, addr, 1),
+                         nir_iadd(b, nir_channel(b, addr, 2), offset));
    case nir_address_format_logical:
       unreachable("Unsupported address format");
    }
@@ -802,18 +806,30 @@ static nir_ssa_def *
 addr_to_index(nir_builder *b, nir_ssa_def *addr,
               nir_address_format addr_format)
 {
-   assert(addr_format == nir_address_format_32bit_index_offset);
-   assert(addr->num_components == 2);
-   return nir_channel(b, addr, 0);
+   if (addr_format == nir_address_format_32bit_index_offset) {
+      assert(addr->num_components == 2);
+      return nir_channel(b, addr, 0);
+   } else if (addr_format == nir_address_format_vec2_index_32bit_offset) {
+      assert(addr->num_components == 3);
+      return nir_channels(b, addr, 0x3);
+   } else {
+      unreachable("bad address format for index");
+   }
 }
 
 static nir_ssa_def *
 addr_to_offset(nir_builder *b, nir_ssa_def *addr,
                nir_address_format addr_format)
 {
-   assert(addr_format == nir_address_format_32bit_index_offset);
-   assert(addr->num_components == 2);
-   return nir_channel(b, addr, 1);
+   if (addr_format == nir_address_format_32bit_index_offset) {
+      assert(addr->num_components == 2);
+      return nir_channel(b, addr, 1);
+   } else if (addr_format == nir_address_format_vec2_index_32bit_offset) {
+      assert(addr->num_components == 3);
+      return nir_channel(b, addr, 2);
+   } else {
+      unreachable("bad address format for offset");
+   }
 }
 
 /** Returns true if the given address format resolves to a global address */
@@ -841,6 +857,7 @@ addr_to_global(nir_builder *b, nir_ssa_def *addr,
                          nir_u2u64(b, nir_channel(b, addr, 3)));
 
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_vec2_index_32bit_offset:
    case nir_address_format_32bit_offset:
    case nir_address_format_logical:
       unreachable("Cannot get a 64-bit address with this address format");
@@ -1284,7 +1301,8 @@ lower_explicit_io_array_length(nir_builder *b, nir_intrinsic_instr *intrin,
    unsigned stride = glsl_get_explicit_stride(deref->type);
    assert(stride > 0);
 
-   assert(addr_format == nir_address_format_32bit_index_offset);
+   assert(addr_format == nir_address_format_32bit_index_offset ||
+          addr_format == nir_address_format_vec2_index_32bit_offset);
    nir_ssa_def *addr = &deref->dest.ssa;
    nir_ssa_def *index = addr_to_index(b, addr, addr_format);
    nir_ssa_def *offset = addr_to_offset(b, addr, addr_format);
@@ -1582,6 +1600,7 @@ nir_address_format_null_value(nir_address_format addr_format)
       [nir_address_format_64bit_global] = {{0}},
       [nir_address_format_64bit_bounded_global] = {{0}},
       [nir_address_format_32bit_index_offset] = {{.u32 = ~0}, {.u32 = ~0}},
+      [nir_address_format_vec2_index_32bit_offset] = {{.u32 = ~0}, {.u32 = ~0}, {.u32 = ~0}},
       [nir_address_format_32bit_offset] = {{.u32 = ~0}},
       [nir_address_format_logical] = {{.u32 = ~0}},
    };
@@ -1599,6 +1618,7 @@ nir_build_addr_ieq(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
    case nir_address_format_64bit_global:
    case nir_address_format_64bit_bounded_global:
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_vec2_index_32bit_offset:
    case nir_address_format_32bit_offset:
       return nir_ball_iequal(b, addr0, addr1);
 
@@ -1631,6 +1651,12 @@ nir_build_addr_isub(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
       /* Assume the same buffer index. */
       return nir_isub(b, nir_channel(b, addr0, 1), nir_channel(b, addr1, 1));
 
+   case nir_address_format_vec2_index_32bit_offset:
+      assert(addr0->num_components == 3);
+      assert(addr1->num_components == 3);
+      /* Assume the same buffer index. */
+      return nir_isub(b, nir_channel(b, addr0, 2), nir_channel(b, addr1, 2));
+
    case nir_address_format_logical:
       unreachable("Unsupported address format");
    }