@@ -10,6 +10,24 @@ namespace Silk.NET.SilkTouch.Mods;
1010/// </summary>
1111public class TransformVulkan : IMod
1212{
13+ private const string MethodClassName = "Vk" ;
14+
15+ private const string InstanceTypeName = "InstanceHandle" ;
16+ private const string InstanceNativeTypeName = "VkInstance" ;
17+ private const string InstanceFieldName = "_currentInstance" ;
18+ private const string InstancePropertyName = "CurrentInstance" ;
19+
20+ private const string DeviceTypeName = "DeviceHandle" ;
21+ private const string DeviceNativeTypeName = "VkDevice" ;
22+ private const string DeviceFieldName = "_currentDevice" ;
23+ private const string DevicePropertyName = "CurrentDevice" ;
24+
25+ private const string VkCreateInstanceNativeName = "vkCreateInstance" ;
26+ private const string VkCreateDeviceNativeName = "vkCreateDevice" ;
27+
28+ private const string VkResultName = "Result" ;
29+ private const string VkResultSuccessName = "Success" ;
30+
1331 /// <inheritdoc />
1432 public async Task ExecuteAsync ( IModContext ctx , CancellationToken ct = default )
1533 {
@@ -38,25 +56,47 @@ public async Task ExecuteAsync(IModContext ctx, CancellationToken ct = default)
3856 ctx . SourceProject = proj ;
3957 }
4058
41- private class Rewriter : CSharpSyntaxRewriter
59+ /// <summary>
60+ /// Used by <see cref="Rewriter"/> to identify methods that call
61+ /// the native function pointer through the vtable slots field.
62+ /// </summary>
63+ private class SlotsMethodIdentifier : CSharpSyntaxWalker
4264 {
43- private const string MethodClassName = "Vk" ;
65+ public bool IsSlotsMethod { get ; private set ; }
66+
67+ private bool isInInvocationExpression = false ;
68+
69+ public override void VisitMethodDeclaration ( MethodDeclarationSyntax node )
70+ {
71+ IsSlotsMethod = false ;
72+ base . VisitMethodDeclaration ( node ) ;
73+ }
4474
45- private const string InstanceTypeName = "InstanceHandle" ;
46- private const string InstanceNativeTypeName = "VkInstance" ;
47- private const string InstanceFieldName = "_currentInstance" ;
48- private const string InstancePropertyName = "CurrentInstance" ;
75+ public override void VisitInvocationExpression ( InvocationExpressionSyntax node )
76+ {
77+ isInInvocationExpression = true ;
78+ base . VisitInvocationExpression ( node ) ;
79+ }
4980
50- private const string DeviceTypeName = "DeviceHandle" ;
51- private const string DeviceNativeTypeName = "VkDevice" ;
52- private const string DeviceFieldName = "_currentDevice" ;
53- private const string DevicePropertyName = "CurrentDevice" ;
81+ public override void VisitFunctionPointerType ( FunctionPointerTypeSyntax node )
82+ {
83+ if ( isInInvocationExpression )
84+ {
85+ IsSlotsMethod = true ;
86+ }
87+ }
88+ }
5489
55- private const string VkCreateInstanceNativeName = "vkCreateInstance" ;
56- private const string VkCreateDeviceNativeName = "vkCreateDevice" ;
90+ /// <summary>
91+ /// This does the following:
92+ /// 1. Add the instance/device members.
93+ /// 2. Rewrite the vkCreateInstance and vkCreateDevice methods to set those members.
94+ /// </summary>
95+ private class Rewriter : CSharpSyntaxRewriter
96+ {
97+ private readonly SlotsMethodIdentifier slotsMethodIdentifier = new ( ) ;
5798
58- private const string VkResultName = "Result" ;
59- private const string VkResultSuccessName = "Success" ;
99+ private bool hasOutputInstanceDeviceMembers ;
60100
61101 public override SyntaxNode ? VisitClassDeclaration ( ClassDeclarationSyntax node )
62102 {
@@ -65,28 +105,39 @@ private class Rewriter : CSharpSyntaxRewriter
65105 return base . VisitClassDeclaration ( node ) ;
66106 }
67107
68- var instanceField = FieldDeclaration (
69- VariableDeclaration ( NullableType ( IdentifierName ( InstanceTypeName ) ) )
70- . AddVariables ( VariableDeclarator ( InstanceFieldName ) )
71- ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
108+ // Rewrite members
109+ node = node . WithMembers ( [
110+ .. node . Members . SelectMany ( RewriteMember )
111+ ] ) ;
72112
113+ // Output instance/device members if needed
114+ if ( ! hasOutputInstanceDeviceMembers )
115+ {
116+ var instanceField = FieldDeclaration (
117+ VariableDeclaration ( NullableType ( IdentifierName ( InstanceTypeName ) ) )
118+ . AddVariables ( VariableDeclarator ( InstanceFieldName ) )
119+ ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
73120
74- var deviceField = FieldDeclaration (
75- VariableDeclaration ( NullableType ( IdentifierName ( DeviceTypeName ) ) )
76- . AddVariables ( VariableDeclarator ( DeviceFieldName ) )
77- ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
78121
79- var instanceProperty = CreateProperty ( InstanceTypeName , InstancePropertyName , InstanceFieldName ) ;
122+ var deviceField = FieldDeclaration (
123+ VariableDeclaration ( NullableType ( IdentifierName ( DeviceTypeName ) ) )
124+ . AddVariables ( VariableDeclarator ( DeviceFieldName ) )
125+ ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
80126
81- var deviceProperty = CreateProperty ( DeviceTypeName , DevicePropertyName , DeviceFieldName ) ;
127+ var instanceProperty = CreateProperty ( InstanceTypeName , InstancePropertyName , InstanceFieldName ) ;
82128
83- node = node . WithMembers ( [
84- instanceField ,
85- deviceField ,
86- instanceProperty ,
87- deviceProperty ,
88- ..node . Members . SelectMany ( RewriteMember )
89- ] ) ;
129+ var deviceProperty = CreateProperty ( DeviceTypeName , DevicePropertyName , DeviceFieldName ) ;
130+
131+ node = node . WithMembers ( [
132+ instanceField ,
133+ deviceField ,
134+ instanceProperty ,
135+ deviceProperty ,
136+ ..node . Members
137+ ] ) ;
138+ }
139+
140+ hasOutputInstanceDeviceMembers = true ;
90141
91142 return base . VisitClassDeclaration ( node ) ;
92143 }
@@ -105,13 +156,14 @@ private IEnumerable<MemberDeclarationSyntax> RewriteMember(MemberDeclarationSynt
105156 yield break ;
106157 }
107158
108- if ( ! method . Modifiers . Any ( modifier => modifier . IsKind ( SyntaxKind . ExternKeyword ) ) )
159+ if ( entryPoint != VkCreateInstanceNativeName && entryPoint != VkCreateDeviceNativeName )
109160 {
110161 yield return member ;
111162 yield break ;
112163 }
113164
114- if ( entryPoint != VkCreateInstanceNativeName && entryPoint != VkCreateDeviceNativeName )
165+ slotsMethodIdentifier . Visit ( member ) ;
166+ if ( ! slotsMethodIdentifier . IsSlotsMethod )
115167 {
116168 yield return member ;
117169 yield break ;
@@ -122,9 +174,9 @@ private IEnumerable<MemberDeclarationSyntax> RewriteMember(MemberDeclarationSynt
122174
123175 // Output the original method, but private
124176 yield return method
177+ . WithExplicitInterfaceSpecifier ( null )
125178 . WithIdentifier ( Identifier ( privateMethodName ) )
126179 . WithModifiers ( [
127- Token ( SyntaxKind . PrivateKeyword ) ,
128180 ..member . Modifiers . Where ( modifier =>
129181 ! SyntaxFacts . IsAccessibilityModifier ( modifier . Kind ( ) ) )
130182 ] ) ;
0 commit comments