if (so->key.has_gs || so->key.tessellation) {
                switch (so->shader->type) {
                case MESA_SHADER_VERTEX:
-                       NIR_PASS_V(s, ir3_nir_lower_to_explicit_output, so->shader, so->key.tessellation);
+                       NIR_PASS_V(s, ir3_nir_lower_to_explicit_output, so, so->key.tessellation);
                        progress = true;
                        break;
                case MESA_SHADER_TESS_CTRL:
-                       NIR_PASS_V(s, ir3_nir_lower_tess_ctrl, so->shader, so->key.tessellation);
+                       NIR_PASS_V(s, ir3_nir_lower_tess_ctrl, so, so->key.tessellation);
                        NIR_PASS_V(s, ir3_nir_lower_to_explicit_input);
                        progress = true;
                        break;
                case MESA_SHADER_TESS_EVAL:
                        NIR_PASS_V(s, ir3_nir_lower_tess_eval, so->key.tessellation);
                        if (so->key.has_gs)
-                               NIR_PASS_V(s, ir3_nir_lower_to_explicit_output, so->shader, so->key.tessellation);
+                               NIR_PASS_V(s, ir3_nir_lower_to_explicit_output, so, so->key.tessellation);
                        progress = true;
                        break;
                case MESA_SHADER_GEOMETRY:
 
 
 
 void ir3_nir_lower_to_explicit_output(nir_shader *shader,
-               struct ir3_shader *s, unsigned topology);
+               struct ir3_shader_variant *v, unsigned topology);
 void ir3_nir_lower_to_explicit_input(nir_shader *shader);
-void ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader *s, unsigned topology);
+void ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader_variant *v, unsigned topology);
 void ir3_nir_lower_tess_eval(nir_shader *shader, unsigned topology);
 void ir3_nir_lower_gs(nir_shader *shader);
 
 
 }
 
 void
-ir3_nir_lower_to_explicit_output(nir_shader *shader, struct ir3_shader *s, unsigned topology)
+ir3_nir_lower_to_explicit_output(nir_shader *shader, struct ir3_shader_variant *v,
+               unsigned topology)
 {
        struct state state = { };
 
        build_primitive_map(shader, &state.map, &shader->outputs);
-       memcpy(s->output_loc, state.map.loc, sizeof(s->output_loc));
+       memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
 
        nir_function_impl *impl = nir_shader_get_entrypoint(shader);
        assert(impl);
        nir_builder_init(&b, impl);
        b.cursor = nir_before_cf_list(&impl->body);
 
-       if (s->type == MESA_SHADER_VERTEX && topology != IR3_TESS_NONE)
+       if (v->type == MESA_SHADER_VERTEX && topology != IR3_TESS_NONE)
                state.header = nir_load_tcs_header_ir3(&b);
        else
                state.header = nir_load_gs_header_ir3(&b);
        nir_metadata_preserve(impl, nir_metadata_block_index |
                        nir_metadata_dominance);
 
-       s->output_size = state.map.stride;
+       v->output_size = state.map.stride;
 }
 
 
 }
 
 void
-ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader *s, unsigned topology)
+ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader_variant *v,
+               unsigned topology)
 {
        struct state state = { .topology = topology };
 
        }
 
        build_primitive_map(shader, &state.map, &shader->outputs);
