v3d: Add Compute Shader compilation support.
[mesa.git] / src / gallium / drivers / v3d / v3d_program.c
index e3e491e9fd798fdc49adb1379aa0ee746b04c837..7805b808a010edaa712e4d9892e705de6abe5019 100644 (file)
@@ -38,7 +38,8 @@
 #include "broadcom/cle/v3d_packet_v33_pack.h"
 
 static struct v3d_compiled_shader *
-v3d_get_compiled_shader(struct v3d_context *v3d, struct v3d_key *key);
+v3d_get_compiled_shader(struct v3d_context *v3d,
+                        struct v3d_key *key, size_t key_size);
 static void
 v3d_setup_shared_precompile_key(struct v3d_uncompiled_shader *uncompiled,
                                 struct v3d_key *key);
@@ -200,7 +201,7 @@ v3d_shader_precompile(struct v3d_context *v3d,
                 }
 
                 v3d_setup_shared_precompile_key(so, &key.base);
-                v3d_get_compiled_shader(v3d, &key.base);
+                v3d_get_compiled_shader(v3d, &key.base, sizeof(key));
         } else {
                 struct v3d_vs_key key = {
                         .base.shader_state = so,
@@ -223,7 +224,7 @@ v3d_shader_precompile(struct v3d_context *v3d,
                         }
                 }
 
-                v3d_get_compiled_shader(v3d, &key.base);
+                v3d_get_compiled_shader(v3d, &key.base, sizeof(key));
 
                 /* Compile VS bin shader: only position (XXX: include TF) */
                 key.is_coord = true;
@@ -233,13 +234,13 @@ v3d_shader_precompile(struct v3d_context *v3d,
                                 v3d_slot_from_slot_and_component(VARYING_SLOT_POS,
                                                                  i);
                 }
-                v3d_get_compiled_shader(v3d, &key.base);
+                v3d_get_compiled_shader(v3d, &key.base, sizeof(key));
         }
 }
 
 static void *
