diff --git a/shaders/slang/_rename.py b/shaders/slang/_rename.py index dfd09a4f..802137aa 100644 --- a/shaders/slang/_rename.py +++ b/shaders/slang/_rename.py @@ -24,13 +24,25 @@ def checkRenameFiles(samplename): "raytracinggltf.rmiss.spv": "miss.rmiss.spv", "raytracinggltf.rgen.spv": "raygen.rgen.spv", "raytracinggltf.rahit.spv": "anyhit.rahit.spv", - } + } + case "raytracingpositionfetch": + mappings = { + "raytracingpositionfetch.rchit.spv": "closesthit.rchit.spv", + "raytracingpositionfetch.rmiss.spv": "miss.rmiss.spv", + "raytracingpositionfetch.rgen.spv": "raygen.rgen.spv", + } case "raytracingreflections": mappings = { "raytracingreflections.rchit.spv": "closesthit.rchit.spv", "raytracingreflections.rmiss.spv": "miss.rmiss.spv", "raytracingreflections.rgen.spv": "raygen.rgen.spv", } + case "raytracingsbtdata": + mappings = { + "raytracingsbtdata.rchit.spv": "closesthit.rchit.spv", + "raytracingsbtdata.rmiss.spv": "miss.rmiss.spv", + "raytracingsbtdata.rgen.spv": "raygen.rgen.spv", + } case "raytracingshadows": mappings = { "raytracingshadows.rchit.spv": "closesthit.rchit.spv", diff --git a/shaders/slang/raytracingpositionfetch/raytracingpositionfetch.slang b/shaders/slang/raytracingpositionfetch/raytracingpositionfetch.slang new file mode 100644 index 00000000..96c341e9 --- /dev/null +++ b/shaders/slang/raytracingpositionfetch/raytracingpositionfetch.slang @@ -0,0 +1,78 @@ +/* Copyright (c) 2025, Sascha Willems + * + * SPDX-License-Identifier: MIT + * + */ + +struct Attributes +{ + float2 bary; +}; + +struct Payload +{ + [[vk::location(0)]] float3 hitValue; +}; + +RaytracingAccelerationStructure accelStruct; +RWTexture2D image; +struct UBO +{ + float4x4 viewInverse; + float4x4 projInverse; + float4 lightPos; +}; +ConstantBuffer ubo; + +[shader("raygeneration")] +void raygenerationMain() +{ + uint3 LaunchID = DispatchRaysIndex(); + uint3 LaunchSize = DispatchRaysDimensions(); + + const float2 pixelCenter = float2(LaunchID.xy) + float2(0.5, 0.5); + const float2 inUV = pixelCenter / float2(LaunchSize.xy); + float2 d = inUV * 2.0 - 1.0; + float4 target = mul(ubo.projInverse, float4(d.x, d.y, 1, 1)); + + RayDesc rayDesc; + rayDesc.Origin = mul(ubo.viewInverse, float4(0, 0, 0, 1)).xyz; + rayDesc.Direction = mul(ubo.viewInverse, float4(normalize(target.xyz), 0)).xyz; + rayDesc.TMin = 0.001; + rayDesc.TMax = 10000.0; + + Payload payload; + TraceRay(accelStruct, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 0, 0, rayDesc, payload); + + image[int2(LaunchID.xy)] = float4(payload.hitValue, 0.0); +} + +[shader("closesthit")] +void closesthitMain(inout Payload payload, in Attributes attribs) +{ + // We need the barycentric coordinates to calculate data for the current position + const float3 barycentricCoords = float3(1.0f - attribs.bary.x - attribs.bary.y, attribs.bary.x, attribs.bary.y); + + // With VK_KHR_ray_tracing_position_fetch we can access the vertices for the hit triangle in the shader + + float3 vertexPos0 = HitTriangleVertexPosition(0); + float3 vertexPos1 = HitTriangleVertexPosition(1); + float3 vertexPos2 = HitTriangleVertexPosition(2); + float3 currentPos = vertexPos0 * barycentricCoords.x + vertexPos1 * barycentricCoords.y + vertexPos2 * barycentricCoords.z; + + // Calcualte the normal from above values + float3 normal = normalize(cross(vertexPos1 - vertexPos0, vertexPos2 - vertexPos0)); + normal = normalize(mul(float4(normal, 1.0), WorldToObject4x3())); + + // Basic lighting + float3 lightDir = normalize(ubo.lightPos.xyz - currentPos); + float diffuse = max(dot(normal, lightDir), 0.0); + + payload.hitValue.rgb = 0.1 + diffuse; +} + +[shader("miss")] +void missMain(inout Payload payload) +{ + payload.hitValue = float3(0.0, 0.0, 0.2); +} \ No newline at end of file diff --git a/shaders/slang/raytracingsbtdata/raytracingsbtdata.slang b/shaders/slang/raytracingsbtdata/raytracingsbtdata.slang new file mode 100644 index 00000000..5ec0bc13 --- /dev/null +++ b/shaders/slang/raytracingsbtdata/raytracingsbtdata.slang @@ -0,0 +1,88 @@ +/* Copyright (c) 2025, Sascha Willems + * + * SPDX-License-Identifier: MIT + * + */ + +struct Attributes +{ + float2 bary; +}; + +struct Payload +{ + float3 hitValue; +}; + +RaytracingAccelerationStructure accelStruct; +RWTexture2D image; +struct CameraProperties +{ + float4x4 viewInverse; + float4x4 projInverse; +}; +ConstantBuffer cam; + +struct SBT { + float r; + float g; + float b; +}; +[[vk::shader_record]] ConstantBuffer sbt; + +[shader("raygeneration")] +void raygenerationMain() +{ + uint3 LaunchID = DispatchRaysIndex(); + uint3 LaunchSize = DispatchRaysDimensions(); + + const float2 pixelCenter = float2(LaunchID.xy) + float2(0.5, 0.5); + const float2 inUV = pixelCenter / float2(LaunchSize.xy); + float2 d = inUV * 2.0 - 1.0; + float4 target = mul(cam.projInverse, float4(d.x, d.y, 1, 1)); + + RayDesc rayDesc; + rayDesc.Origin = mul(cam.viewInverse, float4(0, 0, 0, 1)).xyz; + rayDesc.Direction = mul(cam.viewInverse, float4(normalize(target.xyz), 0)).xyz; + rayDesc.TMin = 0.001; + rayDesc.TMax = 10000.0; + + Payload payload; + + // use border to demonstrate raygen record data + if (all(LaunchID.xy > int2(16, 16)) && all(LaunchID.xy < LaunchSize.xy - int2(16, 16))) + { + // Generate a checker board pattern to trace out rays or use hit record data + int2 pos = int2(LaunchID.xy / 16); + if (((pos.x + pos.y % 2) % 2) == 0) { + // This will set hit value to either hit or miss SBT record color + TraceRay(accelStruct, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 0, 0, rayDesc, payload); + } + else { + // Set the hit value to the raygen SBT data + payload.hitValue = float3(sbt.r, sbt.g, sbt.b); + } + } + else { + // Set hit value to black + payload.hitValue = float3(0.0, 0.0, 0.0); + } + + image[int2(LaunchID.xy)] = float4(payload.hitValue, 0.0); +} + +[shader("closesthit")] +void closesthitMain(inout Payload payload, in Attributes attribs) +{ + // Update the hit value to the hit record SBT data associated with this + // geometry ID and ray ID + payload.hitValue = float3(sbt.r, sbt.g, sbt.g); +} + +[shader("miss")] +void missMain(inout Payload payload) +{ + // Update the hit value to the hit record SBT data associated with this + // miss record + payload.hitValue = float3(sbt.r, sbt.g, sbt.g); +}