MayaFlux 0.2.0
Digital-First Multimedia Processing Framework
Loading...
Searching...
No Matches
VKShaderModule.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <vulkan/vulkan.hpp>
4
5namespace spirv_cross {
6class Compiler;
7struct SPIRType;
8}
9
10namespace MayaFlux::Core {
11
12/**
13 * @enum ShaderType
14 * @brief High-level shader type enumeration
15 *
16 * Used for specifying shader types in a generic way.
17 */
18enum class Stage : uint8_t {
19 COMPUTE,
20 VERTEX,
25};
26
28 std::vector<vk::Format> color_formats;
29 vk::Format depth_format;
30 vk::Format stencil_format;
31};
32
34 std::vector<vk::VertexInputBindingDescription> bindings;
35 std::vector<vk::VertexInputAttributeDescription> attributes;
36};
37
39 struct Attribute {
40 uint32_t location; // layout(location = N)
41 vk::Format format; // vec3 -> eR32G32B32Sfloat
42 uint32_t offset; // byte offset in vertex
43 std::string name; // variable name (from reflection)
44 };
45
46 struct Binding {
47 uint32_t binding; // vertex buffer binding point
48 uint32_t stride; // bytes per vertex
49 vk::VertexInputRate rate; // per-vertex or per-instance
50 };
51
52 std::vector<Attribute> attributes;
53 std::vector<Binding> bindings;
54};
55
57 struct Attachment {
58 uint32_t location; // layout(location = N)
59 vk::Format format; // vec4 -> eR32G32B32A32Sfloat
60 std::string name; // output variable name
61 };
62
63 std::vector<Attachment> color_attachments;
64 bool has_depth_output = false;
65 bool has_stencil_output = false;
66};
67
69 uint32_t offset;
70 uint32_t size;
71 std::string name; // struct name (if any)
72 vk::ShaderStageFlags stages; // which stages use it
73};
74
75/**
76 * @struct ShaderReflection
77 * @brief Metadata extracted from shader module
78 *
79 * Contains information about shader resources for descriptor set layout creation
80 * and pipeline configuration. Extracted via SPIRV-Cross or manual parsing.
81 */
84 uint32_t set; ///< Descriptor set index
85 uint32_t binding; ///< Binding point within set
86 vk::DescriptorType type; ///< Type (uniform buffer, storage buffer, etc.)
87 vk::ShaderStageFlags stage; ///< Stage visibility
88 uint32_t count; ///< Array size (1 for non-arrays)
89 std::string name; ///< Variable name in shader
90 };
91
93 vk::ShaderStageFlags stage; ///< Stage visibility
94 uint32_t offset; ///< Offset in push constant block
95 uint32_t size; ///< Size in bytes
96 };
97
99 uint32_t constant_id; ///< Specialization constant ID
100 uint32_t size; ///< Size in bytes
101 std::string name; ///< Variable name in shader
102 };
103
104 std::vector<DescriptorBinding> bindings;
105 std::vector<PushConstantRange> push_constants;
106 std::vector<SpecializationConstant> specialization_constants;
107
108 std::optional<std::array<uint32_t, 3>> workgroup_size; ///< local_size_x/y/z
109
110 std::vector<vk::VertexInputAttributeDescription> vertex_attributes;
111 std::vector<vk::VertexInputBindingDescription> vertex_bindings;
112};
113
114/**
115 * @class VKShaderModule
116 * @brief Wrapper for Vulkan shader module with lifecycle and reflection
117 *
118 * Responsibilities:
119 * - Create vk::ShaderModule from SPIR-V binary or GLSL source
120 * - Load shaders from disk or memory
121 * - Extract shader metadata via reflection
122 * - Provide pipeline stage info for pipeline creation
123 * - Enable hot-reload support (recreation)
124 *
125 * Does NOT handle:
126 * - Pipeline creation (that's VKComputePipeline/VKGraphicsPipeline)
127 * - Descriptor set allocation (that's VKDescriptorManager)
128 * - Shader compilation (delegates to external compiler)
129 *
130 * Integration points:
131 * - VKComputePipeline/VKGraphicsPipeline: uses get_stage_create_info()
132 * - VKDescriptorManager: uses get_reflection() for layout creation
133 * - VKBufferProcessor: subclasses use this to load compute shaders
134 */
135class MAYAFLUX_API VKShaderModule {
136public:
137 VKShaderModule() = default;
139
142 VKShaderModule(VKShaderModule&&) noexcept;
143 VKShaderModule& operator=(VKShaderModule&&) noexcept;
144
145 /**
146 * @brief Create shader module from SPIR-V binary
147 * @param device Logical device
148 * @param spirv_code SPIR-V bytecode (must be aligned to uint32_t)
149 * @param stage Shader stage (compute, vertex, fragment, etc.)
150 * @param entry_point Entry point function name (default: "main")
151 * @param enable_reflection Extract descriptor bindings and resources
152 * @return true if creation succeeded
153 *
154 * This is the lowest-level creation method. All other create methods
155 * eventually funnel through this one.
156 */
157 bool create_from_spirv(
158 vk::Device device,
159 const std::vector<uint32_t>& spirv_code,
160 vk::ShaderStageFlagBits stage,
161 const std::string& entry_point = "main",
162 bool enable_reflection = true);
163
164 /**
165 * @brief Create shader module from SPIR-V file
166 * @param device Logical device
167 * @param spirv_path Path to .spv file
168 * @param stage Shader stage
169 * @param entry_point Entry point function name
170 * @param enable_reflection Extract metadata
171 * @return true if creation succeeded
172 *
173 * Reads binary file and calls create_from_spirv().
174 */
175 bool create_from_spirv_file(
176 vk::Device device,
177 const std::string& spirv_path,
178 vk::ShaderStageFlagBits stage,
179 const std::string& entry_point = "main",
180 bool enable_reflection = true);
181
182 /**
183 * @brief Create shader module from GLSL source string
184 * @param device Logical device
185 * @param glsl_source GLSL source code
186 * @param stage Shader stage (determines compiler mode)
187 * @param entry_point Entry point function name
188 * @param enable_reflection Extract metadata
189 * @param include_directories Paths for #include resolution
190 * @param defines Preprocessor definitions (e.g., {"DEBUG", "MAX_LIGHTS=4"})
191 * @return true if creation succeeded
192 *
193 * Compiles GLSL → SPIR-V using shaderc, then calls create_from_spirv().
194 * Requires shaderc library to be available.
195 */
196 bool create_from_glsl(
197 vk::Device device,
198 const std::string& glsl_source,
199 vk::ShaderStageFlagBits stage,
200 const std::string& entry_point = "main",
201 bool enable_reflection = true,
202 const std::vector<std::string>& include_directories = {},
203 const std::unordered_map<std::string, std::string>& defines = {});
204
205 /**
206 * @brief Create shader module from GLSL file
207 * @param device Logical device
208 * @param glsl_path Path to .comp/.vert/.frag/.geom file
209 * @param stage Shader stage (auto-detected from extension if not specified)
210 * @param entry_point Entry point function name
211 * @param enable_reflection Extract metadata
212 * @param include_directories Paths for #include resolution
213 * @param defines Preprocessor definitions
214 * @return true if creation succeeded
215 *
216 * Reads file, compiles GLSL → SPIR-V, calls create_from_spirv().
217 * Stage auto-detection:
218 * .comp → Compute
219 * .vert → Vertex
220 * .frag → Fragment
221 * .geom → Geometry
222 * .tesc → Tessellation Control
223 * .tese → Tessellation Evaluation
224 */
225 bool create_from_glsl_file(
226 vk::Device device,
227 const std::string& glsl_path,
228 std::optional<vk::ShaderStageFlagBits> stage = std::nullopt,
229 const std::string& entry_point = "main",
230 bool enable_reflection = true,
231 const std::vector<std::string>& include_directories = {},
232 const std::unordered_map<std::string, std::string>& defines = {});
233
234 /**
235 * @brief Cleanup shader module
236 * @param device Logical device (must match creation device)
237 *
238 * Destroys vk::ShaderModule and clears metadata.
239 * Safe to call multiple times or on uninitialized modules.
240 */
241 void cleanup(vk::Device device);
242
243 /**
244 * @brief Check if module is valid
245 * @return true if shader module was successfully created
246 */
247 [[nodiscard]] bool is_valid() const { return m_module != nullptr; }
248
249 /**
250 * @brief Get raw Vulkan shader module handle
251 * @return vk::ShaderModule handle
252 */
253 [[nodiscard]] vk::ShaderModule get() const { return m_module; }
254
255 /**
256 * @brief Get shader stage
257 * @return Stage flags (compute, vertex, fragment, etc.)
258 */
259 [[nodiscard]] vk::ShaderStageFlagBits get_stage() const { return m_stage; }
260
261 /**
262 * @brief Get entry point function name
263 * @return Entry point string (typically "main")
264 */
265 [[nodiscard]] const std::string& get_entry_point() const { return m_entry_point; }
266
267 /**
268 * @brief Get pipeline shader stage create info
269 * @return vk::PipelineShaderStageCreateInfo for pipeline creation
270 *
271 * This is the primary integration point with pipeline builders.
272 * Usage:
273 * auto stage_info = shader_module.get_stage_create_info();
274 * pipeline_builder.add_shader_stage(stage_info);
275 */
276 [[nodiscard]] vk::PipelineShaderStageCreateInfo get_stage_create_info() const;
277
278 /**
279 * @brief Get shader reflection metadata
280 * @return Const reference to extracted metadata
281 *
282 * Used by descriptor managers and pipeline builders to automatically
283 * configure layouts and bindings without manual specification.
284 */
285 [[nodiscard]] const ShaderReflection& get_reflection() const { return m_reflection; }
286
287 /**
288 * @brief Get SPIR-V bytecode
289 * @return Vector of SPIR-V words (empty if not preserved)
290 *
291 * Useful for caching, serialization, or re-creation.
292 * Only available if preserve_spirv was enabled during creation.
293 */
294 [[nodiscard]] const std::vector<uint32_t>& get_spirv() const { return m_spirv_code; }
295
296 /**
297 * @brief Set specialization constants
298 * @param constants Map of constant_id → value
299 *
300 * Updates the specialization info used in get_stage_create_info().
301 * Must be called before using the shader in pipeline creation.
302 *
303 * Example:
304 * shader.set_specialization_constants({
305 * {0, 256}, // WORKGROUP_SIZE = 256
306 * {1, 1} // ENABLE_OPTIMIZATION = true
307 * });
308 */
309 void set_specialization_constants(const std::unordered_map<uint32_t, uint32_t>& constants);
310
311 /**
312 * @brief Enable SPIR-V preservation for hot-reload
313 * @param preserve If true, stores SPIR-V bytecode internally
314 *
315 * Increases memory usage but enables recreation without recompilation.
316 */
317 void set_preserve_spirv(bool preserve) { m_preserve_spirv = preserve; }
318
319 /**
320 * @brief Get shader stage type
321 * @return Stage enum (easier than vk::ShaderStageFlagBits for logic)
322 */
323 [[nodiscard]] Stage get_stage_type() const;
324
325 /**
326 * @brief Get vertex input state (vertex shaders only)
327 * @return Vertex input metadata, empty if not a vertex shader
328 */
329 [[nodiscard]] const VertexInputInfo& get_vertex_input() const
330 {
331 return m_vertex_input;
332 }
333
334 /**
335 * @brief Check if vertex input is available
336 */
337 [[nodiscard]] bool has_vertex_input() const
338 {
339 return !m_vertex_input.attributes.empty();
340 }
341
342 /**
343 * @brief Get fragment output state (fragment shaders only)
344 * @return Fragment output metadata, empty if not a fragment shader
345 */
346 [[nodiscard]] const FragmentOutputInfo& get_fragment_output() const
347 {
348 return m_fragment_output;
349 }
350
351 /**
352 * @brief Get detailed push constant info
353 * @return Push constant metadata (replaces simple PushConstantRange)
354 */
355 [[nodiscard]] const std::vector<PushConstantInfo>& get_push_constants() const
356 {
357 return m_push_constants;
358 }
359
360 // NEW: Workgroup size for compute shaders
361 /**
362 * @brief Get compute workgroup size (compute shaders only)
363 * @return {local_size_x, local_size_y, local_size_z} or nullopt
364 */
365 [[nodiscard]] std::optional<std::array<uint32_t, 3>> get_workgroup_size() const
366 {
367 return m_reflection.workgroup_size;
368 }
369
370 /**
371 * @brief Auto-detect shader stage from file extension
372 * @param filepath Path to shader file
373 * @return Detected stage, or nullopt if unknown extension
374 */
375 static std::optional<vk::ShaderStageFlagBits> detect_stage_from_extension(const std::string& filepath);
376
377private:
378 vk::ShaderModule m_module = nullptr;
379 vk::ShaderStageFlagBits m_stage = vk::ShaderStageFlagBits::eCompute;
380 std::string m_entry_point = "main";
381
383 std::vector<uint32_t> m_spirv_code; ///< Preserved SPIR-V (if enabled)
384
385 bool m_preserve_spirv {};
386
387 std::unordered_map<uint32_t, uint32_t> m_specialization_map;
388 std::vector<vk::SpecializationMapEntry> m_specialization_entries;
389 std::vector<uint32_t> m_specialization_data;
390 vk::SpecializationInfo m_specialization_info;
391
394 std::vector<PushConstantInfo> m_push_constants;
395
396 /**
397 * @brief Perform reflection on SPIR-V bytecode
398 * @param spirv_code SPIR-V bytecode
399 * @return true if reflection succeeded
400 *
401 * Uses SPIRV-Cross library to extract bindings, push constants,
402 * workgroup sizes, etc. Falls back to basic parsing if library unavailable.
403 */
404 bool reflect_spirv(const std::vector<uint32_t>& spirv_code);
405
406 /**
407 * @brief Compile GLSL to SPIR-V using shaderc
408 * @param glsl_source GLSL source code
409 * @param stage Shader stage (affects compiler settings)
410 * @param include_directories Include paths
411 * @param defines Preprocessor macros
412 * @return SPIR-V bytecode, or empty vector on failure
413 */
414 std::vector<uint32_t> compile_glsl_to_spirv(
415 const std::string& glsl_source,
416 vk::ShaderStageFlagBits stage,
417 const std::vector<std::string>& include_directories,
418 const std::unordered_map<std::string, std::string>& defines);
419
420 /**
421 * @brief Read binary file into vector
422 * @param filepath Path to file
423 * @return File contents, or empty vector on failure
424 */
425 static std::vector<uint32_t> read_spirv_file(const std::string& filepath);
426
427 /**
428 * @brief Read text file into string
429 * @param filepath Path to file
430 * @return File contents, or empty string on failure
431 */
432 static std::string read_text_file(const std::string& filepath);
433
434#ifndef MAYAFLUX_USE_SHADERC
435 /*
436 * @brief Compile GLSL to SPIR-V using external compiler
437 * @param glsl_source GLSL source code
438 * @param stage Shader stage (affects compiler settings)
439 * @param include_directories Include paths
440 * @param defines Preprocessor macros
441 */
442 static std::vector<uint32_t> compile_glsl_to_spirv_external(
443 const std::string& glsl_source,
444 vk::ShaderStageFlagBits stage,
445 const std::vector<std::string>& include_directories = {},
446 const std::unordered_map<std::string, std::string>& defines = {});
447#endif
448
449 /**
450 * @brief Update specialization info from current map
451 * Called before get_stage_create_info() to ensure fresh data
452 */
453 void update_specialization_info();
454
455 /**
456 * @brief Convert SPIRV-Cross type to Vulkan vertex attribute format
457 * @param type SPIR-V type information
458 * @return Corresponding Vulkan format
459 */
460 static vk::Format spirv_type_to_vk_format(const spirv_cross::SPIRType& type);
461};
462
463} // namespace MayaFlux::Core
vk::SpecializationInfo m_specialization_info
const std::vector< uint32_t > & get_spirv() const
Get SPIR-V bytecode.
std::optional< std::array< uint32_t, 3 > > get_workgroup_size() const
Get compute workgroup size (compute shaders only)
bool is_valid() const
Check if module is valid.
bool has_vertex_input() const
Check if vertex input is available.
std::vector< vk::SpecializationMapEntry > m_specialization_entries
vk::ShaderModule get() const
Get raw Vulkan shader module handle.
VKShaderModule & operator=(const VKShaderModule &)=delete
void set_preserve_spirv(bool preserve)
Enable SPIR-V preservation for hot-reload.
VKShaderModule(const VKShaderModule &)=delete
const FragmentOutputInfo & get_fragment_output() const
Get fragment output state (fragment shaders only)
std::vector< PushConstantInfo > m_push_constants
std::vector< uint32_t > m_specialization_data
const std::vector< PushConstantInfo > & get_push_constants() const
Get detailed push constant info.
std::unordered_map< uint32_t, uint32_t > m_specialization_map
const std::string & get_entry_point() const
Get entry point function name.
const ShaderReflection & get_reflection() const
Get shader reflection metadata.
std::vector< uint32_t > m_spirv_code
Preserved SPIR-V (if enabled)
const VertexInputInfo & get_vertex_input() const
Get vertex input state (vertex shaders only)
vk::ShaderStageFlagBits get_stage() const
Get shader stage.
Wrapper for Vulkan shader module with lifecycle and reflection.
int main()
Definition main.cpp:33
std::vector< Attachment > color_attachments
std::vector< vk::Format > color_formats
vk::ShaderStageFlags stage
Stage visibility.
uint32_t count
Array size (1 for non-arrays)
vk::DescriptorType type
Type (uniform buffer, storage buffer, etc.)
vk::ShaderStageFlags stage
Stage visibility.
uint32_t offset
Offset in push constant block.
std::vector< SpecializationConstant > specialization_constants
std::vector< vk::VertexInputBindingDescription > vertex_bindings
std::vector< DescriptorBinding > bindings
std::vector< PushConstantRange > push_constants
std::vector< vk::VertexInputAttributeDescription > vertex_attributes
std::optional< std::array< uint32_t, 3 > > workgroup_size
local_size_x/y/z
Metadata extracted from shader module.
std::vector< Binding > bindings
std::vector< Attribute > attributes
std::vector< vk::VertexInputAttributeDescription > attributes
std::vector< vk::VertexInputBindingDescription > bindings