-v3d_shader_state_create(struct pipe_context *pctx,
-                        const struct pipe_shader_state *cso)
+v3d_uncompiled_shader_create(struct pipe_context *pctx,
+                             enum pipe_shader_ir type, void *ir)
 {
         struct v3d_context *v3d = v3d_context(pctx);
         struct v3d_uncompiled_shader *so = CALLOC_STRUCT(v3d_uncompiled_shader);
@@ -250,21 +251,21 @@ v3d_shader_state_create(struct pipe_context *pctx,
 
         nir_shader *s;
 
-        if (cso->type == PIPE_SHADER_IR_NIR) {
+        if (type == PIPE_SHADER_IR_NIR) {
                 /* The backend takes ownership of the NIR shader on state
                  * creation.
                  */
-                s = cso->ir.nir;
+                s = ir;
         } else {
-                assert(cso->type == PIPE_SHADER_IR_TGSI);
+                assert(type == PIPE_SHADER_IR_TGSI);
 
                 if (V3D_DEBUG & V3D_DEBUG_TGSI) {
                         fprintf(stderr, "prog %d TGSI:\n",
                                 so->program_id);
-                        tgsi_dump(cso->tokens, 0);
+                        tgsi_dump(ir, 0);
                         fprintf(stderr, "\n");
                 }
-                s = tgsi_to_nir(cso->tokens, pctx->screen);
+                s = tgsi_to_nir(ir, pctx->screen);
         }
 
         nir_variable_mode lower_mode = nir_var_all & ~nir_var_uniform;
@@ -289,8 +290,6 @@ v3d_shader_state_create(struct pipe_context *pctx,
         so->base.type = PIPE_SHADER_IR_NIR;
         so->base.ir.nir = s;
 
-        v3d_set_transform_feedback_outputs(so, &cso->stream_output);
-
         if (V3D_DEBUG & (V3D_DEBUG_NIR |
                          v3d_debug_flag_for_shader_stage(s->info.stage))) {
                 fprintf(stderr, "%s prog %d NIR:\n",
@@ -314,22 +313,31 @@ v3d_shader_debug_output(const char *message, void *data)
         pipe_debug_message(&v3d->debug, SHADER_INFO, "%s", message);
 }
 
-static struct v3d_compiled_shader *
-v3d_get_compiled_shader(struct v3d_context *v3d, struct v3d_key *key)
+static void *
+v3d_shader_state_create(struct pipe_context *pctx,
+                        const struct pipe_shader_state *cso)
+{
+        struct v3d_uncompiled_shader *so =
+                v3d_uncompiled_shader_create(pctx,
+                                             cso->type,
+                                             (cso->type == PIPE_SHADER_IR_TGSI ?
+                                              (void *)cso->tokens :
+                                              cso->ir.nir));
+
+        v3d_set_transform_feedback_outputs(so, &cso->stream_output);
+
+        return so;
+}
+
+struct v3d_compiled_shader *
+v3d_get_compiled_shader(struct v3d_context *v3d,
+                        struct v3d_key *key,
+                        size_t key_size)
 {
         struct v3d_uncompiled_shader *shader_state = key->shader_state;
         nir_shader *s = shader_state->base.ir.nir;
 
-        struct hash_table *ht;
-        uint32_t key_size;
-        if (s->info.stage == MESA_SHADER_FRAGMENT) {
-                ht = v3d->fs_cache;
-                key_size = sizeof(struct v3d_fs_key);
-        } else {
-                ht = v3d->vs_cache;
-                key_size = sizeof(struct v3d_vs_key);
-        }
-
+        struct hash_table *ht = v3d->prog.cache[s->info.stage];
         struct hash_entry *entry = _mesa_hash_table_search(ht, key);
         if (entry)
                 return entry->data;
@@ -359,10 +367,12 @@ v3d_get_compiled_shader(struct v3d_context *v3d, struct v3d_key *key)
 
         free(qpu_insts);
 
-        struct v3d_key *dup_key;
-        dup_key = ralloc_size(shader, key_size);
-        memcpy(dup_key, key, key_size);
-        _mesa_hash_table_insert(ht, dup_key, shader);
+        if (ht) {
+                struct v3d_key *dup_key;
+                dup_key = ralloc_size(shader, key_size);
+                memcpy(dup_key, key, key_size);
+                _mesa_hash_table_insert(ht, dup_key, shader);
+        }
 
         if (shader->prog_data.base->spill_size >
             v3d->prog.spill_size_per_thread) {
@@ -446,8 +456,6 @@ v3d_setup_shared_key(struct v3d_context *v3d, struct v3d_key *key,
                                 sampler_state->wrap_r == PIPE_TEX_WRAP_CLAMP;
                 }
         }
-
-        key->ucp_enables = v3d->rasterizer->base.clip_plane_enable;
 }
 
 static void
@@ -489,6 +497,7 @@ v3d_update_compiled_fs(struct v3d_context *v3d, uint8_t prim_mode)
         memset(key, 0, sizeof(*key));
         v3d_setup_shared_key(v3d, &key->base, &v3d->tex[PIPE_SHADER_FRAGMENT]);
         key->base.shader_state = v3d->prog.bind_fs;
+        key->base.ucp_enables = v3d->rasterizer->base.clip_plane_enable;
         key->is_points = (prim_mode == PIPE_PRIM_POINTS);
         key->is_lines = (prim_mode >= PIPE_PRIM_LINES &&
                          prim_mode <= PIPE_PRIM_LINE_STRIP);
@@ -554,7 +563,7 @@ v3d_update_compiled_fs(struct v3d_context *v3d, uint8_t prim_mode)
         key->shade_model_flat = v3d->rasterizer->base.flatshade;
 
         struct v3d_compiled_shader *old_fs = v3d->prog.fs;
-        v3d->prog.fs = v3d_get_compiled_shader(v3d, &key->base);
+        v3d->prog.fs = v3d_get_compiled_shader(v3d, &key->base, sizeof(*key));
         if (v3d->prog.fs == old_fs)
                 return;
 
@@ -602,6 +611,7 @@ v3d_update_compiled_vs(struct v3d_context *v3d, uint8_t prim_mode)
         memset(key, 0, sizeof(*key));
         v3d_setup_shared_key(v3d, &key->base, &v3d->tex[PIPE_SHADER_VERTEX]);
         key->base.shader_state = v3d->prog.bind_vs;
+        key->base.ucp_enables = v3d->rasterizer->base.clip_plane_enable;
         key->num_fs_inputs = v3d->prog.fs->prog_data.fs->num_inputs;
         STATIC_ASSERT(sizeof(key->fs_inputs) ==
                       sizeof(v3d->prog.fs->prog_data.fs->input_slots));
@@ -614,7 +624,7 @@ v3d_update_compiled_vs(struct v3d_context *v3d, uint8_t prim_mode)
                  v3d->rasterizer->base.point_size_per_vertex);
 
         struct v3d_compiled_shader *vs =
-                v3d_get_compiled_shader(v3d, &key->base);
+                v3d_get_compiled_shader(v3d, &key->base, sizeof(*key));
         if (vs != v3d->prog.vs) {
                 v3d->prog.vs = vs;
                 v3d->dirty |= VC5_DIRTY_COMPILED_VS;
@@ -634,7 +644,7 @@ v3d_update_compiled_vs(struct v3d_context *v3d, uint8_t prim_mode)
         key->num_fs_inputs = shader_state->num_tf_outputs;
 
         struct v3d_compiled_shader *cs =
-                v3d_get_compiled_shader(v3d, &key->base);
+                v3d_get_compiled_shader(v3d, &key->base, sizeof(*key));
         if (cs != v3d->prog.cs) {
                 v3d->prog.cs = cs;
                 v3d->dirty |= VC5_DIRTY_COMPILED_CS;
@@ -648,6 +658,30 @@ v3d_update_compiled_shaders(struct v3d_context *v3d, uint8_t prim_mode)
         v3d_update_compiled_vs(v3d, prim_mode);
 }
 
+void
+v3d_update_compiled_cs(struct v3d_context *v3d)
+{
+        struct v3d_key local_key;
+        struct v3d_key *key = &local_key;
+
+        if (!(v3d->dirty & (~0 | /* XXX */
+                            VC5_DIRTY_VERTTEX |
+                            VC5_DIRTY_UNCOMPILED_FS))) {
+                return;
+        }
+
+        memset(key, 0, sizeof(*key));
+        v3d_setup_shared_key(v3d, key, &v3d->tex[PIPE_SHADER_COMPUTE]);
+        key->shader_state = v3d->prog.bind_compute;
+
+        struct v3d_compiled_shader *cs =
+                v3d_get_compiled_shader(v3d, key, sizeof(*key));
+        if (cs != v3d->prog.compute) {
+                v3d->prog.compute = cs;
+                v3d->dirty |= VC5_DIRTY_COMPILED_CS; /* XXX */
+        }
+}
+
 static uint32_t
 fs_cache_hash(const void *key)
 {
@@ -660,6 +694,12 @@ vs_cache_hash(const void *key)
         return _mesa_hash_data(key, sizeof(struct v3d_vs_key));
 }
 
+static uint32_t
+cs_cache_hash(const void *key)
+{
+        return _mesa_hash_data(key, sizeof(struct v3d_key));
+}
+
 static bool
 fs_cache_compare(const void *key1, const void *key2)
 {
@@ -672,23 +712,10 @@ vs_cache_compare(const void *key1, const void *key2)
         return memcmp(key1, key2, sizeof(struct v3d_vs_key)) == 0;
 }
 
-static void
-delete_from_cache_if_matches(struct hash_table *ht,
-                             struct v3d_compiled_shader **last_compile,
-                             struct hash_entry *entry,
-                             struct v3d_uncompiled_shader *so)
+static bool
+cs_cache_compare(const void *key1, const void *key2)
 {
-        const struct v3d_key *key = entry->key;
-
-        if (key->shader_state == so) {
-                struct v3d_compiled_shader *shader = entry->data;
-                _mesa_hash_table_remove(ht, entry);
-
-                if (shader == *last_compile)
-                        *last_compile = NULL;
-
-                v3d_free_compiled_shader(shader);
-        }
+        return memcmp(key1, key2, sizeof(struct v3d_key)) == 0;
 }
 
 static void
@@ -696,14 +723,26 @@ v3d_shader_state_delete(struct pipe_context *pctx, void *hwcso)
 {
         struct v3d_context *v3d = v3d_context(pctx);
         struct v3d_uncompiled_shader *so = hwcso;
+        nir_shader *s = so->base.ir.nir;
 
-        hash_table_foreach(v3d->fs_cache, entry) {
-                delete_from_cache_if_matches(v3d->fs_cache, &v3d->prog.fs,
-                                             entry, so);
-        }
-        hash_table_foreach(v3d->vs_cache, entry) {
-                delete_from_cache_if_matches(v3d->vs_cache, &v3d->prog.vs,
-                                             entry, so);
+        hash_table_foreach(v3d->prog.cache[s->info.stage], entry) {
+                const struct v3d_key *key = entry->key;
+                struct v3d_compiled_shader *shader = entry->data;
+
+                if (key->shader_state != so)
+                        continue;
+
+                if (v3d->prog.fs == shader)
+                        v3d->prog.fs = NULL;
+                if (v3d->prog.vs == shader)
+                        v3d->prog.vs = NULL;
+                if (v3d->prog.cs == shader)
+                        v3d->prog.cs = NULL;
+                if (v3d->prog.compute == shader)
+                        v3d->prog.compute = NULL;
+
+                _mesa_hash_table_remove(v3d->prog.cache[s->info.stage], entry);
+                v3d_free_compiled_shader(shader);
         }
 
         ralloc_free(so->base.ir.nir);
@@ -726,6 +765,22 @@ v3d_vp_state_bind(struct pipe_context *pctx, void *hwcso)
         v3d->dirty |= VC5_DIRTY_UNCOMPILED_VS;
 }
 
+static void
+v3d_compute_state_bind(struct pipe_context *pctx, void *state)
+{
+        struct v3d_context *v3d = v3d_context(pctx);
+
+        v3d->prog.bind_compute = state;
+}
+
+static void *
+v3d_create_compute_state(struct pipe_context *pctx,
+                         const struct pipe_compute_state *cso)
+{
+        return v3d_uncompiled_shader_create(pctx, cso->ir_type,
+                                            (void *)cso->prog);
+}
+
 void
 v3d_program_init(struct pipe_context *pctx)
 {
@@ -740,10 +795,18 @@ v3d_program_init(struct pipe_context *pctx)
         pctx->bind_fs_state = v3d_fp_state_bind;
         pctx->bind_vs_state = v3d_vp_state_bind;
 
-        v3d->fs_cache = _mesa_hash_table_create(pctx, fs_cache_hash,
-                                                fs_cache_compare);
-        v3d->vs_cache = _mesa_hash_table_create(pctx, vs_cache_hash,
-                                                vs_cache_compare);
+        if (v3d->screen->has_csd) {
+                pctx->create_compute_state = v3d_create_compute_state;
+                pctx->delete_compute_state = v3d_shader_state_delete;
+                pctx->bind_compute_state = v3d_compute_state_bind;
+        }
+
+        v3d->prog.cache[MESA_SHADER_VERTEX] =
+                _mesa_hash_table_create(pctx, vs_cache_hash, vs_cache_compare);
+        v3d->prog.cache[MESA_SHADER_FRAGMENT] =
+                _mesa_hash_table_create(pctx, fs_cache_hash, fs_cache_compare);
+        v3d->prog.cache[MESA_SHADER_COMPUTE] =
+                _mesa_hash_table_create(pctx, cs_cache_hash, cs_cache_compare);
 }
 
 void
@@ -751,16 +814,16 @@ v3d_program_fini(struct pipe_context *pctx)
 {
         struct v3d_context *v3d = v3d_context(pctx);
 
-        hash_table_foreach(v3d->fs_cache, entry) {
-                struct v3d_compiled_shader *shader = entry->data;
-                v3d_free_compiled_shader(shader);
-                _mesa_hash_table_remove(v3d->fs_cache, entry);
-        }
+        for (int i = 0; i < MESA_SHADER_STAGES; i++) {
+                struct hash_table *cache = v3d->prog.cache[i];
+                if (!cache)
+                        continue;
 
-        hash_table_foreach(v3d->vs_cache, entry) {
-                struct v3d_compiled_shader *shader = entry->data;
-                v3d_free_compiled_shader(shader);
-                _mesa_hash_table_remove(v3d->vs_cache, entry);
+                hash_table_foreach(cache, entry) {
+                        struct v3d_compiled_shader *shader = entry->data;
+                        v3d_free_compiled_shader(shader);
+                        _mesa_hash_table_remove(cache, entry);
+                }
         }
 
         v3d_bo_unreference(&v3d->prog.spill_bo);