zink: implement support for derivative-control
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 922ed8c81f558c7b28dcff3cc632a8b9300f28ce..af1ddef11bdf0f6c80abf3bbe5d77ae896a46d04 100644 (file)
@@ -57,7 +57,7 @@ struct ntv_context {
    bool block_started;
    SpvId loop_break, loop_cont;
 
-   SpvId front_face_var, vertex_id_var;
+   SpvId front_face_var, instance_id_var, vertex_id_var;
 };
 
 static SpvId
@@ -386,13 +386,15 @@ type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
 static void
 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
 {
+   const struct glsl_type *type = glsl_without_array(var->type);
+
    bool is_ms;
-   SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
+   SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
 
-   SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(var->type));
+   SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
    SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
                                                dimension, false,
-                                               glsl_sampler_type_is_array(var->type),
+                                               glsl_sampler_type_is_array(type),
                                                is_ms, 1,
                                                SpvImageFormatUnknown);
 
@@ -401,21 +403,45 @@ emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassUniformConstant,
                                                    sampled_type);
-   SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
-                                         SpvStorageClassUniformConstant);
 
-   if (var->name)
-      spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+   if (glsl_type_is_array(var->type)) {
+      for (int i = 0; i < glsl_get_length(var->type); ++i) {
+         SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
+                                               SpvStorageClassUniformConstant);
 
-   assert(ctx->num_samplers < ARRAY_SIZE(ctx->image_types));
-   ctx->image_types[ctx->num_samplers] = image_type;
+         if (var->name) {
+            char element_name[100];
+            snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
+            spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+         }
 
-   assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
-   ctx->samplers[ctx->num_samplers++] = var_id;
+         assert(ctx->num_samplers < ARRAY_SIZE(ctx->image_types));
+         ctx->image_types[ctx->num_samplers] = image_type;
 
-   spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
-                                     var->data.descriptor_set);
-   spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
+         assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
+         ctx->samplers[ctx->num_samplers++] = var_id;
+
+         spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
+                                           var->data.descriptor_set);
+         spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
+      }
+   } else {
+      SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
+                                            SpvStorageClassUniformConstant);
+
+      if (var->name)
+         spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+
+      assert(ctx->num_samplers < ARRAY_SIZE(ctx->image_types));
+      ctx->image_types[ctx->num_samplers] = image_type;
+
+      assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
+      ctx->samplers[ctx->num_samplers++] = var_id;
+
+      spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
+                                        var->data.descriptor_set);
+      spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
+   }
 }
 
 static void
@@ -465,7 +491,7 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
       emit_ubo(ctx, var);
    else {
       assert(var->data.mode == nir_var_uniform);
-      if (glsl_type_is_sampler(var->type))
+      if (glsl_type_is_sampler(glsl_without_array(var->type)))
          emit_sampler(ctx, var);
    }
 }
@@ -885,7 +911,11 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    UNOP(nir_op_ineg, SpvOpSNegate)
    UNOP(nir_op_fneg, SpvOpFNegate)
    UNOP(nir_op_fddx, SpvOpDPdx)
+   UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
+   UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
    UNOP(nir_op_fddy, SpvOpDPdy)
+   UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
+   UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
    UNOP(nir_op_f2i32, SpvOpConvertFToS)
    UNOP(nir_op_f2u32, SpvOpConvertFToU)
    UNOP(nir_op_i2f32, SpvOpConvertSToF)
@@ -1287,6 +1317,22 @@ emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    store_dest_uint(ctx, &intr->dest, result);
 }
 
+static void
+emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
+   if (!ctx->instance_id_var)
+      ctx->instance_id_var = create_builtin_var(ctx, var_type,
+                                               SpvStorageClassInput,
+                                               "gl_InstanceId",
+                                               SpvBuiltInInstanceIndex);
+
+   SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
+                                          ctx->instance_id_var);
+   assert(1 == nir_dest_num_components(intr->dest));
+   store_dest_uint(ctx, &intr->dest, result);
+}
+
 static void
 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
@@ -1327,6 +1373,10 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       emit_load_front_face(ctx, intr);
       break;
 
+   case nir_intrinsic_load_instance_id:
+      emit_load_instance_id(ctx, intr);
+      break;
+
    case nir_intrinsic_load_vertex_id:
       emit_load_vertex_id(ctx, intr);
       break;
@@ -1800,6 +1850,7 @@ nir_to_spirv(struct nir_shader *s)
    if (s->info.stage == MESA_SHADER_FRAGMENT) {
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
+      spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
    }
 
    ctx.stage = s->info.stage;