MayaFlux 0.3.0
Digital-First Multimedia Processing Framework
Loading...
Searching...
No Matches
VKShaderModule.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <vulkan/vulkan.hpp>
4
5namespace spirv_cross {
6class Compiler;
7struct SPIRType;
8}
9
10namespace MayaFlux::Core {
11
12/**
13 * @enum ShaderType
14 * @brief High-level shader type enumeration
15 *
16 * Used for specifying shader types in a generic way.
17 */
18enum class Stage : uint8_t {
19 COMPUTE,
20 VERTEX,
25 MESH,
26 TASK
27};
28
30 std::vector<vk::Format> color_formats;
31 vk::Format depth_format;
32 vk::Format stencil_format;
33};
34
36 std::vector<vk::VertexInputBindingDescription> bindings;
37 std::vector<vk::VertexInputAttributeDescription> attributes;
38};
39
41 struct Attribute {
42 uint32_t location; // layout(location = N)
43 vk::Format format; // vec3 -> eR32G32B32Sfloat
44 uint32_t offset; // byte offset in vertex
45 std::string name; // variable name (from reflection)
46 };
47
48 struct Binding {
49 uint32_t binding; // vertex buffer binding point
50 uint32_t stride; // bytes per vertex
51 vk::VertexInputRate rate; // per-vertex or per-instance
52 };
53
54 std::vector<Attribute> attributes;
55 std::vector<Binding> bindings;
56};
57
59 struct Attachment {
60 uint32_t location; // layout(location = N)
61 vk::Format format; // vec4 -> eR32G32B32A32Sfloat
62 std::string name; // output variable name
63 };
64
65 std::vector<Attachment> color_attachments;
66 bool has_depth_output = false;
67 bool has_stencil_output = false;
68};
69
71 uint32_t offset;
72 uint32_t size;
73 std::string name; // struct name (if any)
74 vk::ShaderStageFlags stages; // which stages use it
75};
76
77/**
78 * @struct ShaderReflection
79 * @brief Metadata extracted from shader module
80 *
81 * Contains information about shader resources for descriptor set layout creation
82 * and pipeline configuration. Extracted via SPIRV-Cross or manual parsing.
83 */
86 uint32_t set; ///< Descriptor set index
87 uint32_t binding; ///< Binding point within set
88 vk::DescriptorType type; ///< Type (uniform buffer, storage buffer, etc.)
89 vk::ShaderStageFlags stage; ///< Stage visibility
90 uint32_t count; ///< Array size (1 for non-arrays)
91 std::string name; ///< Variable name in shader
92 };
93
95 vk::ShaderStageFlags stage; ///< Stage visibility
96 uint32_t offset; ///< Offset in push constant block
97 uint32_t size; ///< Size in bytes
98 };
99
101 uint32_t constant_id; ///< Specialization constant ID
102 uint32_t size; ///< Size in bytes
103 std::string name; ///< Variable name in shader
104 };
105
106 std::vector<DescriptorBinding> bindings;
107 std::vector<PushConstantRange> push_constants;
108 std::vector<SpecializationConstant> specialization_constants;
109
110 std::optional<std::array<uint32_t, 3>> workgroup_size; ///< local_size_x/y/z
111
112 std::vector<vk::VertexInputAttributeDescription> vertex_attributes;
113 std::vector<vk::VertexInputBindingDescription> vertex_bindings;
114};
115
116/**
117 * @class VKShaderModule
118 * @brief Wrapper for Vulkan shader module with lifecycle and reflection
119 *
120 * Responsibilities:
121 * - Create vk::ShaderModule from SPIR-V binary or GLSL source
122 * - Load shaders from disk or memory
123 * - Extract shader metadata via reflection
124 * - Provide pipeline stage info for pipeline creation
125 * - Enable hot-reload support (recreation)
126 *
127 * Does NOT handle:
128 * - Pipeline creation (that's VKComputePipeline/VKGraphicsPipeline)
129 * - Descriptor set allocation (that's VKDescriptorManager)
130 * - Shader compilation (delegates to external compiler)
131 *
132 * Integration points:
133 * - VKComputePipeline/VKGraphicsPipeline: uses get_stage_create_info()
134 * - VKDescriptorManager: uses get_reflection() for layout creation
135 * - VKBufferProcessor: subclasses use this to load compute shaders
136 */
137class MAYAFLUX_API VKShaderModule {
138public:
139 VKShaderModule() = default;
141
144 VKShaderModule(VKShaderModule&&) noexcept;
145 VKShaderModule& operator=(VKShaderModule&&) noexcept;
146
147 /**
148 * @brief Create shader module from SPIR-V binary
149 * @param device Logical device
150 * @param spirv_code SPIR-V bytecode (must be aligned to uint32_t)
151 * @param stage Shader stage (compute, vertex, fragment, etc.)
152 * @param entry_point Entry point function name (default: "main")
153 * @param enable_reflection Extract descriptor bindings and resources
154 * @return true if creation succeeded
155 *
156 * This is the lowest-level creation method. All other create methods
157 * eventually funnel through this one.
158 */
159 bool create_from_spirv(
160 vk::Device device,
161 const std::vector<uint32_t>& spirv_code,
162 vk::ShaderStageFlagBits stage,
163 const std::string& entry_point = "main",
164 bool enable_reflection = true);
165
166 /**
167 * @brief Create shader module from SPIR-V file
168 * @param device Logical device
169 * @param spirv_path Path to .spv file
170 * @param stage Shader stage
171 * @param entry_point Entry point function name
172 * @param enable_reflection Extract metadata
173 * @return true if creation succeeded
174 *
175 * Reads binary file and calls create_from_spirv().
176 */
177 bool create_from_spirv_file(
178 vk::Device device,
179 const std::string& spirv_path,
180 vk::ShaderStageFlagBits stage,
181 const std::string& entry_point = "main",
182 bool enable_reflection = true);
183
184 /**
185 * @brief Create shader module from GLSL source string
186 * @param device Logical device
187 * @param glsl_source GLSL source code
188 * @param stage Shader stage (determines compiler mode)
189 * @param entry_point Entry point function name
190 * @param enable_reflection Extract metadata
191 * @param include_directories Paths for #include resolution
192 * @param defines Preprocessor definitions (e.g., {"DEBUG", "MAX_LIGHTS=4"})
193 * @return true if creation succeeded
194 *
195 * Compiles GLSL → SPIR-V using shaderc, then calls create_from_spirv().
196 * Requires shaderc library to be available.
197 */
198 bool create_from_glsl(
199 vk::Device device,
200 const std::string& glsl_source,
201 vk::ShaderStageFlagBits stage,
202 const std::string& entry_point = "main",
203 bool enable_reflection = true,
204 const std::vector<std::string>& include_directories = {},
205 const std::unordered_map<std::string, std::string>& defines = {});
206
207 /**
208 * @brief Create shader module from GLSL file
209 * @param device Logical device
210 * @param glsl_path Path to .comp/.vert/.frag/.geom file
211 * @param stage Shader stage (auto-detected from extension if not specified)
212 * @param entry_point Entry point function name
213 * @param enable_reflection Extract metadata
214 * @param include_directories Paths for #include resolution
215 * @param defines Preprocessor definitions
216 * @return true if creation succeeded
217 *
218 * Reads file, compiles GLSL → SPIR-V, calls create_from_spirv().
219 * Stage auto-detection:
220 * .comp → Compute
221 * .vert → Vertex
222 * .frag → Fragment
223 * .geom → Geometry
224 * .tesc → Tessellation Control
225 * .tese → Tessellation Evaluation
226 */
227 bool create_from_glsl_file(
228 vk::Device device,
229 const std::string& glsl_path,
230 std::optional<vk::ShaderStageFlagBits> stage = std::nullopt,
231 const std::string& entry_point = "main",
232 bool enable_reflection = true,
233 const std::vector<std::string>& include_directories = {},
234 const std::unordered_map<std::string, std::string>& defines = {});
235
236 /**
237 * @brief Cleanup shader module
238 * @param device Logical device (must match creation device)
239 *
240 * Destroys vk::ShaderModule and clears metadata.
241 * Safe to call multiple times or on uninitialized modules.
242 */
243 void cleanup(vk::Device device);
244
245 /**
246 * @brief Check if module is valid
247 * @return true if shader module was successfully created
248 */
249 [[nodiscard]] bool is_valid() const { return m_module != nullptr; }
250
251 /**
252 * @brief Get raw Vulkan shader module handle
253 * @return vk::ShaderModule handle
254 */
255 [[nodiscard]] vk::ShaderModule get() const { return m_module; }
256
257 /**
258 * @brief Get shader stage
259 * @return Stage flags (compute, vertex, fragment, etc.)
260 */
261 [[nodiscard]] vk::ShaderStageFlagBits get_stage() const { return m_stage; }
262
263 /**
264 * @brief Get entry point function name
265 * @return Entry point string (typically "main")
266 */
267 [[nodiscard]] const std::string& get_entry_point() const { return m_entry_point; }
268
269 /**
270 * @brief Get pipeline shader stage create info
271 * @return vk::PipelineShaderStageCreateInfo for pipeline creation
272 *
273 * This is the primary integration point with pipeline builders.
274 * Usage:
275 * auto stage_info = shader_module.get_stage_create_info();
276 * pipeline_builder.add_shader_stage(stage_info);
277 */
278 [[nodiscard]] vk::PipelineShaderStageCreateInfo get_stage_create_info() const;
279
280 /**
281 * @brief Get shader reflection metadata
282 * @return Const reference to extracted metadata
283 *
284 * Used by descriptor managers and pipeline builders to automatically
285 * configure layouts and bindings without manual specification.
286 */
287 [[nodiscard]] const ShaderReflection& get_reflection() const { return m_reflection; }
288
289 /**
290 * @brief Get SPIR-V bytecode
291 * @return Vector of SPIR-V words (empty if not preserved)
292 *
293 * Useful for caching, serialization, or re-creation.
294 * Only available if preserve_spirv was enabled during creation.
295 */
296 [[nodiscard]] const std::vector<uint32_t>& get_spirv() const { return m_spirv_code; }
297
298 /**
299 * @brief Set specialization constants
300 * @param constants Map of constant_id → value
301 *
302 * Updates the specialization info used in get_stage_create_info().
303 * Must be called before using the shader in pipeline creation.
304 *
305 * Example:
306 * shader.set_specialization_constants({
307 * {0, 256}, // WORKGROUP_SIZE = 256
308 * {1, 1} // ENABLE_OPTIMIZATION = true
309 * });
310 */
311 void set_specialization_constants(const std::unordered_map<uint32_t, uint32_t>& constants);
312
313 /**
314 * @brief Enable SPIR-V preservation for hot-reload
315 * @param preserve If true, stores SPIR-V bytecode internally
316 *
317 * Increases memory usage but enables recreation without recompilation.
318 */
319 void set_preserve_spirv(bool preserve) { m_preserve_spirv = preserve; }
320
321 /**
322 * @brief Get shader stage type
323 * @return Stage enum (easier than vk::ShaderStageFlagBits for logic)
324 */
325 [[nodiscard]] Stage get_stage_type() const;
326
327 /**
328 * @brief Get vertex input state (vertex shaders only)
329 * @return Vertex input metadata, empty if not a vertex shader
330 */
331 [[nodiscard]] const VertexInputInfo& get_vertex_input() const
332 {
333 return m_vertex_input;
334 }
335
336 /**
337 * @brief Check if vertex input is available
338 */
339 [[nodiscard]] bool has_vertex_input() const
340 {
341 return !m_vertex_input.attributes.empty();
342 }
343
344 /**
345 * @brief Get fragment output state (fragment shaders only)
346 * @return Fragment output metadata, empty if not a fragment shader
347 */
348 [[nodiscard]] const FragmentOutputInfo& get_fragment_output() const
349 {
350 return m_fragment_output;
351 }
352
353 /**
354 * @brief Get detailed push constant info
355 * @return Push constant metadata (replaces simple PushConstantRange)
356 */
357 [[nodiscard]] const std::vector<PushConstantInfo>& get_push_constants() const
358 {
359 return m_push_constants;
360 }
361
362 // NEW: Workgroup size for compute shaders
363 /**
364 * @brief Get compute workgroup size (compute shaders only)
365 * @return {local_size_x, local_size_y, local_size_z} or nullopt
366 */
367 [[nodiscard]] std::optional<std::array<uint32_t, 3>> get_workgroup_size() const
368 {
369 return m_reflection.workgroup_size;
370 }
371
372 /**
373 * @brief Auto-detect shader stage from file extension
374 * @param filepath Path to shader file
375 * @return Detected stage, or nullopt if unknown extension
376 */
377 static std::optional<vk::ShaderStageFlagBits> detect_stage_from_extension(const std::string& filepath);
378
379private:
380 vk::ShaderModule m_module = nullptr;
381 vk::ShaderStageFlagBits m_stage = vk::ShaderStageFlagBits::eCompute;
382 std::string m_entry_point = "main";
383
385 std::vector<uint32_t> m_spirv_code; ///< Preserved SPIR-V (if enabled)
386
387 bool m_preserve_spirv {};
388
389 std::unordered_map<uint32_t, uint32_t> m_specialization_map;
390 std::vector<vk::SpecializationMapEntry> m_specialization_entries;
391 std::vector<uint32_t> m_specialization_data;
392 vk::SpecializationInfo m_specialization_info;
393
396 std::vector<PushConstantInfo> m_push_constants;
397
398 /**
399 * @brief Perform reflection on SPIR-V bytecode
400 * @param spirv_code SPIR-V bytecode
401 * @return true if reflection succeeded
402 *
403 * Uses SPIRV-Cross library to extract bindings, push constants,
404 * workgroup sizes, etc. Falls back to basic parsing if library unavailable.
405 */
406 bool reflect_spirv(const std::vector<uint32_t>& spirv_code);
407
408 /**
409 * @brief Compile GLSL to SPIR-V using shaderc
410 * @param glsl_source GLSL source code
411 * @param stage Shader stage (affects compiler settings)
412 * @param include_directories Include paths
413 * @param defines Preprocessor macros
414 * @return SPIR-V bytecode, or empty vector on failure
415 */
416 std::vector<uint32_t> compile_glsl_to_spirv(
417 const std::string& glsl_source,
418 vk::ShaderStageFlagBits stage,
419 const std::vector<std::string>& include_directories,
420 const std::unordered_map<std::string, std::string>& defines);
421
422 /**
423 * @brief Read binary file into vector
424 * @param filepath Path to file
425 * @return File contents, or empty vector on failure
426 */
427 static std::vector<uint32_t> read_spirv_file(const std::string& filepath);
428
429 /**
430 * @brief Read text file into string
431 * @param filepath Path to file
432 * @return File contents, or empty string on failure
433 */
434 static std::string read_text_file(const std::string& filepath);
435
436#ifndef MAYAFLUX_USE_SHADERC
437 /*
438 * @brief Compile GLSL to SPIR-V using external compiler
439 * @param glsl_source GLSL source code
440 * @param stage Shader stage (affects compiler settings)
441 * @param include_directories Include paths
442 * @param defines Preprocessor macros
443 */
444 static std::vector<uint32_t> compile_glsl_to_spirv_external(
445 const std::string& glsl_source,
446 vk::ShaderStageFlagBits stage,
447 const std::vector<std::string>& include_directories = {},
448 const std::unordered_map<std::string, std::string>& defines = {});
449#endif
450
451 /**
452 * @brief Update specialization info from current map
453 * Called before get_stage_create_info() to ensure fresh data
454 */
455 void update_specialization_info();
456
457 /**
458 * @brief Convert SPIRV-Cross type to Vulkan vertex attribute format
459 * @param type SPIR-V type information
460 * @return Corresponding Vulkan format
461 */
462 static vk::Format spirv_type_to_vk_format(const spirv_cross::SPIRType& type);
463};
464
465} // namespace MayaFlux::Core
vk::SpecializationInfo m_specialization_info
const std::vector< uint32_t > & get_spirv() const
Get SPIR-V bytecode.
std::optional< std::array< uint32_t, 3 > > get_workgroup_size() const
Get compute workgroup size (compute shaders only)
bool is_valid() const
Check if module is valid.
bool has_vertex_input() const
Check if vertex input is available.
std::vector< vk::SpecializationMapEntry > m_specialization_entries
vk::ShaderModule get() const
Get raw Vulkan shader module handle.
VKShaderModule & operator=(const VKShaderModule &)=delete
void set_preserve_spirv(bool preserve)
Enable SPIR-V preservation for hot-reload.
VKShaderModule(const VKShaderModule &)=delete
const FragmentOutputInfo & get_fragment_output() const
Get fragment output state (fragment shaders only)
std::vector< PushConstantInfo > m_push_constants
std::vector< uint32_t > m_specialization_data
const std::vector< PushConstantInfo > & get_push_constants() const
Get detailed push constant info.
std::unordered_map< uint32_t, uint32_t > m_specialization_map
const std::string & get_entry_point() const
Get entry point function name.
const ShaderReflection & get_reflection() const
Get shader reflection metadata.
std::vector< uint32_t > m_spirv_code
Preserved SPIR-V (if enabled)
const VertexInputInfo & get_vertex_input() const
Get vertex input state (vertex shaders only)
vk::ShaderStageFlagBits get_stage() const
Get shader stage.
Wrapper for Vulkan shader module with lifecycle and reflection.
int main()
Definition main.cpp:33
std::vector< Attachment > color_attachments
std::vector< vk::Format > color_formats
vk::ShaderStageFlags stage
Stage visibility.
uint32_t count
Array size (1 for non-arrays)
vk::DescriptorType type
Type (uniform buffer, storage buffer, etc.)
vk::ShaderStageFlags stage
Stage visibility.
uint32_t offset
Offset in push constant block.
std::vector< SpecializationConstant > specialization_constants
std::vector< vk::VertexInputBindingDescription > vertex_bindings
std::vector< DescriptorBinding > bindings
std::vector< PushConstantRange > push_constants
std::vector< vk::VertexInputAttributeDescription > vertex_attributes
std::optional< std::array< uint32_t, 3 > > workgroup_size
local_size_x/y/z
Metadata extracted from shader module.
std::vector< Binding > bindings
std::vector< Attribute > attributes
std::vector< vk::VertexInputAttributeDescription > attributes
std::vector< vk::VertexInputBindingDescription > bindings