spirv: Use nir_const_value for spec constants
authorJason Ekstrand <jason@jlekstrand.net>
Wed, 22 Apr 2020 19:05:13 +0000 (14:05 -0500)
committerMarge Bot <eric+marge@anholt.net>
Fri, 24 Apr 2020 09:23:59 +0000 (09:23 +0000)
When we originally wrote spirv_to_nir we didn't have a good scalar value
union to handily use so we rolled our own thing for spec constants.  Now
that we have nir_const_value, we can use that and simplify a bunch of
the spec constant logic.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Acked-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4675>

src/amd/vulkan/radv_shader.c
src/compiler/spirv/nir_spirv.h
src/compiler/spirv/spirv_to_nir.c
src/freedreno/vulkan/tu_shader.c
src/intel/vulkan/anv_pipeline.c
src/mesa/main/glspirv.c

index 7d8e8a290118335dba31eba79de4f48e71c4bb18..7ed697057b4ecc16ac940ea76759c917ec1c0072 100644 (file)
@@ -320,7 +320,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
                struct nir_spirv_specialization *spec_entries = NULL;
                if (spec_info && spec_info->mapEntryCount > 0) {
                        num_spec_entries = spec_info->mapEntryCount;
-                       spec_entries = malloc(num_spec_entries * sizeof(*spec_entries));
+                       spec_entries = calloc(num_spec_entries, sizeof(*spec_entries));
                        for (uint32_t i = 0; i < num_spec_entries; i++) {
                                VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
                                const void *data = spec_info->pData + entry.offset;
@@ -329,16 +329,16 @@ radv_shader_compile_to_nir(struct radv_device *device,
                                spec_entries[i].id = spec_info->pMapEntries[i].constantID;
                                switch (entry.size) {
                                case 8:
-                                       spec_entries[i].data64 = *(const uint64_t *)data;
+                                       spec_entries[i].value.u64 = *(const uint64_t *)data;
                                        break;
                                case 4:
-                                       spec_entries[i].data32 = *(const uint32_t *)data;
+                                       spec_entries[i].value.u32 = *(const uint32_t *)data;
                                        break;
                                case 2:
-                                       spec_entries[i].data32 = *(const uint16_t *)data;
+                                       spec_entries[i].value.u16 = *(const uint16_t *)data;
                                        break;
                                case 1:
-                                       spec_entries[i].data32 = *(const uint8_t *)data;
+                                       spec_entries[i].value.u8 = *(const uint8_t *)data;
                                        break;
                                default:
                                        assert(!"Invalid spec constant size");
index 37fbf351bc9b9a7133034a81b04ad86f01edb190..3d6f74e43ca9337952a7f9743c0d999cb805e6ed 100644 (file)
@@ -37,10 +37,7 @@ extern "C" {
 
 struct nir_spirv_specialization {
    uint32_t id;
-   union {
-      uint32_t data32;
-      uint64_t data64;
-   };
+   nir_const_value value;
    bool defined_on_module;
 };
 
index 6cf43cd2ca7c520ad4add7c87fc799d156c782a6..2ea517077344a990e9dcce5a5305428409624489 100644 (file)
@@ -163,14 +163,6 @@ _vtn_fail(struct vtn_builder *b, const char *file, unsigned line,
    longjmp(b->fail_jump, 1);
 }
 
-struct spec_constant_value {
-   bool is_double;
-   union {
-      uint32_t data32;
-      uint64_t data64;
-   };
-};
-
 static struct vtn_ssa_value *
 vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
 {
@@ -1547,41 +1539,15 @@ spec_constant_decoration_cb(struct vtn_builder *b, UNUSED struct vtn_value *val,
    if (dec->decoration != SpvDecorationSpecId)
       return;
 
-   struct spec_constant_value *const_value = data;
-
+   nir_const_value *value = data;
    for (unsigned i = 0; i < b->num_specializations; i++) {
       if (b->specializations[i].id == dec->operands[0]) {
-         if (const_value->is_double)
-            const_value->data64 = b->specializations[i].data64;
-         else
-            const_value->data32 = b->specializations[i].data32;
+         *value = b->specializations[i].value;
          return;
       }
    }
 }
 
-static uint32_t
-get_specialization(struct vtn_builder *b, struct vtn_value *val,
-                   uint32_t const_value)
-{
-   struct spec_constant_value data;
-   data.is_double = false;
-   data.data32 = const_value;
-   vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
-   return data.data32;
-}
-
-static uint64_t
-get_specialization64(struct vtn_builder *b, struct vtn_value *val,
-                   uint64_t const_value)
-{
-   struct spec_constant_value data;
-   data.is_double = true;
-   data.data64 = const_value;
-   vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
-   return data.data64;
-}
-
 static void
 handle_workgroup_size_decoration_cb(struct vtn_builder *b,
                                     struct vtn_value *val,
@@ -1613,18 +1579,21 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                   "Result type of %s must be OpTypeBool",
                   spirv_op_to_string(opcode));
 
-      uint32_t int_val = (opcode == SpvOpConstantTrue ||
-                          opcode == SpvOpSpecConstantTrue);
+      bool bval = (opcode == SpvOpConstantTrue ||
+                   opcode == SpvOpSpecConstantTrue);
+
+      nir_const_value u32val = nir_const_value_for_uint(bval, 32);
 
       if (opcode == SpvOpSpecConstantTrue ||
           opcode == SpvOpSpecConstantFalse)
-         int_val = get_specialization(b, val, int_val);
+         vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &u32val);
 
-      val->constant->values[0].b = int_val != 0;
+      val->constant->values[0].b = u32val.u32 != 0;
       break;
    }
 
-   case SpvOpConstant: {
+   case SpvOpConstant:
+   case SpvOpSpecConstant: {
       vtn_fail_if(val->type->base_type != vtn_base_type_scalar,
                   "Result type of %s must be a scalar",
                   spirv_op_to_string(opcode));
@@ -1645,31 +1614,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       default:
          vtn_fail("Unsupported SpvOpConstant bit size: %u", bit_size);
       }
-      break;
-   }
 
-   case SpvOpSpecConstant: {
-      vtn_fail_if(val->type->base_type != vtn_base_type_scalar,
-                  "Result type of %s must be a scalar",
-                  spirv_op_to_string(opcode));
-      int bit_size = glsl_get_bit_size(val->type->type);
-      switch (bit_size) {
-      case 64:
-         val->constant->values[0].u64 =
-            get_specialization64(b, val, vtn_u64_literal(&w[3]));
-         break;
-      case 32:
-         val->constant->values[0].u32 = get_specialization(b, val, w[3]);
-         break;
-      case 16:
-         val->constant->values[0].u16 = get_specialization(b, val, w[3]);
-         break;
-      case 8:
-         val->constant->values[0].u8 = get_specialization(b, val, w[3]);
-         break;
-      default:
-         vtn_fail("Unsupported SpvOpSpecConstant bit size");
-      }
+      if (opcode == SpvOpSpecConstant)
+         vtn_foreach_decoration(b, val, spec_constant_decoration_cb,
+                                &val->constant->values[0]);
       break;
    }
 
@@ -1719,7 +1667,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
    }
 
    case SpvOpSpecConstantOp: {
-      SpvOp opcode = get_specialization(b, val, w[3]);
+      nir_const_value u32op = nir_const_value_for_uint(w[3], 32);
+      vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &u32op);
+      SpvOp opcode = u32op.u32;
       switch (opcode) {
       case SpvOpVectorShuffle: {
          struct vtn_value *v0 = &b->values[w[4]];
index 3c674156afe85c22520b95579aedc11b735441c5..33826e37691db5e75fc76aae1e281bcc6c8e06fd 100644 (file)
@@ -53,7 +53,7 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
    struct nir_spirv_specialization *spec = NULL;
    uint32_t num_spec = 0;
    if (spec_info && spec_info->mapEntryCount) {
-      spec = malloc(sizeof(*spec) * spec_info->mapEntryCount);
+      spec = calloc(spec_info->mapEntryCount, sizeof(*spec));
       if (!spec)
          return NULL;
 
@@ -64,16 +64,16 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
          spec[i].id = entry->constantID;
          switch (entry->size) {
          case 8:
-            spec[i].data64 = *(const uint64_t *)data;
+            spec[i].value.u64 = *(const uint64_t *)data;
             break;
          case 4:
-            spec[i].data32 = *(const uint32_t *)data;
+            spec[i].value.u32 = *(const uint32_t *)data;
             break;
          case 2:
-            spec[i].data32 = *(const uint16_t *)data;
+            spec[i].value.u16 = *(const uint16_t *)data;
             break;
          case 1:
-            spec[i].data32 = *(const uint8_t *)data;
+            spec[i].value.u8 = *(const uint8_t *)data;
             break;
          default:
             assert(!"Invalid spec constant size");
index ea2329446667035aa9dbe69c08bcf1ab27fb5044..5c6150d26ff51e40c28a076c302580a012ce0fd0 100644 (file)
@@ -140,7 +140,7 @@ anv_shader_compile_to_nir(struct anv_device *device,
    struct nir_spirv_specialization *spec_entries = NULL;
    if (spec_info && spec_info->mapEntryCount > 0) {
       num_spec_entries = spec_info->mapEntryCount;
-      spec_entries = malloc(num_spec_entries * sizeof(*spec_entries));
+      spec_entries = calloc(num_spec_entries, sizeof(*spec_entries));
       for (uint32_t i = 0; i < num_spec_entries; i++) {
          VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
          const void *data = spec_info->pData + entry.offset;
@@ -149,16 +149,16 @@ anv_shader_compile_to_nir(struct anv_device *device,
          spec_entries[i].id = spec_info->pMapEntries[i].constantID;
          switch (entry.size) {
          case 8:
-            spec_entries[i].data64 = *(const uint64_t *)data;
+            spec_entries[i].value.u64 = *(const uint64_t *)data;
             break;
          case 4:
-            spec_entries[i].data32 = *(const uint32_t *)data;
+            spec_entries[i].value.u32 = *(const uint32_t *)data;
             break;
          case 2:
-            spec_entries[i].data32 = *(const uint16_t *)data;
+            spec_entries[i].value.u16 = *(const uint16_t *)data;
             break;
          case 1:
-            spec_entries[i].data32 = *(const uint8_t *)data;
+            spec_entries[i].value.u8 = *(const uint8_t *)data;
             break;
          default:
             assert(!"Invalid spec constant size");
index 4a8165fa4e16ae937df26223a281d52d355ea5d7..9a6630748167ac79e36a5bed882ce01d92efa414 100644 (file)
@@ -239,7 +239,7 @@ _mesa_spirv_to_nir(struct gl_context *ctx,
 
    for (unsigned i = 0; i < spirv_data->NumSpecializationConstants; ++i) {
       spec_entries[i].id = spirv_data->SpecializationConstantsIndex[i];
-      spec_entries[i].data32 = spirv_data->SpecializationConstantsValue[i];
+      spec_entries[i].value.u32 = spirv_data->SpecializationConstantsValue[i];
       spec_entries[i].defined_on_module = false;
    }
 
@@ -370,7 +370,7 @@ _mesa_SpecializeShaderARB(GLuint shader,
 
    for (unsigned i = 0; i < numSpecializationConstants; ++i) {
       spec_entries[i].id = pConstantIndex[i];
-      spec_entries[i].data32 = pConstantValue[i];
+      spec_entries[i].value.u32 = pConstantValue[i];
       spec_entries[i].defined_on_module = false;
    }