-       memcpy(s->output_loc, state.map.loc, sizeof(s->output_loc));
-       s->output_size = state.map.stride;
+       memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
+       v->output_size = state.map.stride;
 
        nir_function_impl *impl = nir_shader_get_entrypoint(shader);
        assert(impl);
                nir_foreach_variable(out_var, &producer->shader->nir->outputs) {
                        if (in_var->data.location == out_var->data.location) {
                                locs[in_var->data.driver_location] =
-                                       producer->shader->output_loc[out_var->data.driver_location] * factor;
+                                       producer->output_loc[out_var->data.driver_location] * factor;
 
                                debug_assert(num_loc <= in_var->data.driver_location + 1);
                                num_loc = in_var->data.driver_location + 1;
 
        } outputs[32 + 2];  /* +POSITION +PSIZE */
        bool writes_pos, writes_smask, writes_psize;
 
+       /* Size in dwords of all outputs for VS, size of entire patch for HS. */
+       uint32_t output_size;
+
+       /* Map from driver_location to byte offset in per-primitive storage */
+       unsigned output_loc[32];
+
        /* attributes (VS) / varyings (FS):
         * Note that sysval's should come *after* normal inputs.
         */
        struct ir3_shader_variant *variants;
        mtx_t variants_lock;
 
-       uint32_t output_size; /* Size in dwords of all outputs for VS, size of entire patch for HS. */
-
-       /* Map from driver_location to byte offset in per-primitive storage */
-       unsigned output_loc[32];
-
        /* Bitmask of bits of the shader key used by this shader.  Used to avoid
         * recompiles for GL NOS that doesn't actually apply to the shader.
         */
 
          invocations = gs->shader->nir->info.gs.invocations - 1;
          /* Size of per-primitive alloction in ldlw memory in vec4s. */
          vec4_size = gs->shader->nir->info.gs.vertices_in *
-                     DIV_ROUND_UP(vs->shader->output_size, 4);
+                     DIV_ROUND_UP(vs->output_size, 4);
       } else {
          vertices_out = 3;
          output = TESS_CW_TRIS;
       tu_cs_emit(cs, 0);
 
       tu_cs_emit_pkt4(cs, REG_A6XX_SP_GS_PRIM_SIZE, 1);
-      tu_cs_emit(cs, vs->shader->output_size);
+      tu_cs_emit(cs, vs->output_size);
    }
 
    tu_cs_emit_pkt4(cs, REG_A6XX_SP_PRIMITIVE_CNTL, 1);
    unsigned num_vertices = gs->shader->nir->info.gs.vertices_in;
 
    uint32_t params[4] = {
-      vs->shader->output_size * num_vertices * 4,  /* primitive stride */
-      vs->shader->output_size * 4,                 /* vertex stride */
+      vs->output_size * num_vertices * 4,  /* primitive stride */
+      vs->output_size * 4,                 /* vertex stride */
       0,
       0,
    };
 
                emit->gs->shader->nir->info.gs.vertices_in;
 
        uint32_t vs_params[4] = {
-               emit->vs->shader->output_size * num_vertices * 4,       /* vs primitive stride */
-               emit->vs->shader->output_size * 4,                                      /* vs vertex stride */
+               emit->vs->output_size * num_vertices * 4,       /* vs primitive stride */
+               emit->vs->output_size * 4,                                      /* vs vertex stride */
                0,
                0
        };
 
        if (emit->hs) {
                uint32_t hs_params[4] = {
-                       emit->vs->shader->output_size * num_vertices * 4,       /* vs primitive stride */
-                       emit->vs->shader->output_size * 4,                                      /* vs vertex stride */
-                       emit->hs->shader->output_size,
+                       emit->vs->output_size * num_vertices * 4,       /* vs primitive stride */
+                       emit->vs->output_size * 4,                                      /* vs vertex stride */
+                       emit->hs->output_size,
                        emit->info->vertices_per_patch
                };
 
                        num_vertices = emit->gs->shader->nir->info.gs.vertices_in;
 
                uint32_t ds_params[4] = {
-                       emit->ds->shader->output_size * num_vertices * 4,       /* ds primitive stride */
-                       emit->ds->shader->output_size * 4,                                      /* ds vertex stride */
-                       emit->hs->shader->output_size,                      /* hs vertex stride (dwords) */
+                       emit->ds->output_size * num_vertices * 4,       /* ds primitive stride */
+                       emit->ds->output_size * 4,                                      /* ds vertex stride */
+                       emit->hs->output_size,                      /* hs vertex stride (dwords) */
                        emit->hs->shader->nir->info.tess.tcs_vertices_out
                };
 
                        prev = emit->vs;
 
                uint32_t gs_params[4] = {
-                       prev->shader->output_size * num_vertices * 4,   /* ds primitive stride */
-                       prev->shader->output_size * 4,                                  /* ds vertex stride */
+                       prev->output_size * num_vertices * 4,   /* ds primitive stride */
+                       prev->output_size * 4,                                  /* ds vertex stride */
                        0,
                        0,
                };
 
 
                ctx->batch->tessellation = true;
                ctx->batch->tessparam_size = MAX2(ctx->batch->tessparam_size,
-                               emit.hs->shader->output_size * 4 * info->count);
+                               emit.hs->output_size * 4 * info->count);
                ctx->batch->tessfactor_size = MAX2(ctx->batch->tessfactor_size,
                                factor_stride * info->count);
 
 
 
                /* Total attribute slots in HS incoming patch. */
                OUT_PKT4(ring, REG_A6XX_PC_UNKNOWN_9801, 1);
-               OUT_RING(ring, hs_info->tess.tcs_vertices_out * vs->shader->output_size / 4);
+               OUT_RING(ring, hs_info->tess.tcs_vertices_out * vs->output_size / 4);
 
                OUT_PKT4(ring, REG_A6XX_SP_HS_UNKNOWN_A831, 1);
-               OUT_RING(ring, vs->shader->output_size);
+               OUT_RING(ring, vs->output_size);
 
                shader_info *ds_info = &ds->shader->nir->info;
                OUT_PKT4(ring, REG_A6XX_PC_TESS_CNTL, 1);
                /* Size of per-primitive alloction in ldlw memory in vec4s. */
                uint32_t vec4_size =
                        gs->shader->nir->info.gs.vertices_in *
-                       DIV_ROUND_UP(prev->shader->output_size, 4);
+                       DIV_ROUND_UP(prev->output_size, 4);
                OUT_PKT4(ring, REG_A6XX_PC_PRIMITIVE_CNTL_6, 1);
                OUT_RING(ring, A6XX_PC_PRIMITIVE_CNTL_6_STRIDE_IN_VPC(vec4_size));
 
                OUT_RING(ring, 0);
 
                OUT_PKT4(ring, REG_A6XX_SP_GS_PRIM_SIZE, 1);
-               OUT_RING(ring, prev->shader->output_size);
+               OUT_RING(ring, prev->output_size);
        } else {
                OUT_PKT4(ring, REG_A6XX_PC_PRIMITIVE_CNTL_6, 1);
                OUT_RING(ring, 0);