From 6511d9d80295768d1d76be69c6a6a4390b85bdda Mon Sep 17 00:00:00 2001 From: Sascha Willems Date: Tue, 19 Mar 2024 18:29:43 +0100 Subject: [PATCH] Added HLSL shaders --- .../raytracingpositionfetch/closesthit.rchit | 57 +++++++++++++++++++ .../hlsl/raytracingpositionfetch/miss.rmiss | 16 ++++++ .../hlsl/raytracingpositionfetch/raygen.rgen | 43 ++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 shaders/hlsl/raytracingpositionfetch/closesthit.rchit create mode 100644 shaders/hlsl/raytracingpositionfetch/miss.rmiss create mode 100644 shaders/hlsl/raytracingpositionfetch/raygen.rgen diff --git a/shaders/hlsl/raytracingpositionfetch/closesthit.rchit b/shaders/hlsl/raytracingpositionfetch/closesthit.rchit new file mode 100644 index 00000000..59e91df1 --- /dev/null +++ b/shaders/hlsl/raytracingpositionfetch/closesthit.rchit @@ -0,0 +1,57 @@ +/* Copyright (c) 2024, Sascha Willems + * + * SPDX-License-Identifier: MIT + * + */ + +struct Attributes +{ + float2 bary; +}; + +struct Payload +{ + [[vk::location(0)]] float3 hitValue; +}; + +struct UBO +{ + float4x4 viewInverse; + float4x4 projInverse; + float4 lightPos; +}; +cbuffer ubo : register(b2) +{ + UBO ubo; +}; + +[[vk::ext_extension("SPV_KHR_ray_tracing_position_fetch")]] + +[shader("closesthit")] +void main(inout Payload p, in Attributes attribs) +{ + // We need the barycentric coordinates to calculate data for the current position + const float3 barycentricCoords = float3(1.0f - attribs.bary.x - attribs.bary.y, attribs.bary.x, attribs.bary.y); + + // With VK_KHR_ray_tracing_position_fetch we can access the vertices for the hit triangle in the shader + + // We need to use special syntax for SPIR-V inlines + #define HitTriangleVertexPositionsKHR 5391 + [[vk::ext_builtin_output(HitTriangleVertexPositionsKHR)]] + static float3 gl_HitTriangleVertexPositions[3]; + + float3 vertexPos0 = gl_HitTriangleVertexPositions[0]; + float3 vertexPos1 = gl_HitTriangleVertexPositions[1]; + float3 vertexPos2 = gl_HitTriangleVertexPositions[2]; + float3 currentPos = vertexPos0 * barycentricCoords.x + vertexPos1 * barycentricCoords.y + vertexPos2 * barycentricCoords.z; + + // Calcualte the normal from above values + float3 normal = normalize(cross(vertexPos1 - vertexPos0, vertexPos2 - vertexPos0)); + normal = normalize(mul(float4(normal, 1.0), WorldToObject4x3())); + + // Basic lighting + float3 lightDir = normalize(ubo.lightPos.xyz - currentPos); + float diffuse = max(dot(normal, lightDir), 0.0); + + p.hitValue.rgb = 0.1 + diffuse; +} \ No newline at end of file diff --git a/shaders/hlsl/raytracingpositionfetch/miss.rmiss b/shaders/hlsl/raytracingpositionfetch/miss.rmiss new file mode 100644 index 00000000..e28ec877 --- /dev/null +++ b/shaders/hlsl/raytracingpositionfetch/miss.rmiss @@ -0,0 +1,16 @@ +/* Copyright (c) 2024, Sascha Willems + * + * SPDX-License-Identifier: MIT + * + */ + +struct Payload +{ +[[vk::location(0)]] float3 hitValue; +}; + +[shader("miss")] +void main(inout Payload p) +{ + p.hitValue = float3(0.0, 0.0, 0.2); +} \ No newline at end of file diff --git a/shaders/hlsl/raytracingpositionfetch/raygen.rgen b/shaders/hlsl/raytracingpositionfetch/raygen.rgen new file mode 100644 index 00000000..a346a62f --- /dev/null +++ b/shaders/hlsl/raytracingpositionfetch/raygen.rgen @@ -0,0 +1,43 @@ +/* Copyright (c) 2024, Sascha Willems + * + * SPDX-License-Identifier: MIT + * + */ + +RaytracingAccelerationStructure rs : register(t0); +RWTexture2D image : register(u1); + +struct CameraProperties +{ + float4x4 viewInverse; + float4x4 projInverse; +}; +cbuffer cam : register(b2) { CameraProperties cam; }; + +struct Payload +{ +[[vk::location(0)]] float3 hitValue; +}; + +[shader("raygeneration")] +void main() +{ + uint3 LaunchID = DispatchRaysIndex(); + uint3 LaunchSize = DispatchRaysDimensions(); + + const float2 pixelCenter = float2(LaunchID.xy) + float2(0.5, 0.5); + const float2 inUV = pixelCenter/float2(LaunchSize.xy); + float2 d = inUV * 2.0 - 1.0; + float4 target = mul(cam.projInverse, float4(d.x, d.y, 1, 1)); + + RayDesc rayDesc; + rayDesc.Origin = mul(cam.viewInverse, float4(0,0,0,1)).xyz; + rayDesc.Direction = mul(cam.viewInverse, float4(normalize(target.xyz), 0)).xyz; + rayDesc.TMin = 0.001; + rayDesc.TMax = 10000.0; + + Payload payload; + TraceRay(rs, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 0, 0, rayDesc, payload); + + image[int2(LaunchID.xy)] = float4(payload.hitValue, 0.0); +}