Skip to content

Commit a87c47c

Browse files
committed
add test and implement source generated logging
1 parent 926aacd commit a87c47c

File tree

2 files changed

+155
-112
lines changed

2 files changed

+155
-112
lines changed

src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs

Lines changed: 114 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
namespace ModelContextProtocol.Server;
1212

1313
/// <summary>Provides an <see cref="McpServerTool"/> that's implemented via an <see cref="AIFunction"/>.</summary>
14-
internal sealed class AIFunctionMcpServerTool : McpServerTool
14+
internal sealed partial class AIFunctionMcpServerTool : McpServerTool
1515
{
1616
private readonly ILogger _logger;
1717

1818
/// <summary>
1919
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="Delegate"/> instance.
2020
/// </summary>
2121
public static new AIFunctionMcpServerTool Create(
22-
Delegate method,
23-
McpServerToolCreateOptions? options)
22+
Delegate method,
23+
McpServerToolCreateOptions? options)
2424
{
2525
Throw.IfNull(method);
2626

@@ -33,139 +33,139 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
3333
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="MethodInfo"/> instance.
3434
/// </summary>
3535
public static new AIFunctionMcpServerTool Create(
36-
MethodInfo method,
37-
object? target,
38-
McpServerToolCreateOptions? options)
36+
MethodInfo method,
37+
object? target,
38+
McpServerToolCreateOptions? options)
3939
{
4040
Throw.IfNull(method);
4141

4242
options = DeriveOptions(method, options);
4343

4444
return Create(
45-
AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)),
46-
options);
45+
AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)),
46+
options);
4747
}
4848

4949
/// <summary>
5050
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="MethodInfo"/> instance.
5151
/// </summary>
5252
public static new AIFunctionMcpServerTool Create(
53-
MethodInfo method,
54-
Func<RequestContext<CallToolRequestParams>, object> createTargetFunc,
55-
McpServerToolCreateOptions? options)
53+
MethodInfo method,
54+
Func<RequestContext<CallToolRequestParams>, object> createTargetFunc,
55+
McpServerToolCreateOptions? options)
5656
{
5757
Throw.IfNull(method);
5858
Throw.IfNull(createTargetFunc);
5959

6060
options = DeriveOptions(method, options);
6161

6262
return Create(
63-
AIFunctionFactory.Create(method, args =>
64-
{
65-
var request = (RequestContext<CallToolRequestParams>)args.Context![typeof(RequestContext<CallToolRequestParams>)]!;
66-
return createTargetFunc(request);
67-
}, CreateAIFunctionFactoryOptions(method, options)),
68-
options);
63+
AIFunctionFactory.Create(method, args =>
64+
{
65+
var request = (RequestContext<CallToolRequestParams>)args.Context![typeof(RequestContext<CallToolRequestParams>)]!;
66+
return createTargetFunc(request);
67+
}, CreateAIFunctionFactoryOptions(method, options)),
68+
options);
6969
}
7070

7171
// TODO: Fix the need for this suppression.
7272
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2111:ReflectionToDynamicallyAccessedMembers",
73-
Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")]
73+
Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")]
7474
internal static Func<Type, AIFunctionArguments, object> GetCreateInstanceFunc() =>
75-
static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ?
76-
ActivatorUtilities.CreateInstance(services, type) :
77-
Activator.CreateInstance(type)!;
75+
static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ?
76+
ActivatorUtilities.CreateInstance(services, type) :
77+
Activator.CreateInstance(type)!;
7878

