diff --git a/README.md b/README.md index 9b106a24..61451a1f 100644 --- a/README.md +++ b/README.md @@ -308,7 +308,7 @@ Renders a complex scene with reflective surfaces using the new ray tracing exten #### [Callable ray tracing shaders](examples/raytracingcallable) -Callable shaders can be dynamically invoked from within other ray tracing shaders to execute different shaders based on your own conditions. The example ray traces multiple geometries, with each calling a different callable shader from the closest hit shader. +Callable shaders can be dynamically invoked from within other ray tracing shaders to execute different shaders based on dynamic conditions. The example ray traces multiple geometries, with each calling a different callable shader from the closest hit shader. #### [Ray query](examples/rayquery) diff --git a/data/shaders/glsl/raytracingcallable/callable3.rcall b/data/shaders/glsl/raytracingcallable/callable3.rcall index 6c52133b..1cc3a701 100644 --- a/data/shaders/glsl/raytracingcallable/callable3.rcall +++ b/data/shaders/glsl/raytracingcallable/callable3.rcall @@ -8,6 +8,4 @@ void main() // Generate a line pattern vec2 pos = vec2(gl_LaunchIDEXT / 8); outColor = vec3(mod(pos.y, 2.0)); - -// outColor = vec3(0.0, 0.0, 1.0); } \ No newline at end of file diff --git a/data/shaders/glsl/raytracingcallable/shadow.rmiss b/data/shaders/glsl/raytracingcallable/shadow.rmiss deleted file mode 100644 index 36d9b7ba..00000000 --- a/data/shaders/glsl/raytracingcallable/shadow.rmiss +++ /dev/null @@ -1,9 +0,0 @@ -#version 460 -#extension GL_EXT_ray_tracing : require - -layout(location = 2) rayPayloadInEXT bool shadowed; - -void main() -{ - shadowed = false; -} \ No newline at end of file diff --git a/data/shaders/glsl/raytracingcallable/shadow.rmiss.spv b/data/shaders/glsl/raytracingcallable/shadow.rmiss.spv deleted file mode 100644 index 65b01ba9..00000000 Binary files a/data/shaders/glsl/raytracingcallable/shadow.rmiss.spv and /dev/null differ diff --git a/data/shaders/hlsl/raytracingcallable/callable1.rcall b/data/shaders/hlsl/raytracingcallable/callable1.rcall new file mode 100644 index 00000000..40fd7c71 --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/callable1.rcall @@ -0,0 +1,14 @@ +// Copyright 2021 Sascha Willems + +struct CallData +{ + vec3 outColor; +}; + +[shader("callable")] +void main(inout CallData data) +{ + // Generate a checker board pattern + float2 pos = float2(DispatchRaysIndex() / 8); + data.outColor = float3(mod(pos.x + mod(pos.y, 2.0), 2.0)); +} \ No newline at end of file diff --git a/data/shaders/hlsl/raytracingcallable/callable2.rcall b/data/shaders/hlsl/raytracingcallable/callable2.rcall new file mode 100644 index 00000000..498dfe05 --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/callable2.rcall @@ -0,0 +1,12 @@ +// Copyright 2021 Sascha Willems + +struct CallData +{ + vec3 outColor; +}; + +[shader("callable")] +void main(inout CallData data) +{ + data.outColor = float3(0.0, 1.0, 0.0); +} \ No newline at end of file diff --git a/data/shaders/hlsl/raytracingcallable/callable3.rcall b/data/shaders/hlsl/raytracingcallable/callable3.rcall new file mode 100644 index 00000000..18c97cac --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/callable3.rcall @@ -0,0 +1,14 @@ +// Copyright 2021 Sascha Willems + +struct CallData +{ + vec3 outColor; +}; + +[shader("callable")] +void main(inout CallData data) +{ + // Generate a checker board pattern + float2 pos = float2(DispatchRaysIndex() / 8); + data.outColor = float3(mod(pos.y, 2.0)); +} \ No newline at end of file diff --git a/data/shaders/hlsl/raytracingcallable/closesthit.rchit b/data/shaders/hlsl/raytracingcallable/closesthit.rchit new file mode 100644 index 00000000..61caef21 --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/closesthit.rchit @@ -0,0 +1,26 @@ +// Copyright 2021 Sascha Willems + +struct Attribute +{ + float2 attribs; +}; + +struct Payload +{ +[[vk::location(0)]] float3 hitValue; +}; + +struct CallData +{ + float3 outColor; +}; + +[shader("closesthit")] +void main(inout Payload p, in float3 attribs) +{ + // Execute the callable shader indexed by the current geometry being hit + // For our sample this means that the first callable shader in the SBT is invoked for the first triangle, the second callable shader for the second triangle, etc. + CallData callData; + CallShader(GeometryIndex(), callData); + p.hitValue = callData.outColor; +} diff --git a/data/shaders/hlsl/raytracingcallable/closesthit.rchit.spv b/data/shaders/hlsl/raytracingcallable/closesthit.rchit.spv new file mode 100644 index 00000000..8e3fca2c Binary files /dev/null and b/data/shaders/hlsl/raytracingcallable/closesthit.rchit.spv differ diff --git a/data/shaders/hlsl/raytracingcallable/miss.rmiss b/data/shaders/hlsl/raytracingcallable/miss.rmiss new file mode 100644 index 00000000..3342b168 --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/miss.rmiss @@ -0,0 +1,12 @@ +// Copyright 2021 Sascha Willems + +struct Payload +{ +[[vk::location(0)]] float3 hitValue; +}; + +[shader("miss")] +void main(inout Payload p) +{ + p.hitValue = float3(0.0, 0.0, 0.2); +} \ No newline at end of file diff --git a/data/shaders/hlsl/raytracingcallable/miss.rmiss.spv b/data/shaders/hlsl/raytracingcallable/miss.rmiss.spv new file mode 100644 index 00000000..839732c5 Binary files /dev/null and b/data/shaders/hlsl/raytracingcallable/miss.rmiss.spv differ diff --git a/data/shaders/hlsl/raytracingcallable/raygen.rgen b/data/shaders/hlsl/raytracingcallable/raygen.rgen new file mode 100644 index 00000000..6e08cd05 --- /dev/null +++ b/data/shaders/hlsl/raytracingcallable/raygen.rgen @@ -0,0 +1,39 @@ +// Copyright 2021 Sascha Willems + +RaytracingAccelerationStructure rs : register(t0); +RWTexture2D image : register(u1); + +struct CameraProperties +{ + float4x4 viewInverse; + float4x4 projInverse; +}; +cbuffer cam : register(b2) { CameraProperties cam; }; + +struct Payload +{ +[[vk::location(0)]] float3 hitValue; +}; + +[shader("raygeneration")] +void main() +{ + uint3 LaunchID = DispatchRaysIndex(); + uint3 LaunchSize = DispatchRaysDimensions(); + + const float2 pixelCenter = float2(LaunchID.xy) + float2(0.5, 0.5); + const float2 inUV = pixelCenter/float2(LaunchSize.xy); + float2 d = inUV * 2.0 - 1.0; + float4 target = mul(cam.projInverse, float4(d.x, d.y, 1, 1)); + + RayDesc rayDesc; + rayDesc.Origin = mul(cam.viewInverse, float4(0,0,0,1)).xyz; + rayDesc.Direction = mul(cam.viewInverse, float4(normalize(target.xyz), 0)).xyz; + rayDesc.TMin = 0.001; + rayDesc.TMax = 10000.0; + + Payload payload; + TraceRay(rs, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 0, 0, rayDesc, payload); + + image[int2(LaunchID.xy)] = float4(payload.hitValue, 0.0); +} diff --git a/data/shaders/hlsl/raytracingcallable/raygen.rgen.spv b/data/shaders/hlsl/raytracingcallable/raygen.rgen.spv new file mode 100644 index 00000000..611d1de1 Binary files /dev/null and b/data/shaders/hlsl/raytracingcallable/raygen.rgen.spv differ diff --git a/examples/raytracingcallable/raytracingcallable.cpp b/examples/raytracingcallable/raytracingcallable.cpp index a949c0f6..b2e56af5 100644 --- a/examples/raytracingcallable/raytracingcallable.cpp +++ b/examples/raytracingcallable/raytracingcallable.cpp @@ -1,9 +1,11 @@ /* * Vulkan Example - Hardware accelerated ray tracing callable shaders example * -* Renders a complex scene using multiple hit and miss shaders for implementing shadows +* Dynamically calls different shaders based on the geoemtry id in the closest hit shader * -* Copyright (C) by Sascha Willems - www.saschawillems.de +* Relevant code parts are marked with [POI] +* +* Copyright (C) 2021 by Sascha Willems - www.saschawillems.de * * This code is licensed under the MIT license (MIT) (http://opensource.org/licenses/MIT) */ @@ -56,20 +58,22 @@ public: ~VulkanExample() { - vkDestroyPipeline(device, pipeline, nullptr); - vkDestroyPipelineLayout(device, pipelineLayout, nullptr); - vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); - deleteStorageImage(); - deleteAccelerationStructure(bottomLevelAS); - deleteAccelerationStructure(topLevelAS); - shaderBindingTables.raygen.destroy(); - shaderBindingTables.miss.destroy(); - shaderBindingTables.hit.destroy(); - shaderBindingTables.callable.destroy(); - vertexBuffer.destroy(); - indexBuffer.destroy(); - transformBuffer.destroy(); - ubo.destroy(); + if (device) { + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipelineLayout, nullptr); + vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); + deleteStorageImage(); + deleteAccelerationStructure(bottomLevelAS); + deleteAccelerationStructure(topLevelAS); + shaderBindingTables.raygen.destroy(); + shaderBindingTables.miss.destroy(); + shaderBindingTables.hit.destroy(); + shaderBindingTables.callable.destroy(); + vertexBuffer.destroy(); + indexBuffer.destroy(); + transformBuffer.destroy(); + ubo.destroy(); + } } /* @@ -183,6 +187,7 @@ public: accelerationBuildGeometryInfo.pGeometries = accelerationStructureGeometries.data(); accelerationBuildGeometryInfo.scratchData.deviceAddress = scratchBuffer.deviceAddress; + // [POI] The bottom level acceleration structure for this sample contains three separate triangle geometries, so we can use gl_GeometryIndexEXT in the closest hit shader to select different callable shaders std::vector accelerationStructureBuildRangeInfos{}; for (uint32_t i = 0; i < objectCount; i++) { VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfo{}; @@ -350,7 +355,7 @@ public: createShaderBindingTable(shaderBindingTables.raygen, 1); createShaderBindingTable(shaderBindingTables.miss, 1); createShaderBindingTable(shaderBindingTables.hit, 1); - // The callable shader binding table contains one shader handle per ray traced object + // [POI] The callable shader binding table contains one shader handle per ray traced object createShaderBindingTable(shaderBindingTables.callable, objectCount); // Copy handles @@ -437,50 +442,44 @@ public: Setup ray tracing shader groups */ std::vector shaderStages; + VkRayTracingShaderGroupCreateInfoKHR shaderGroup; // Ray generation shader group - { - shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/raygen.rgen.spv", VK_SHADER_STAGE_RAYGEN_BIT_KHR)); - VkRayTracingShaderGroupCreateInfoKHR shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); - shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; - shaderGroup.generalShader = static_cast(shaderStages.size()) - 1; - shaderGroup.closestHitShader = VK_SHADER_UNUSED_KHR; - shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; - shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; - shaderGroups.push_back(shaderGroup); - } + shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/raygen.rgen.spv", VK_SHADER_STAGE_RAYGEN_BIT_KHR)); + shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); + shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; + shaderGroup.generalShader = static_cast(shaderStages.size()) - 1; + shaderGroup.closestHitShader = VK_SHADER_UNUSED_KHR; + shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; + shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; + shaderGroups.push_back(shaderGroup); // Miss shader group - { - shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/miss.rmiss.spv", VK_SHADER_STAGE_MISS_BIT_KHR)); - VkRayTracingShaderGroupCreateInfoKHR shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); - shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; - shaderGroup.generalShader = static_cast(shaderStages.size()) - 1; - shaderGroup.closestHitShader = VK_SHADER_UNUSED_KHR; - shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; - shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; - shaderGroups.push_back(shaderGroup); - } + shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/miss.rmiss.spv", VK_SHADER_STAGE_MISS_BIT_KHR)); + shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); + shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; + shaderGroup.generalShader = static_cast(shaderStages.size()) - 1; + shaderGroup.closestHitShader = VK_SHADER_UNUSED_KHR; + shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; + shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; + shaderGroups.push_back(shaderGroup); // Closest hit shader group - { - shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/closesthit.rchit.spv", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR)); - VkRayTracingShaderGroupCreateInfoKHR shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); - shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; - shaderGroup.generalShader = VK_SHADER_UNUSED_KHR; - shaderGroup.closestHitShader = static_cast(shaderStages.size()) - 1; - shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; - shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; - shaderGroups.push_back(shaderGroup); - } + shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/closesthit.rchit.spv", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR)); + shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); + shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; + shaderGroup.generalShader = VK_SHADER_UNUSED_KHR; + shaderGroup.closestHitShader = static_cast(shaderStages.size()) - 1; + shaderGroup.anyHitShader = VK_SHADER_UNUSED_KHR; + shaderGroup.intersectionShader = VK_SHADER_UNUSED_KHR; + shaderGroups.push_back(shaderGroup); - // Callable shader group - // This sample's hit shader will call different callable shaders depending on the geometry index, so as we render three different geometries, we'll also use three callable shaders + // [POI] Callable shader group + // This sample's hit shader will call different callable shaders depending on the geometry index using executeCallableEXT, so as we render three geometries, we'll also use three callable shaders for (uint32_t i = 0; i < objectCount; i++) { shaderStages.push_back(loadShader(getShadersPath() + "raytracingcallable/callable" + std::to_string(i+1) + ".rcall.spv", VK_SHADER_STAGE_CALLABLE_BIT_KHR)); - VkRayTracingShaderGroupCreateInfoKHR shaderGroup{}; - shaderGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR; + shaderGroup = vks::initializers::rayTracingShaderGroupCreateInfoKHR(); shaderGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; shaderGroup.generalShader = static_cast(shaderStages.size()) - 1; shaderGroup.closestHitShader = VK_SHADER_UNUSED_KHR;