MayaFlux 0.1.0
Digital-First Multimedia Processing Framework
Loading...
Searching...
No Matches
VKShaderModule.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <vulkan/vulkan.hpp>
4
5namespace MayaFlux::Core {
6
7/**
8 * @enum ShaderType
9 * @brief High-level shader type enumeration
10 *
11 * Used for specifying shader types in a generic way.
12 */
13enum class Stage : uint8_t {
14 COMPUTE,
15 VERTEX,
20};
21
23 std::vector<vk::Format> color_formats;
24 vk::Format depth_format;
25 vk::Format stencil_format;
26};
27
29 std::vector<vk::VertexInputBindingDescription> bindings;
30 std::vector<vk::VertexInputAttributeDescription> attributes;
31};
32
34 struct Attribute {
35 uint32_t location; // layout(location = N)
36 vk::Format format; // vec3 -> eR32G32B32Sfloat
37 uint32_t offset; // byte offset in vertex
38 std::string name; // variable name (from reflection)
39 };
40
41 struct Binding {
42 uint32_t binding; // vertex buffer binding point
43 uint32_t stride; // bytes per vertex
44 vk::VertexInputRate rate; // per-vertex or per-instance
45 };
46
47 std::vector<Attribute> attributes;
48 std::vector<Binding> bindings;
49};
50
52 struct Attachment {
53 uint32_t location; // layout(location = N)
54 vk::Format format; // vec4 -> eR32G32B32A32Sfloat
55 std::string name; // output variable name
56 };
57
58 std::vector<Attachment> color_attachments;
59 bool has_depth_output = false;
60 bool has_stencil_output = false;
61};
62
64 uint32_t offset;
65 uint32_t size;
66 std::string name; // struct name (if any)
67 vk::ShaderStageFlags stages; // which stages use it
68};
69
70/**
71 * @struct ShaderReflection
72 * @brief Metadata extracted from shader module
73 *
74 * Contains information about shader resources for descriptor set layout creation
75 * and pipeline configuration. Extracted via SPIRV-Reflect or manual parsing.
76 */
79 uint32_t set; ///< Descriptor set index
80 uint32_t binding; ///< Binding point within set
81 vk::DescriptorType type; ///< Type (uniform buffer, storage buffer, etc.)
82 vk::ShaderStageFlags stage; ///< Stage visibility
83 uint32_t count; ///< Array size (1 for non-arrays)
84 std::string name; ///< Variable name in shader
85 };
86
88 vk::ShaderStageFlags stage; ///< Stage visibility
89 uint32_t offset; ///< Offset in push constant block
90 uint32_t size; ///< Size in bytes
91 };
92
94 uint32_t constant_id; ///< Specialization constant ID
95 uint32_t size; ///< Size in bytes
96 std::string name; ///< Variable name in shader
97 };
98
99 std::vector<DescriptorBinding> bindings;
100 std::vector<PushConstantRange> push_constants;
101 std::vector<SpecializationConstant> specialization_constants;
102
103 std::optional<std::array<uint32_t, 3>> workgroup_size; ///< local_size_x/y/z
104
105 std::vector<vk::VertexInputAttributeDescription> vertex_attributes;
106 std::vector<vk::VertexInputBindingDescription> vertex_bindings;
107};
108
109/**
110 * @class VKShaderModule
111 * @brief Wrapper for Vulkan shader module with lifecycle and reflection
112 *
113 * Responsibilities:
114 * - Create vk::ShaderModule from SPIR-V binary or GLSL source
115 * - Load shaders from disk or memory
116 * - Extract shader metadata via reflection
117 * - Provide pipeline stage info for pipeline creation
118 * - Enable hot-reload support (recreation)
119 *
120 * Does NOT handle:
121 * - Pipeline creation (that's VKComputePipeline/VKGraphicsPipeline)
122 * - Descriptor set allocation (that's VKDescriptorManager)
123 * - Shader compilation (delegates to external compiler)
124 *
125 * Integration points:
126 * - VKComputePipeline/VKGraphicsPipeline: uses get_stage_create_info()
127 * - VKDescriptorManager: uses get_reflection() for layout creation
128 * - VKBufferProcessor: subclasses use this to load compute shaders
129 */
130class MAYAFLUX_API VKShaderModule {
131public:
132 VKShaderModule() = default;
134
137 VKShaderModule(VKShaderModule&&) noexcept;
138 VKShaderModule& operator=(VKShaderModule&&) noexcept;
139
140 /**
141 * @brief Create shader module from SPIR-V binary
142 * @param device Logical device
143 * @param spirv_code SPIR-V bytecode (must be aligned to uint32_t)
144 * @param stage Shader stage (compute, vertex, fragment, etc.)
145 * @param entry_point Entry point function name (default: "main")
146 * @param enable_reflection Extract descriptor bindings and resources
147 * @return true if creation succeeded
148 *
149 * This is the lowest-level creation method. All other create methods
150 * eventually funnel through this one.
151 */
152 bool create_from_spirv(
153 vk::Device device,
154 const std::vector<uint32_t>& spirv_code,
155 vk::ShaderStageFlagBits stage,
156 const std::string& entry_point = "main",
157 bool enable_reflection = true);
158
159 /**
160 * @brief Create shader module from SPIR-V file
161 * @param device Logical device
162 * @param spirv_path Path to .spv file
163 * @param stage Shader stage
164 * @param entry_point Entry point function name
165 * @param enable_reflection Extract metadata
166 * @return true if creation succeeded
167 *
168 * Reads binary file and calls create_from_spirv().
169 */
170 bool create_from_spirv_file(
171 vk::Device device,
172 const std::string& spirv_path,
173 vk::ShaderStageFlagBits stage,
174 const std::string& entry_point = "main",
175 bool enable_reflection = true);
176
177 /**
178 * @brief Create shader module from GLSL source string
179 * @param device Logical device
180 * @param glsl_source GLSL source code
181 * @param stage Shader stage (determines compiler mode)
182 * @param entry_point Entry point function name
183 * @param enable_reflection Extract metadata
184 * @param include_directories Paths for #include resolution
185 * @param defines Preprocessor definitions (e.g., {"DEBUG", "MAX_LIGHTS=4"})
186 * @return true if creation succeeded
187 *
188 * Compiles GLSL → SPIR-V using shaderc, then calls create_from_spirv().
189 * Requires shaderc library to be available.
190 */
191 bool create_from_glsl(
192 vk::Device device,
193 const std::string& glsl_source,
194 vk::ShaderStageFlagBits stage,
195 const std::string& entry_point = "main",
196 bool enable_reflection = true,
197 const std::vector<std::string>& include_directories = {},
198 const std::unordered_map<std::string, std::string>& defines = {});
199
200 /**
201 * @brief Create shader module from GLSL file
202 * @param device Logical device
203 * @param glsl_path Path to .comp/.vert/.frag/.geom file
204 * @param stage Shader stage (auto-detected from extension if not specified)
205 * @param entry_point Entry point function name
206 * @param enable_reflection Extract metadata
207 * @param include_directories Paths for #include resolution
208 * @param defines Preprocessor definitions
209 * @return true if creation succeeded
210 *
211 * Reads file, compiles GLSL → SPIR-V, calls create_from_spirv().
212 * Stage auto-detection:
213 * .comp → Compute
214 * .vert → Vertex
215 * .frag → Fragment
216 * .geom → Geometry
217 * .tesc → Tessellation Control
218 * .tese → Tessellation Evaluation
219 */
220 bool create_from_glsl_file(
221 vk::Device device,
222 const std::string& glsl_path,
223 std::optional<vk::ShaderStageFlagBits> stage = std::nullopt,
224 const std::string& entry_point = "main",
225 bool enable_reflection = true,
226 const std::vector<std::string>& include_directories = {},
227 const std::unordered_map<std::string, std::string>& defines = {});
228
229 /**
230 * @brief Cleanup shader module
231 * @param device Logical device (must match creation device)
232 *
233 * Destroys vk::ShaderModule and clears metadata.
234 * Safe to call multiple times or on uninitialized modules.
235 */
236 void cleanup(vk::Device device);
237
238 /**
239 * @brief Check if module is valid
240 * @return true if shader module was successfully created
241 */
242 [[nodiscard]] bool is_valid() const { return m_module != nullptr; }
243
244 /**
245 * @brief Get raw Vulkan shader module handle
246 * @return vk::ShaderModule handle
247 */
248 [[nodiscard]] vk::ShaderModule get() const { return m_module; }
249
250 /**
251 * @brief Get shader stage
252 * @return Stage flags (compute, vertex, fragment, etc.)
253 */
254 [[nodiscard]] vk::ShaderStageFlagBits get_stage() const { return m_stage; }
255
256 /**
257 * @brief Get entry point function name
258 * @return Entry point string (typically "main")
259 */
260 [[nodiscard]] const std::string& get_entry_point() const { return m_entry_point; }
261
262 /**
263 * @brief Get pipeline shader stage create info
264 * @return vk::PipelineShaderStageCreateInfo for pipeline creation
265 *
266 * This is the primary integration point with pipeline builders.
267 * Usage:
268 * auto stage_info = shader_module.get_stage_create_info();
269 * pipeline_builder.add_shader_stage(stage_info);
270 */
271 [[nodiscard]] vk::PipelineShaderStageCreateInfo get_stage_create_info() const;
272
273 /**
274 * @brief Get shader reflection metadata
275 * @return Const reference to extracted metadata
276 *
277 * Used by descriptor managers and pipeline builders to automatically
278 * configure layouts and bindings without manual specification.
279 */
280 [[nodiscard]] const ShaderReflection& get_reflection() const { return m_reflection; }
281
282 /**
283 * @brief Get SPIR-V bytecode
284 * @return Vector of SPIR-V words (empty if not preserved)
285 *
286 * Useful for caching, serialization, or re-creation.
287 * Only available if preserve_spirv was enabled during creation.
288 */
289 [[nodiscard]] const std::vector<uint32_t>& get_spirv() const { return m_spirv_code; }
290
291 /**
292 * @brief Set specialization constants
293 * @param constants Map of constant_id → value
294 *
295 * Updates the specialization info used in get_stage_create_info().
296 * Must be called before using the shader in pipeline creation.
297 *
298 * Example:
299 * shader.set_specialization_constants({
300 * {0, 256}, // WORKGROUP_SIZE = 256
301 * {1, 1} // ENABLE_OPTIMIZATION = true
302 * });
303 */
304 void set_specialization_constants(const std::unordered_map<uint32_t, uint32_t>& constants);
305
306 /**
307 * @brief Enable SPIR-V preservation for hot-reload
308 * @param preserve If true, stores SPIR-V bytecode internally
309 *
310 * Increases memory usage but enables recreation without recompilation.
311 */
312 void set_preserve_spirv(bool preserve) { m_preserve_spirv = preserve; }
313
314 /**
315 * @brief Get shader stage type
316 * @return Stage enum (easier than vk::ShaderStageFlagBits for logic)
317 */
318 [[nodiscard]] Stage get_stage_type() const;
319
320 /**
321 * @brief Get vertex input state (vertex shaders only)
322 * @return Vertex input metadata, empty if not a vertex shader
323 */
324 [[nodiscard]] const VertexInputInfo& get_vertex_input() const
325 {
326 return m_vertex_input;
327 }
328
329 /**
330 * @brief Check if vertex input is available
331 */
332 [[nodiscard]] bool has_vertex_input() const
333 {
334 return !m_vertex_input.attributes.empty();
335 }
336
337 /**
338 * @brief Get fragment output state (fragment shaders only)
339 * @return Fragment output metadata, empty if not a fragment shader
340 */
341 [[nodiscard]] const FragmentOutputInfo& get_fragment_output() const
342 {
343 return m_fragment_output;
344 }
345
346 /**
347 * @brief Get detailed push constant info
348 * @return Push constant metadata (replaces simple PushConstantRange)
349 */
350 [[nodiscard]] const std::vector<PushConstantInfo>& get_push_constants() const
351 {
352 return m_push_constants;
353 }
354
355 // NEW: Workgroup size for compute shaders
356 /**
357 * @brief Get compute workgroup size (compute shaders only)
358 * @return {local_size_x, local_size_y, local_size_z} or nullopt
359 */
360 [[nodiscard]] std::optional<std::array<uint32_t, 3>> get_workgroup_size() const
361 {
362 return m_reflection.workgroup_size;
363 }
364
365 /**
366 * @brief Auto-detect shader stage from file extension
367 * @param filepath Path to shader file
368 * @return Detected stage, or nullopt if unknown extension
369 */
370 static std::optional<vk::ShaderStageFlagBits> detect_stage_from_extension(const std::string& filepath);
371
372private:
373 vk::ShaderModule m_module = nullptr;
374 vk::ShaderStageFlagBits m_stage = vk::ShaderStageFlagBits::eCompute;
375 std::string m_entry_point = "main";
376
378 std::vector<uint32_t> m_spirv_code; ///< Preserved SPIR-V (if enabled)
379
380 bool m_preserve_spirv {};
381
382 std::unordered_map<uint32_t, uint32_t> m_specialization_map;
383 std::vector<vk::SpecializationMapEntry> m_specialization_entries;
384 std::vector<uint32_t> m_specialization_data;
385 vk::SpecializationInfo m_specialization_info;
386
389 std::vector<PushConstantInfo> m_push_constants;
390
391 /**
392 * @brief Perform reflection on SPIR-V bytecode
393 * @param spirv_code SPIR-V bytecode
394 * @return true if reflection succeeded
395 *
396 * Uses SPIRV-Reflect library to extract bindings, push constants,
397 * workgroup sizes, etc. Falls back to basic parsing if library unavailable.
398 */
399 bool reflect_spirv(const std::vector<uint32_t>& spirv_code);
400
401 /**
402 * @brief Compile GLSL to SPIR-V using shaderc
403 * @param glsl_source GLSL source code
404 * @param stage Shader stage (affects compiler settings)
405 * @param include_directories Include paths
406 * @param defines Preprocessor macros
407 * @return SPIR-V bytecode, or empty vector on failure
408 */
409 std::vector<uint32_t> compile_glsl_to_spirv(
410 const std::string& glsl_source,
411 vk::ShaderStageFlagBits stage,
412 const std::vector<std::string>& include_directories,
413 const std::unordered_map<std::string, std::string>& defines);
414
415 /**
416 * @brief Read binary file into vector
417 * @param filepath Path to file
418 * @return File contents, or empty vector on failure
419 */
420 static std::vector<uint32_t> read_spirv_file(const std::string& filepath);
421
422 /**
423 * @brief Read text file into string
424 * @param filepath Path to file
425 * @return File contents, or empty string on failure
426 */
427 static std::string read_text_file(const std::string& filepath);
428
429 /**
430 * @brief Update specialization info from current map
431 * Called before get_stage_create_info() to ensure fresh data
432 */
433 void update_specialization_info();
434};
435
436} // namespace MayaFlux::Core
vk::SpecializationInfo m_specialization_info
const std::vector< uint32_t > & get_spirv() const
Get SPIR-V bytecode.
std::optional< std::array< uint32_t, 3 > > get_workgroup_size() const
Get compute workgroup size (compute shaders only)
bool is_valid() const
Check if module is valid.
bool has_vertex_input() const
Check if vertex input is available.
std::vector< vk::SpecializationMapEntry > m_specialization_entries
vk::ShaderModule get() const
Get raw Vulkan shader module handle.
VKShaderModule & operator=(const VKShaderModule &)=delete
void set_preserve_spirv(bool preserve)
Enable SPIR-V preservation for hot-reload.
VKShaderModule(const VKShaderModule &)=delete
const FragmentOutputInfo & get_fragment_output() const
Get fragment output state (fragment shaders only)
std::vector< PushConstantInfo > m_push_constants
std::vector< uint32_t > m_specialization_data
const std::vector< PushConstantInfo > & get_push_constants() const
Get detailed push constant info.
std::unordered_map< uint32_t, uint32_t > m_specialization_map
const std::string & get_entry_point() const
Get entry point function name.
const ShaderReflection & get_reflection() const
Get shader reflection metadata.
std::vector< uint32_t > m_spirv_code
Preserved SPIR-V (if enabled)
const VertexInputInfo & get_vertex_input() const
Get vertex input state (vertex shaders only)
vk::ShaderStageFlagBits get_stage() const
Get shader stage.
Wrapper for Vulkan shader module with lifecycle and reflection.
int main()
Definition main.cpp:33
std::vector< Attachment > color_attachments
std::vector< vk::Format > color_formats
vk::ShaderStageFlags stage
Stage visibility.
uint32_t count
Array size (1 for non-arrays)
vk::DescriptorType type
Type (uniform buffer, storage buffer, etc.)
vk::ShaderStageFlags stage
Stage visibility.
uint32_t offset
Offset in push constant block.
std::vector< SpecializationConstant > specialization_constants
std::vector< vk::VertexInputBindingDescription > vertex_bindings
std::vector< DescriptorBinding > bindings
std::vector< PushConstantRange > push_constants
std::vector< vk::VertexInputAttributeDescription > vertex_attributes
std::optional< std::array< uint32_t, 3 > > workgroup_size
local_size_x/y/z
Metadata extracted from shader module.
std::vector< Binding > bindings
std::vector< Attribute > attributes
std::vector< vk::VertexInputAttributeDescription > attributes
std::vector< vk::VertexInputBindingDescription > bindings