7979
private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
80-
MethodInfo method, McpServerToolCreateOptions? options) =>
81-
new()
80+
MethodInfo method, McpServerToolCreateOptions? options) =>
81+
new()
82+
{
83+
Name = options?.Name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name,
84+
Description = options?.Description,
85+
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
86+
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
87+
ConfigureParameterBinding = pi =>
88+
{
89+
if (pi.ParameterType == typeof(RequestContext<CallToolRequestParams>))
8290
{
83-
Name = options?.Name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name,
84-
Description = options?.Description,
85-
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
86-
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
87-
ConfigureParameterBinding = pi =>
91+
return new()
8892
{
89-
if (pi.ParameterType == typeof(RequestContext<CallToolRequestParams>))
90-
{
91-
return new()
92-
{
93-
ExcludeFromSchema = true,
94-
BindParameter = (pi, args) => GetRequestContext(args),
95-
};
96-
}
93+
ExcludeFromSchema = true,
94+
BindParameter = (pi, args) => GetRequestContext(args),
95+
};
96+
}
9797

98-
if (pi.ParameterType == typeof(IMcpServer))
99-
{
100-
return new()
101-
{
102-
ExcludeFromSchema = true,
103-
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
104-
};
105-
}
98+
if (pi.ParameterType == typeof(IMcpServer))
99+
{
100+
return new()
101+
{
102+
ExcludeFromSchema = true,
103+
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
104+
};
105+
}
106106

107-
if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
107+
if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
108+
{
109+
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
110+
// if there is one. If we can't get one, return a nop progress.
111+
return new()
112+
{
113+
ExcludeFromSchema = true,
114+
BindParameter = (pi, args) =>
115+
{
116+
var requestContent = GetRequestContext(args);
117+
if (requestContent?.Server is { } server &&
118+
requestContent?.Params?.Meta?.ProgressToken is { } progressToken)
108119
{
109-
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
110-
// if there is one. If we can't get one, return a nop progress.
111-
return new()
112-
{
113-
ExcludeFromSchema = true,
114-
BindParameter = (pi, args) =>
115-
{
116-
var requestContent = GetRequestContext(args);
117-
if (requestContent?.Server is { } server &&
118-
requestContent?.Params?.Meta?.ProgressToken is { } progressToken)
119-
{
120-
return new TokenProgress(server, progressToken);
121-
}
122-
123-
return NullProgress.Instance;
124-
},
125-
};
120+
return new TokenProgress(server, progressToken);
126121
}
127122

128-
if (options?.Services is { } services &&
129-
services.GetService<IServiceProviderIsService>() is { } ispis &&
130-
ispis.IsService(pi.ParameterType))
131-
{
132-
return new()
133-
{
134-
ExcludeFromSchema = true,
135-
BindParameter = (pi, args) =>
136-
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
137-
(pi.HasDefaultValue ? null :
138-
throw new ArgumentException("No service of the requested type was found.")),
139-
};
140-
}
123+
return NullProgress.Instance;
124+
},
125+
};
126+
}
141127

142-
if (pi.GetCustomAttribute<FromKeyedServicesAttribute>() is { } keyedAttr)
143-
{
144-
return new()
145-
{
146-
ExcludeFromSchema = true,
147-
BindParameter = (pi, args) =>
148-
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
149-
(pi.HasDefaultValue ? null :
150-
throw new ArgumentException("No service of the requested type was found.")),
151-
};
152-
}
128+
if (options?.Services is { } services &&
129+
services.GetService<IServiceProviderIsService>() is { } ispis &&
130+
ispis.IsService(pi.ParameterType))
131+
{
132+
return new()
133+
{
134+
ExcludeFromSchema = true,
135+
BindParameter = (pi, args) =>
136+
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
137+
(pi.HasDefaultValue ? null :
138+
throw new ArgumentException("No service of the requested type was found.")),
139+
};
140+
}
141+
142+
if (pi.GetCustomAttribute<FromKeyedServicesAttribute>() is { } keyedAttr)
143+
{
144+
return new()
145+
{
146+
ExcludeFromSchema = true,
147+
BindParameter = (pi, args) =>
148+
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
149+
(pi.HasDefaultValue ? null :
150+
throw new ArgumentException("No service of the requested type was found.")),
151+
};
152+
}
153153

154-
return default;
154+
return default;
155155

156-
static RequestContext<CallToolRequestParams>? GetRequestContext(AIFunctionArguments args)
157-
{
158-
if (args.Context?.TryGetValue(typeof(RequestContext<CallToolRequestParams>), out var orc) is true &&
159-
orc is RequestContext<CallToolRequestParams> requestContext)
160-
{
161-
return requestContext;
162-
}
156+
static RequestContext<CallToolRequestParams>? GetRequestContext(AIFunctionArguments args)
157+
{
158+
if (args.Context?.TryGetValue(typeof(RequestContext<CallToolRequestParams>), out var orc) is true &&
159+
orc is RequestContext<CallToolRequestParams> requestContext)
160+
{
161+
return requestContext;
162+
}
163163

164-
return null;
165-
}
166-
},
167-
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
168-
};
164+
return null;
165+
}
166+
},
167+
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
168+
};
169169

