MayaFlux 0.1.0
Digital-First Multimedia Processing Framework
Loading...
Searching...
No Matches
ShaderProcessor.hpp
Go to the documentation of this file.
1#pragma once
2
5
6namespace MayaFlux::Buffers {
7
8/**
9 * @struct ShaderBinding
10 * @brief Describes how a VKBuffer binds to a shader descriptor
11 */
13 uint32_t set = 0; ///< Descriptor set index
14 uint32_t binding = 0; ///< Binding point within set
15 vk::DescriptorType type = vk::DescriptorType::eStorageBuffer;
16
17 ShaderBinding() = default;
18 ShaderBinding(uint32_t s, uint32_t b, vk::DescriptorType t = vk::DescriptorType::eStorageBuffer)
19 : set(s)
20 , binding(b)
21 , type(t)
22 {
23 }
24};
25
26/**
27 * @struct ShaderDispatchConfig
28 * @brief Configuration for compute shader dispatch
29 */
31 uint32_t workgroup_x = 256; ///< Workgroup size X (should match shader)
32 uint32_t workgroup_y = 1;
33 uint32_t workgroup_z = 1;
34
35 enum class DispatchMode : uint8_t {
36 ELEMENT_COUNT, ///< Calculate from buffer element count
37 MANUAL, ///< Use explicit group counts
38 BUFFER_SIZE, ///< Calculate from buffer byte size
39 CUSTOM ///< User-provided calculation function
40 } mode
42
43 // Manual dispatch (MANUAL mode)
44 uint32_t group_count_x = 1;
45 uint32_t group_count_y = 1;
46 uint32_t group_count_z = 1;
47
48 std::function<std::array<uint32_t, 3>(const std::shared_ptr<VKBuffer>&)> custom_calculator;
49
51};
52
53/**
54 * @struct ShaderProcessorConfig
55 * @brief Complete configuration for shader processor
56 */
58 std::string shader_path; ///< Path to shader file
60 std::string entry_point = "main";
61
63
64 std::unordered_map<std::string, ShaderBinding> bindings;
65
67
68 std::unordered_map<uint32_t, uint32_t> specialization_constants;
69
71 ShaderProcessorConfig(std::string path)
72 : shader_path(std::move(path))
73 {
74 }
75};
76
77/**
78 * @class ShaderProcessor
79 * @brief Generic compute shader processor for VKBuffers
80 *
81 * ShaderProcessor is a fully functional base class that:
82 * - Loads compute shaders via Portal::Graphics::ShaderFoundry
83 * - Automatically creates compute pipelines and descriptor sets
84 * - Binds VKBuffers to shader descriptors with configurable mappings
85 * - Dispatches compute shaders with flexible workgroup calculation
86 * - Supports hot-reload via ShaderFoundry caching
87 * - Handles push constants and specialization constants
88 *
89 * Quality-of-life features:
90 * - **Data movement hints:** Query buffer usage (input/output/in-place) for automation and validation.
91 * - **Binding introspection:** Check if bindings exist, list expected bindings, and validate binding completeness.
92 * - **State queries:** Track last processed buffer and command buffer for chain management and debugging.
93 *
94 * Design Philosophy:
95 * - **Fully usable as-is**: Not just a base class, but a complete processor
96 * - **Inheritance-friendly**: Specialized processors can override behavior
97 * - **Buffer-agnostic**: Works with any VKBuffer modality/usage
98 * - **Flexible binding**: Map buffers to shader descriptors by name
99 * - **GPU-efficient**: Uses device-local buffers and staging where needed
100 *
101 * Integration:
102 * - Uses Portal::Graphics::ShaderFoundry for shader compilation
103 * - Leverages VKComputePipeline for execution
104 * - Works with existing BufferManager/ProcessingChain architecture
105 * - Compatible with all VKBuffer usage types (COMPUTE, STORAGE, etc.)
106 *
107 * Usage:
108 * // Simple usage - single buffer processor
109 * auto processor = std::make_shared<ShaderProcessor>("shaders/kernel.comp");
110 * processor->bind_buffer("input_buffer", my_buffer);
111 * my_buffer->set_default_processor(processor);
112 *
113 * // Advanced - multi-buffer with explicit bindings
114 * ShaderProcessorConfig config("shaders/complex.comp");
115 * config.bindings["input"] = ShaderBinding(0, 0);
116 * config.bindings["output"] = ShaderBinding(0, 1);
117 * config.dispatch.workgroup_x = 512;
118 *
119 * auto processor = std::make_shared<ShaderProcessor>(config);
120 * processor->bind_buffer("input", input_buffer);
121 * processor->bind_buffer("output", output_buffer);
122 *
123 * chain->add_processor(processor, input_buffer);
124 * chain->add_processor(processor, output_buffer);
125 *
126 * // With push constants
127 * struct Params { float scale; uint32_t iterations; };
128 * processor->set_push_constant_size<Params>();
129 * processor->set_push_constant_data(Params{2.0f, 100});
130 *
131 * Specialized Processors:
132 * class FFTProcessor : public ShaderProcessor {
133 * FFTProcessor() : ShaderProcessor("shaders/fft.comp") {
134 * configure_fft_bindings();
135 * }
136 *
137 * void on_attach(std::shared_ptr<Buffer> buffer) override {
138 * ShaderProcessor::on_attach(buffer);
139 * // FFT-specific setup
140 * }
141 * };
142 */
143class MAYAFLUX_API ShaderProcessor : public VKBufferProcessor {
144public:
145 /**
146 * @brief Get buffer usage characteristics needed for safe data flow
147 *
148 * Returns flags indicating:
149 * - Does compute read from input? (HOST_TO_DEVICE upload needed?)
150 * - Does compute write to output? (DEVICE_TO_HOST readback needed?)
151 *
152 * This lets ComputeProcessingChain auto-determine staging needs.
153 */
154 enum class BufferUsageHint : uint8_t {
155 NONE = 0,
156 INPUT_READ = 1 << 0, ///< Shader reads input
157 OUTPUT_WRITE = 1 << 1, ///< Shader writes output (modifies)
158 BIDIRECTIONAL = INPUT_READ | OUTPUT_WRITE
159 };
160
161 /**
162 * @brief Construct processor with shader path
163 * @param shader_path Path to compute shader (.comp or .spv)
164 * @param workgroup_x Workgroup size X (default 256)
165 */
166 explicit ShaderProcessor(const std::string& shader_path, uint32_t workgroup_x = 256);
167
168 /**
169 * @brief Construct processor with full configuration
170 * @param config Complete shader processor configuration
171 */
172 explicit ShaderProcessor(ShaderProcessorConfig config);
173
174 ~ShaderProcessor() override;
175
176 //==========================================================================
177 // BufferProcessor Interface
178 //==========================================================================
179
180 void processing_function(std::shared_ptr<Buffer> buffer) override;
181 void on_attach(std::shared_ptr<Buffer> buffer) override;
182 void on_detach(std::shared_ptr<Buffer> buffer) override;
183
184 [[nodiscard]] bool is_compatible_with(std::shared_ptr<Buffer> buffer) const override;
185
186 //==========================================================================
187 // Buffer Binding - Multi-buffer Support
188 //==========================================================================
189
190 /**
191 * @brief Bind a VKBuffer to a named shader descriptor
192 * @param descriptor_name Logical name (e.g., "input", "output")
193 * @param buffer VKBuffer to bind
194 *
195 * Registers the buffer for descriptor set binding.
196 * The descriptor_name must match a key in config.bindings.
197 */
198 void bind_buffer(const std::string& descriptor_name, const std::shared_ptr<VKBuffer>& buffer);
199
200 /**
201 * @brief Unbind a buffer from a descriptor
202 * @param descriptor_name Logical name to unbind
203 */
204 void unbind_buffer(const std::string& descriptor_name);
205
206 /**
207 * @brief Get bound buffer for a descriptor name
208 * @param descriptor_name Logical name
209 * @return Bound buffer, or nullptr if not bound
210 */
211 [[nodiscard]] std::shared_ptr<VKBuffer> get_bound_buffer(const std::string& descriptor_name) const;
212
213 /**
214 * @brief Auto-bind buffer based on attachment order
215 * @param buffer Buffer to auto-bind
216 *
217 * First attachment -> "input" or first binding
218 * Second attachment -> "output" or second binding
219 * Useful for simple single-buffer or input/output patterns.
220 */
221 void auto_bind_buffer(const std::shared_ptr<VKBuffer>& buffer);
222
223 //==========================================================================
224 // Shader Management
225 //==========================================================================
226
227 /**
228 * @brief Hot-reload shader from ShaderFoundry
229 * @return True if reload succeeded
230 *
231 * Invalidates cached shader and rebuilds pipeline.
232 * Existing descriptor sets are preserved if compatible.
233 */
234 bool hot_reload_shader();
235
236 /**
237 * @brief Update shader path and reload
238 * @param shader_path New shader path
239 */
240 void set_shader(const std::string& shader_path);
241
242 /**
243 * @brief Get current shader path
244 */
245 [[nodiscard]] const std::string& get_shader_path() const { return m_config.shader_path; }
246
247 //==========================================================================
248 // Dispatch Configuration
249 //==========================================================================
250
251 /**
252 * @brief Set workgroup size (should match shader local_size)
253 * @param x Workgroup size X
254 * @param y Workgroup size Y (default 1)
255 * @param z Workgroup size Z (default 1)
256 */
257 void set_workgroup_size(uint32_t x, uint32_t y = 1, uint32_t z = 1);
258
259 /**
260 * @brief Set dispatch mode
261 * @param mode Dispatch calculation mode
262 */
263 void set_dispatch_mode(ShaderDispatchConfig::DispatchMode mode);
264
265 /**
266 * @brief Set manual dispatch group counts
267 * @param x Group count X
268 * @param y Group count Y (default 1)
269 * @param z Group count Z (default 1)
270 */
271 void set_manual_dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1);
272
273 /**
274 * @brief Set custom dispatch calculator
275 * @param calculator Function that calculates dispatch from buffer
276 */
277 void set_custom_dispatch(std::function<std::array<uint32_t, 3>(const std::shared_ptr<VKBuffer>&)> calculator);
278
279 /**
280 * @brief Get current dispatch configuration
281 */
282 [[nodiscard]] const ShaderDispatchConfig& get_dispatch_config() const { return m_config.dispatch; }
283
284 //==========================================================================
285 // Push Constants
286 //==========================================================================
287
288 /**
289 * @brief Set push constant size
290 * @param size Size in bytes
291 */
292 void set_push_constant_size(size_t size);
293
294 /**
295 * @brief Set push constant size from type
296 * @tparam T Push constant struct type
297 */
298 template <typename T>
300 {
301 set_push_constant_size(sizeof(T));
302 }
303
304 /**
305 * @brief Update push constant data (type-safe)
306 * @tparam T Push constant struct type
307 * @param data Push constant data
308 *
309 * Data is copied and uploaded during next process() call.
310 */
311 template <typename T>
312 void set_push_constant_data(const T& data);
313
314 /**
315 * @brief Update push constant data (raw bytes)
316 * @param data Pointer to data
317 * @param size Size in bytes
318 */
319 void set_push_constant_data_raw(const void* data, size_t size);
320
321 /**
322 * @brief Get current push constant data
323 */
324 [[nodiscard]] const std::vector<uint8_t>& get_push_constant_data() const { return m_push_constant_data; }
325 [[nodiscard]] std::vector<uint8_t>& get_push_constant_data() { return m_push_constant_data; }
326
327 //==========================================================================
328 // Specialization Constants
329 //==========================================================================
330
331 /**
332 * @brief Set specialization constant
333 * @param constant_id Specialization constant ID
334 * @param value Value to set
335 *
336 * Requires pipeline recreation to take effect.
337 */
338 void set_specialization_constant(uint32_t constant_id, uint32_t value);
339
340 /**
341 * @brief Clear all specialization constants
342 */
343 void clear_specialization_constants();
344
345 //==========================================================================
346 // Configuration
347 //==========================================================================
348
349 /**
350 * @brief Update entire configuration
351 * @param config New configuration
352 *
353 * Triggers pipeline recreation.
354 */
355 void set_config(const ShaderProcessorConfig& config);
356
357 /**
358 * @brief Get current configuration
359 */
360 [[nodiscard]] const ShaderProcessorConfig& get_config() const { return m_config; }
361
362 /**
363 * @brief Add descriptor binding configuration
364 * @param descriptor_name Logical name
365 * @param binding Shader binding info
366 */
367 void add_binding(const std::string& descriptor_name, const ShaderBinding& binding);
368
369 //==========================================================================
370 // Data movement hints
371 //==========================================================================
372
373 /**
374 * @brief Get buffer usage hint for a descriptor
375 * @param descriptor_name Binding name
376 * @return BufferUsageHint flags
377 */
378 [[nodiscard]] virtual BufferUsageHint get_buffer_usage_hint(const std::string& descriptor_name) const;
379
380 /**
381 * @brief Check if shader modifies a specific buffer in-place
382 * @param descriptor_name Binding name
383 * @return True if shader both reads and writes this buffer
384 */
385 [[nodiscard]] virtual bool is_in_place_operation(const std::string& descriptor_name) const;
386
387 /**
388 * @brief Check if a descriptor binding exists
389 * @param descriptor_name Name of the binding (e.g., "input", "output")
390 * @return True if binding is configured
391 */
392 [[nodiscard]] bool has_binding(const std::string& descriptor_name) const;
393
394 /**
395 * @brief Get all configured descriptor names
396 * @return Vector of binding names
397 *
398 * Useful for introspection: which buffers does this shader expect?
399 */
400 [[nodiscard]] std::vector<std::string> get_binding_names() const;
401
402 /**
403 * @brief Check if all required bindings are satisfied
404 * @return True if all configured bindings have buffers bound
405 */
406 [[nodiscard]] bool are_bindings_complete() const;
407
408 //==========================================================================
409 // State Queries
410 //==========================================================================
411
412 /**
413 * @brief Check if shader is loaded
414 */
415 [[nodiscard]] bool is_shader_loaded() const { return m_shader_id != Portal::Graphics::INVALID_SHADER; }
416
417 /**
418 * @brief Check if pipeline is created
419 */
420 [[nodiscard]] bool is_pipeline_ready() const { return m_pipeline_id != Portal::Graphics::INVALID_COMPUTE_PIPELINE; }
421
422 /**
423 * @brief Check if descriptors are initialized
424 */
425 [[nodiscard]] bool are_descriptors_ready() const { return !m_descriptor_set_ids.empty(); }
426
427 /**
428 * @brief Get number of bound buffers
429 */
430 [[nodiscard]] size_t get_bound_buffer_count() const { return m_bound_buffers.size(); }
431
432 /**
433 * @brief Get the output buffer after compute dispatch
434 *
435 * Returns the buffer that was last processed (input/output depends on
436 * shader and binding configuration). Used by ComputeProcessingChain
437 * to determine where compute results ended up.
438 *
439 * Typically the buffer passed to processing_function(), but can be
440 * overridden by subclasses if compute modifies different buffers.
441 */
442 [[nodiscard]] virtual std::shared_ptr<VKBuffer> get_output_buffer() const { return m_last_processed_buffer; }
443
444 /**
445 * @brief Check if compute has been executed at least once
446 * @return True if processing_function() has been called
447 */
448 [[nodiscard]] virtual inline bool has_executed() const
449 {
450 return m_last_command_buffer != Portal::Graphics::INVALID_COMMAND_BUFFER;
451 }
452
453protected:
454 //==========================================================================
455 // Overridable Hooks for Specialized Processors
456 //==========================================================================
457
458 /**
459 * @brief Called before shader compilation
460 * @param shader_path Path to shader
461 *
462 * Override to modify shader compilation (e.g., add defines, includes).
463 */
464 virtual void on_before_compile(const std::string& shader_path);
465
466 /**
467 * @brief Called after shader is loaded
468 * @param shader Loaded shader module
469 *
470 * Override to extract reflection data or validate shader.
471 */
472 virtual void on_shader_loaded(Portal::Graphics::ShaderID shader_id);
473
474 /**
475 * @brief Called before pipeline creation
476 * @param config Pipeline configuration
477 *
478 * Override to modify pipeline configuration.
479 */
481
482 /**
483 * @brief Called after pipeline is created
484 * @param pipeline Created pipeline
485 *
486 * Override for post-pipeline setup.
487 */
488 virtual void on_pipeline_created(Portal::Graphics::ComputePipelineID pipeline_id);
489
490 /**
491 * @brief Called before descriptor sets are created
492 *
493 * Override to add custom descriptor bindings.
494 */
495 virtual void on_before_descriptors_create();
496
497 /**
498 * @brief Called after descriptor sets are created
499 *
500 * Override for custom descriptor updates.
501 */
502 virtual void on_descriptors_created();
503
504 /**
505 * @brief Called before each dispatch
506 * @param cmd Command buffer
507 * @param buffer Currently processing buffer
508 *
509 * Override to update push constants or dynamic descriptors.
510 */
511 virtual void on_before_dispatch(Portal::Graphics::CommandBufferID cmd_id, const std::shared_ptr<VKBuffer>& buffer);
512
513 /**
514 * @brief Called after each dispatch
515 * @param cmd Command buffer
516 * @param buffer Currently processed buffer
517 *
518 * Override for post-dispatch synchronization or state updates.
519 */
520 virtual void on_after_dispatch(Portal::Graphics::CommandBufferID cmd_id, const std::shared_ptr<VKBuffer>& buffer);
521
522 /**
523 * @brief Calculate dispatch size from buffer
524 * @param buffer Buffer to process
525 * @return {group_count_x, group_count_y, group_count_z}
526 *
527 * Override for custom dispatch calculation logic.
528 * Default implementation uses m_config.dispatch settings.
529 */
530 virtual std::array<uint32_t, 3> calculate_dispatch_size(const std::shared_ptr<VKBuffer>& buffer);
531
532 //==========================================================================
533 // Protected State - Available to Subclasses
534 //==========================================================================
535
537
538 Portal::Graphics::ShaderID m_shader_id = Portal::Graphics::INVALID_SHADER;
539 Portal::Graphics::ComputePipelineID m_pipeline_id = Portal::Graphics::INVALID_COMPUTE_PIPELINE;
540 std::vector<Portal::Graphics::DescriptorSetID> m_descriptor_set_ids;
541 Portal::Graphics::CommandBufferID m_last_command_buffer = Portal::Graphics::INVALID_COMMAND_BUFFER;
542
543 std::unordered_map<std::string, std::shared_ptr<VKBuffer>> m_bound_buffers;
544 std::shared_ptr<VKBuffer> m_last_processed_buffer;
545
546 std::vector<uint8_t> m_push_constant_data;
547
548 bool m_initialized {};
549 bool m_needs_pipeline_rebuild = true;
550 bool m_needs_descriptor_rebuild = true;
551
552 size_t m_auto_bind_index {};
553
554protected:
555 virtual void initialize_pipeline(const std::shared_ptr<Buffer>& buffer);
556 virtual void cleanup();
557
558private:
559 //==========================================================================
560 // Internal Implementation
561 //==========================================================================
562
563 void initialize_shader();
564 void initialize_descriptors();
565 void update_descriptors();
566 void execute_dispatch(const std::shared_ptr<VKBuffer>& buffer);
567};
568
569template <typename T>
571{
572 static_assert(sizeof(T) <= 128, "Push constants typically limited to 128 bytes");
573 m_push_constant_data.resize(sizeof(T));
574 std::memcpy(m_push_constant_data.data(), &data, sizeof(T));
575}
576} // namespace MayaFlux::Buffers
const std::vector< uint8_t > & get_push_constant_data() const
Get current push constant data.
size_t get_bound_buffer_count() const
Get number of bound buffers.
std::unordered_map< std::string, std::shared_ptr< VKBuffer > > m_bound_buffers
std::vector< uint8_t > m_push_constant_data
bool is_pipeline_ready() const
Check if pipeline is created.
std::vector< uint8_t > & get_push_constant_data()
const std::string & get_shader_path() const
Get current shader path.
void set_push_constant_size()
Set push constant size from type.
virtual bool has_executed() const
Check if compute has been executed at least once.
void set_push_constant_data(const T &data)
Update push constant data (type-safe)
bool are_descriptors_ready() const
Check if descriptors are initialized.
std::shared_ptr< VKBuffer > m_last_processed_buffer
bool is_shader_loaded() const
Check if shader is loaded.
virtual std::shared_ptr< VKBuffer > get_output_buffer() const
Get the output buffer after compute dispatch.
BufferUsageHint
Get buffer usage characteristics needed for safe data flow.
virtual void on_before_pipeline_create(Portal::Graphics::ComputePipelineID pipeline_id)
Called before pipeline creation.
const ShaderProcessorConfig & get_config() const
Get current configuration.
const ShaderDispatchConfig & get_dispatch_config() const
Get current dispatch configuration.
std::vector< Portal::Graphics::DescriptorSetID > m_descriptor_set_ids
Generic compute shader processor for VKBuffers.
ShaderStage
User-friendly shader stage enum.
uint32_t binding
Binding point within set.
uint32_t set
Descriptor set index.
ShaderBinding(uint32_t s, uint32_t b, vk::DescriptorType t=vk::DescriptorType::eStorageBuffer)
Describes how a VKBuffer binds to a shader descriptor.
enum MayaFlux::Buffers::ShaderDispatchConfig::DispatchMode mode
std::function< std::array< uint32_t, 3 >(const std::shared_ptr< VKBuffer > &)> custom_calculator
@ CUSTOM
User-provided calculation function.
@ BUFFER_SIZE
Calculate from buffer byte size.
@ ELEMENT_COUNT
Calculate from buffer element count.
uint32_t workgroup_x
Workgroup size X (should match shader)
Configuration for compute shader dispatch.
std::unordered_map< uint32_t, uint32_t > specialization_constants
std::unordered_map< std::string, ShaderBinding > bindings
Portal::Graphics::ShaderStage stage
std::string shader_path
Path to shader file.
Complete configuration for shader processor.