zink: adjust zink_shader struct to contain full streamout info
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 68a5d00a3506c3233e773cdea859e9f1586b432e..a4106e7439c3d921a3fb7c55b33991556b508924 100644 (file)
@@ -889,11 +889,12 @@ get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_c
    from complete outputs, so we just can't use the created packed outputs */
 static void
 emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
-             const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
+             const struct zink_so_info *so_info)
 {
-   for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
-      struct pipe_stream_output so_output = local_so_info->output[i];
-      SpvId out_type = get_output_type(ctx, so_output.register_index, so_output.num_components);
+   for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
+      struct pipe_stream_output so_output = so_info->so_info.output[i];
+      unsigned slot = so_info->so_info_slots[i];
+      SpvId out_type = get_output_type(ctx, slot, so_output.num_components);
       SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                       SpvStorageClassOutput,
                                                       out_type);
@@ -905,7 +906,7 @@ emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
       spirv_builder_emit_name(&ctx->builder, var_id, name);
       spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
       spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
-      spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->stride[so_output.output_buffer] * 4);
+      spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->so_info.stride[so_output.output_buffer] * 4);
 
       /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
        * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
@@ -939,32 +940,33 @@ emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
 
 static void
 emit_so_outputs(struct ntv_context *ctx,
-                const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
+                const struct zink_so_info *so_info)
 {
    SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
-   for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
+   for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
       uint32_t components[NIR_MAX_VEC_COMPONENTS];
-      struct pipe_stream_output so_output = local_so_info->output[i];
+      unsigned slot = so_info->so_info_slots[i];
+      struct pipe_stream_output so_output = so_info->so_info.output[i];
       uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
       struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
       assert(he);
       SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
 
-      SpvId type = get_output_type(ctx, so_output.register_index, so_output.num_components);
-      SpvId output = ctx->outputs[so_output.register_index];
-      SpvId output_type = ctx->so_output_types[so_output.register_index];
-      const struct glsl_type *out_type = ctx->so_output_gl_types[so_output.register_index];
+      SpvId type = get_output_type(ctx, slot, so_output.num_components);
+      SpvId output = ctx->outputs[slot];
+      SpvId output_type = ctx->so_output_types[slot];
+      const struct glsl_type *out_type = ctx->so_output_gl_types[slot];
 
-      if (!loaded_outputs[so_output.register_index])
-         loaded_outputs[so_output.register_index] = spirv_builder_emit_load(&ctx->builder, output_type, output);
-      SpvId src = loaded_outputs[so_output.register_index];
+      if (!loaded_outputs[slot])
+         loaded_outputs[slot] = spirv_builder_emit_load(&ctx->builder, output_type, output);
+      SpvId src = loaded_outputs[slot];
 
       SpvId result;
 
       for (unsigned c = 0; c < so_output.num_components; c++) {
          components[c] = so_output.start_component + c;
          /* this is the second half of a 2 * vec4 array */
-         if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
+         if (ctx->stage == MESA_SHADER_VERTEX && slot == VARYING_SLOT_CLIP_DIST1)
             components[c] += 4;
       }
 
@@ -991,7 +993,7 @@ emit_so_outputs(struct ntv_context *ctx,
                 uint32_t member[] = { so_output.start_component + c };
                 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
 
-                if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
+                if (ctx->stage == MESA_SHADER_VERTEX && slot == VARYING_SLOT_CLIP_DIST1)
                    member[0] += 4;
                 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
              }
@@ -2227,7 +2229,7 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
 }
 
 struct spirv_shader *
-nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
+nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info)
 {
    struct spirv_shader *ret = NULL;
 
@@ -2312,7 +2314,7 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
       emit_output(&ctx, var);
 
    if (so_info)
-      emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info, local_so_info);
+      emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info);
    nir_foreach_variable_with_modes(var, s, nir_var_uniform |
                                            nir_var_mem_ubo |
                                            nir_var_mem_ssbo)
@@ -2326,7 +2328,7 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
                                       SpvExecutionModeDepthReplacing);
    }
 
-   if (so_info && so_info->num_outputs) {
+   if (so_info && so_info->so_info.num_outputs) {
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
       spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
                                    SpvExecutionModeXfb);
@@ -2379,7 +2381,7 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    emit_cf_list(&ctx, &entry->body);
 
    if (so_info)
-      emit_so_outputs(&ctx, so_info, local_so_info);
+      emit_so_outputs(&ctx, so_info);
 
    spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
    spirv_builder_function_end(&ctx.builder);