170170
/// <summary>Creates an <see cref="McpServerTool"/> that wraps the specified <see cref="AIFunction"/>.</summary>
171171
public static new AIFunctionMcpServerTool Create(AIFunction function, McpServerToolCreateOptions? options)
@@ -182,10 +182,10 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
182182
if (options is not null)
183183
{
184184
if (options.Title is not null ||
185-
options.Idempotent is not null ||
186-
options.Destructive is not null ||
187-
options.OpenWorld is not null ||
188-
options.ReadOnly is not null)
185+
options.Idempotent is not null ||
186+
options.Destructive is not null ||
187+
options.OpenWorld is not null ||
188+
options.ReadOnly is not null)
189189
{
190190
tool.Annotations = new()
191191
{
@@ -255,7 +255,7 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider
255255

256256
/// <inheritdoc />
257257
public override async ValueTask<CallToolResponse> InvokeAsync(
258-
RequestContext<CallToolRequestParams> request, CancellationToken cancellationToken = default)
258+
RequestContext<CallToolRequestParams> request, CancellationToken cancellationToken = default)
259259
{
260260
Throw.IfNull(request);
261261
cancellationToken.ThrowIfCancellationRequested();
@@ -282,12 +282,11 @@ public override async ValueTask<CallToolResponse> InvokeAsync(
282282
}
283283
catch (Exception e) when (e is not OperationCanceledException)
284284
{
285-
_logger.LogError(e, "Error invoking AIFunction tool '{ToolName}' with arguments '{Args}'.",
286-
request.Params?.Name, string.Join(",", request.Params?.Arguments?.Keys ?? Array.Empty<string>()));
285+
ToolCallError(request.Params?.Name ?? string.Empty, e);
287286

288287
string errorMessage = e is McpException ?
289-
$"An error occurred invoking '{request.Params?.Name}': {e.Message}" :
290-
$"An error occurred invoking '{request.Params?.Name}'.";
288+
$"An error occurred invoking '{request.Params?.Name}': {e.Message}" :
289+
$"An error occurred invoking '{request.Params?.Name}'.";
291290

292291
return new()
293292
{
@@ -336,10 +335,10 @@ public override async ValueTask<CallToolResponse> InvokeAsync(
336335
_ => new()
337336
{
338337
Content = [new()
339-
{
340-
Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))),
341-
Type = "text"
342-
}]
338+
{
339+
Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))),
340+
Type = "text"
341+
}]
343342
},
344343
};
345344
}
@@ -367,4 +366,7 @@ private static CallToolResponse ConvertAIContentEnumerableToCallToolResponse(IEn
367366
IsError = allErrorContent && hasAny
368367
};
369368
}
369+
370+
[LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")]
371+
private partial void ToolCallError(string toolName, Exception exception);
370372
}

tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Microsoft.Extensions.AI;
22
using Microsoft.Extensions.DependencyInjection;
3+
using Microsoft.Extensions.Logging;
34
using ModelContextProtocol.Protocol;
45
using ModelContextProtocol.Server;
6+
using ModelContextProtocol.Tests.Utils;
57
using Moq;
68
using System.Reflection;
79
using System.Text.Json;
@@ -381,6 +383,45 @@ public async Task SupportsSchemaCreateOptions()
381383
);
382384
}
383385

386+
[Fact]
387+
public async Task ToolCallError_LogsErrorMessage()
388+
{
389+
// Arrange
390+
var mockLoggerProvider = new MockLoggerProvider();
391+
var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider });
392+
var services = new ServiceCollection();
393+
services.AddSingleton<ILoggerFactory>(loggerFactory);
394+
var serviceProvider = services.BuildServiceProvider();
395+
396+
var toolName = "tool-that-throws";
397+
var exceptionMessage = "Test exception message";
398+
399+
McpServerTool tool = McpServerTool.Create(() =>
400+
{
401+
throw new InvalidOperationException(exceptionMessage);
402+
}, new() { Name = toolName, Services = serviceProvider });
403+
404+
var mockServer = new Mock<IMcpServer>();
405+
var request = new RequestContext<CallToolRequestParams>(mockServer.Object)
406+
{
407+
Params = new CallToolRequestParams() { Name = toolName },
408+
Services = serviceProvider
409+
};
410+
411+
// Act
412+
var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken);
413+
414+
// Assert
415+
Assert.True(result.IsError);
416+
Assert.Single(result.Content);
417+
Assert.Equal($"An error occurred invoking '{toolName}'.", result.Content[0].Text);
418+
419+
var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error);
420+
Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message);
421+
Assert.IsType<InvalidOperationException>(errorLog.Exception);
422+
Assert.Equal(exceptionMessage, errorLog.Exception.Message);
423+
}
424+
384425
private sealed class MyService;
385426

386427
private class DisposableToolType : IDisposable

0 commit comments

Comments
 (0)