Skip to content

chore: add tests for metadata actions #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/tools/mongodb/metadata/collectionSchema.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import { ToolArgs, OperationType } from "../../tool.js";
import { parseSchema, SchemaField } from "mongodb-schema";
import { getSimplifiedSchema } from "mongodb-schema";

export class CollectionSchemaTool extends MongoDBToolBase {
protected name = "collection-schema";
Expand All @@ -13,29 +13,31 @@ export class CollectionSchemaTool extends MongoDBToolBase {
protected async execute({ database, collection }: ToolArgs<typeof DbOperationArgs>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const documents = await provider.find(database, collection, {}, { limit: 5 }).toArray();
const schema = await parseSchema(documents);
const schema = await getSimplifiedSchema(documents);

const fieldsCount = Object.entries(schema).length;
if (fieldsCount === 0) {
return {
content: [
{
text: `Could not deduce the schema for "${database}.${collection}". This may be because it doesn't exist or is empty.`,
type: "text",
},
],
};
}

return {
content: [
{
text: `Found ${schema.fields.length} fields in the schema for \`${database}.${collection}\``,
text: `Found ${fieldsCount} fields in the schema for "${database}.${collection}"`,
type: "text",
},
{
text: this.formatFieldOutput(schema.fields),
text: JSON.stringify(schema),
type: "text",
},
],
};
}

private formatFieldOutput(fields: SchemaField[]): string {
let result = "| Field | Type | Confidence |\n";
result += "|-------|------|-------------|\n";
for (const field of fields) {
const fieldType = Array.isArray(field.type) ? field.type.join(", ") : field.type;
result += `| ${field.name} | \`${fieldType}\` | ${(field.probability * 100).toFixed(0)}% |\n`;
}
return result;
}
}
44 changes: 41 additions & 3 deletions src/tools/mongodb/metadata/collectionStorageSize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { ToolArgs, OperationType } from "../../tool.js";

export class CollectionStorageSizeTool extends MongoDBToolBase {
protected name = "collection-storage-size";
protected description = "Gets the size of the collection in MB";
protected description = "Gets the size of the collection";
protected argsShape = DbOperationArgs;

protected operationType: OperationType = "metadata";
Expand All @@ -14,17 +14,55 @@ export class CollectionStorageSizeTool extends MongoDBToolBase {
const [{ value }] = (await provider
.aggregate(database, collection, [
{ $collStats: { storageStats: {} } },
{ $group: { _id: null, value: { $sum: "$storageStats.storageSize" } } },
{ $group: { _id: null, value: { $sum: "$storageStats.size" } } },
])
.toArray()) as [{ value: number }];

const { units, value: scaledValue } = CollectionStorageSizeTool.getStats(value);

return {
content: [
{
text: `The size of \`${database}.${collection}\` is \`${(value / 1024 / 1024).toFixed(2)} MB\``,
text: `The size of "${database}.${collection}" is \`${scaledValue.toFixed(2)} ${units}\``,
type: "text",
},
],
};
}

protected handleError(
error: unknown,
args: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> | CallToolResult {
if (error instanceof Error && "codeName" in error && error.codeName === "NamespaceNotFound") {
return {
content: [
{
text: `The size of "${args.database}.${args.collection}" cannot be determined because the collection does not exist.`,
type: "text",
},
],
};
}

return super.handleError(error, args);
}

private static getStats(value: number): { value: number; units: string } {
const kb = 1024;
const mb = kb * 1024;
const gb = mb * 1024;

if (value > gb) {
return { value: value / gb, units: "GB" };
}

if (value > mb) {
return { value: value / mb, units: "MB" };
}
if (value > kb) {
return { value: value / kb, units: "KB" };
}
return { value, units: "bytes" };
}
}
9 changes: 6 additions & 3 deletions src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { z } from "zod";
import { ToolBase, ToolCategory } from "../tool.js";
import { ToolArgs, ToolBase, ToolCategory } from "../tool.js";
import { Session } from "../../session.js";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
Expand Down Expand Up @@ -30,7 +30,10 @@ export abstract class MongoDBToolBase extends ToolBase {
return this.session.serviceProvider;
}

protected handleError(error: unknown): Promise<CallToolResult> | CallToolResult {
protected handleError(
error: unknown,
args: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> | CallToolResult {
if (error instanceof MongoDBError && error.code === ErrorCodes.NotConnectedToMongoDB) {
return {
content: [
Expand All @@ -47,7 +50,7 @@ export abstract class MongoDBToolBase extends ToolBase {
};
}

return super.handleError(error);
return super.handleError(error, args);
}

protected async connectToMongoDB(connectionString: string): Promise<void> {
Expand Down
8 changes: 6 additions & 2 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export abstract class ToolBase {
} catch (error: unknown) {
logger.error(mongoLogId(1_000_000), "tool", `Error executing ${this.name}: ${error as string}`);

return await this.handleError(error);
return await this.handleError(error, args[0] as ToolArgs<typeof this.argsShape>);
}
};

Expand Down Expand Up @@ -76,7 +76,11 @@ export abstract class ToolBase {
}

// This method is intended to be overridden by subclasses to handle errors
protected handleError(error: unknown): Promise<CallToolResult> | CallToolResult {
protected handleError(
error: unknown,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
args: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> | CallToolResult {
return {
content: [
{
Expand Down
92 changes: 88 additions & 4 deletions tests/integration/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { MongoClient, ObjectId } from "mongodb";
import { toIncludeAllMembers } from "jest-extended";
import config from "../../src/config.js";
import { McpError } from "@modelcontextprotocol/sdk/types.js";

interface ParameterInfo {
name: string;
Expand Down Expand Up @@ -223,10 +224,93 @@ export const dbOperationParameters: ParameterInfo[] = [
{ name: "collection", type: "string", description: "Collection name", required: true },
];

export function validateParameters(tool: ToolInfo, parameters: ParameterInfo[]): void {
const toolParameters = getParameters(tool);
expect(toolParameters).toHaveLength(parameters.length);
expect(toolParameters).toIncludeAllMembers(parameters);
export const dbOperationInvalidArgTests = [{}, { database: 123 }, { foo: "bar", database: "test" }, { database: [] }];

export function validateToolMetadata(
integration: IntegrationTest,
name: string,
description: string,
parameters: ParameterInfo[]
): void {
it("should have correct metadata", async () => {
const { tools } = await integration.mcpClient().listTools();
const tool = tools.find((tool) => tool.name === name)!;
expect(tool).toBeDefined();
expect(tool.description).toBe(description);

const toolParameters = getParameters(tool);
expect(toolParameters).toHaveLength(parameters.length);
expect(toolParameters).toIncludeAllMembers(parameters);
});
}

export function validateAutoConnectBehavior(
integration: IntegrationTest,
name: string,
validation: () => {
args: { [x: string]: unknown };
expectedResponse?: string;
validate?: (content: unknown) => void;
},
beforeEachImpl?: () => Promise<void>
): void {
describe("when not connected", () => {
if (beforeEachImpl) {
beforeEach(() => beforeEachImpl());
}

it("connects automatically if connection string is configured", async () => {
config.connectionString = integration.connectionString();

const validationInfo = validation();

const response = await integration.mcpClient().callTool({
name,
arguments: validationInfo.args,
});

if (validationInfo.expectedResponse) {
const content = getResponseContent(response.content);
expect(content).toContain(validationInfo.expectedResponse);
}

if (validationInfo.validate) {
validationInfo.validate(response.content);
}
});

it("throws an error if connection string is not configured", async () => {
const response = await integration.mcpClient().callTool({
name,
arguments: validation().args,
});
const content = getResponseContent(response.content);
expect(content).toContain("You need to connect to a MongoDB instance before you can access its data.");
});
});
}

export function validateThrowsForInvalidArguments(
integration: IntegrationTest,
name: string,
args: { [x: string]: unknown }[]
): void {
describe("with invalid arguments", () => {
for (const arg of args) {
it(`throws a schema error for: ${JSON.stringify(arg)}`, async () => {
await integration.connectMcpClient();
try {
await integration.mcpClient().callTool({ name, arguments: arg });
expect.fail("Expected an error to be thrown");
} catch (error) {
expect(error).toBeInstanceOf(McpError);
const mcpError = error as McpError;
expect(mcpError.code).toEqual(-32602);
expect(mcpError.message).toContain(`Invalid arguments for tool ${name}`);
}
});
}
});
}

export function describeAtlas(name: number | string | Function | jest.FunctionLike, fn: jest.EmptyFunction) {
Expand Down
73 changes: 16 additions & 57 deletions tests/integration/tools/mongodb/create/createCollection.test.ts
Original file line number Diff line number Diff line change
@@ -1,50 +1,24 @@
import {
getResponseContent,
validateParameters,
dbOperationParameters,
setupIntegrationTest,
validateToolMetadata,
validateAutoConnectBehavior,
validateThrowsForInvalidArguments,
dbOperationInvalidArgTests,
} from "../../../helpers.js";
import { toIncludeSameMembers } from "jest-extended";
import { McpError } from "@modelcontextprotocol/sdk/types.js";
import { ObjectId } from "bson";
import config from "../../../../../src/config.js";

describe("createCollection tool", () => {
const integration = setupIntegrationTest();

it("should have correct metadata", async () => {
const { tools } = await integration.mcpClient().listTools();
const listCollections = tools.find((tool) => tool.name === "create-collection")!;
expect(listCollections).toBeDefined();
expect(listCollections.description).toBe(
"Creates a new collection in a database. If the database doesn't exist, it will be created automatically."
);
validateToolMetadata(
integration,
"create-collection",
"Creates a new collection in a database. If the database doesn't exist, it will be created automatically.",
dbOperationParameters
);

validateParameters(listCollections, dbOperationParameters);
});

describe("with invalid arguments", () => {
const args = [
{},
{ database: 123, collection: "bar" },
{ foo: "bar", database: "test", collection: "bar" },
{ collection: [], database: "test" },
];
for (const arg of args) {
it(`throws a schema error for: ${JSON.stringify(arg)}`, async () => {
await integration.connectMcpClient();
try {
await integration.mcpClient().callTool({ name: "create-collection", arguments: arg });
expect.fail("Expected an error to be thrown");
} catch (error) {
expect(error).toBeInstanceOf(McpError);
const mcpError = error as McpError;
expect(mcpError.code).toEqual(-32602);
expect(mcpError.message).toContain("Invalid arguments for tool create-collection");
}
});
}
});
validateThrowsForInvalidArguments(integration, "create-collection", dbOperationInvalidArgTests);

describe("with non-existent database", () => {
it("creates a new collection", async () => {
Expand Down Expand Up @@ -114,25 +88,10 @@ describe("createCollection tool", () => {
});
});

describe("when not connected", () => {
it("connects automatically if connection string is configured", async () => {
config.connectionString = integration.connectionString();

const response = await integration.mcpClient().callTool({
name: "create-collection",
arguments: { database: integration.randomDbName(), collection: "new-collection" },
});
const content = getResponseContent(response.content);
expect(content).toEqual(`Collection "new-collection" created in database "${integration.randomDbName()}".`);
});

it("throws an error if connection string is not configured", async () => {
const response = await integration.mcpClient().callTool({
name: "create-collection",
arguments: { database: integration.randomDbName(), collection: "new-collection" },
});
const content = getResponseContent(response.content);
expect(content).toContain("You need to connect to a MongoDB instance before you can access its data.");
});
validateAutoConnectBehavior(integration, "create-collection", () => {
return {
args: { database: integration.randomDbName(), collection: "new-collection" },
expectedResponse: `Collection "new-collection" created in database "${integration.randomDbName()}".`,
};
});
});
Loading
Loading