Skip to content

Commit 35296e4

Browse files
committed
refactor: improve type safety for page scoped tools
1 parent dfdac26 commit 35296e4

27 files changed

Lines changed: 476 additions & 197 deletions

src/server.ts

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ import {
2323
SetLevelRequestSchema,
2424
} from './third_party/index.js';
2525
import {ToolCategory} from './tools/categories.js';
26-
import type {ToolDefinition} from './tools/ToolDefinition.js';
26+
import type {
27+
DefinedPageTool,
28+
ToolDefinition,
29+
} from './tools/ToolDefinition.js';
2730
import {pageIdSchema} from './tools/ToolDefinition.js';
2831
import {createTools} from './tools/tools.js';
2932
import {VERSION} from './version.js';
@@ -107,7 +110,7 @@ export async function createMcpServer(
107110

108111
const toolMutex = new Mutex();
109112

110-
function registerTool(tool: ToolDefinition): void {
113+
function registerTool(tool: ToolDefinition | DefinedPageTool): void {
111114
if (
112115
tool.annotations.category === ToolCategory.EMULATION &&
113116
serverArgs.categoryEmulation === false
@@ -151,7 +154,9 @@ export async function createMcpServer(
151154
return;
152155
}
153156
const schema =
154-
tool.annotations.pageScoped && serverArgs.experimentalPageIdRouting
157+
'pageScoped' in tool &&
158+
tool.pageScoped &&
159+
serverArgs.experimentalPageIdRouting
155160
? {...tool.schema, ...pageIdSchema}
156161
: tool.schema;
157162

@@ -174,22 +179,31 @@ export async function createMcpServer(
174179
const response = serverArgs.slim
175180
? new SlimMcpResponse(serverArgs)
176181
: new McpResponse(serverArgs);
177-
const page =
178-
tool.annotations.pageScoped && serverArgs.experimentalPageIdRouting
179-
? context.resolvePageById(params.pageId as number | undefined)
180-
: undefined;
181-
if (page) {
182-
context.setRequestPage(page);
183-
}
184182
try {
185-
await tool.handler(
186-
{
187-
params,
188-
page,
189-
},
190-
response,
191-
context,
192-
);
183+
if ('pageScoped' in tool && tool.pageScoped) {
184+
const page =
185+
serverArgs.experimentalPageIdRouting && params.pageId
186+
? context.resolvePageById(params.pageId)
187+
: context.getSelectedPage();
188+
context.setRequestPage(page);
189+
await tool.handler(
190+
{
191+
params,
192+
page,
193+
},
194+
response,
195+
context,
196+
);
197+
} else {
198+
await tool.handler(
199+
// @ts-expect-error types do not match.
200+
{
201+
params,
202+
},
203+
response,
204+
context,
205+
);
206+
}
193207
const {content, structuredContent} = await response.handle(
194208
tool.name,
195209
context,

src/tools/ToolDefinition.ts

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import type {PaginationOptions} from '../utils/types.js';
2020

2121
import type {ToolCategory} from './categories.js';
2222

23-
export interface ToolDefinition<
23+
export interface BaseToolDefinition<
2424
Schema extends zod.ZodRawShape = zod.ZodRawShape,
2525
> {
2626
name: string;
@@ -33,14 +33,14 @@ export interface ToolDefinition<
3333
*/
3434
readOnlyHint: boolean;
3535
conditions?: string[];
36-
/**
37-
* If true, the tool operates on a specific page.
38-
* The `pageId` schema field is auto-injected and the resolved
39-
* page is provided via `request.page`.
40-
*/
41-
pageScoped?: boolean;
4236
};
4337
schema: Schema;
38+
}
39+
40+
export interface ToolDefinition<
41+
Schema extends zod.ZodRawShape = zod.ZodRawShape,
42+
> extends BaseToolDefinition<Schema> {
43+
schema: Schema;
4444
handler: (
4545
request: Request<Schema>,
4646
response: Response,
@@ -50,8 +50,6 @@ export interface ToolDefinition<
5050

5151
export interface Request<Schema extends zod.ZodRawShape> {
5252
params: zod.objectOutputType<Schema, zod.ZodTypeAny>;
53-
/** Populated centrally for tools with `pageScoped: true`. */
54-
page?: Page;
5553
}
5654

5755
export interface ImageContentData {
@@ -215,28 +213,66 @@ export function defineTool<
215213
if (typeof definition === 'function') {
216214
const factory = definition;
217215
return (args: Args) => {
218-
const tool = factory(args);
219-
wrapPageScopedHandler(tool);
220-
return tool;
216+
return factory(args);
221217
};
222218
}
223-
wrapPageScopedHandler(definition);
224219
return definition;
225220
}
226221

227-
function wrapPageScopedHandler<Schema extends zod.ZodRawShape>(
228-
definition: ToolDefinition<Schema>,
229-
) {
230-
if (definition.annotations.pageScoped) {
231-
const originalHandler = definition.handler;
232-
definition.handler = async (request, response, context) => {
233-
// In production, main.ts resolves request.page centrally before calling
234-
// the handler. This fallback exists for tests that invoke handlers
235-
// directly without going through main.ts.
236-
request.page ??= context.getSelectedPage();
237-
return originalHandler(request, response, context);
222+
interface PageToolDefinition<
223+
Schema extends zod.ZodRawShape = zod.ZodRawShape,
224+
> extends BaseToolDefinition<Schema> {
225+
handler: (
226+
request: Request<Schema> & {page: Page},
227+
response: Response,
228+
context: Context,
229+
) => Promise<void>;
230+
}
231+
232+
export type DefinedPageTool<
233+
Schema extends zod.ZodRawShape = zod.ZodRawShape,
234+
> = PageToolDefinition<Schema> & {
235+
pageScoped: true;
236+
handler: (
237+
request: Request<Schema> & {page: Page},
238+
response: Response,
239+
context: Context,
240+
) => Promise<void>;
241+
};
242+
243+
export function definePageTool<Schema extends zod.ZodRawShape>(
244+
definition: PageToolDefinition<Schema>,
245+
): DefinedPageTool<Schema>;
246+
247+
export function definePageTool<
248+
Schema extends zod.ZodRawShape,
249+
Args extends ParsedArguments = ParsedArguments,
250+
>(
251+
definition: (args?: Args) => PageToolDefinition<Schema>,
252+
): (args?: Args) => DefinedPageTool<Schema>;
253+
254+
export function definePageTool<
255+
Schema extends zod.ZodRawShape,
256+
Args extends ParsedArguments = ParsedArguments,
257+
>(
258+
definition:
259+
| PageToolDefinition<Schema>
260+
| ((args?: Args) => PageToolDefinition<Schema>),
261+
): DefinedPageTool<Schema> | ((args?: Args) => DefinedPageTool<Schema>) {
262+
if (typeof definition === 'function') {
263+
return (args?: Args): DefinedPageTool<Schema> => {
264+
const tool = definition(args);
265+
return {
266+
...tool,
267+
pageScoped: true,
268+
};
238269
};
239270
}
271+
272+
return {
273+
...definition,
274+
pageScoped: true,
275+
} as DefinedPageTool<Schema>;
240276
}
241277

242278
export const CLOSE_PAGE_ERROR =

src/tools/console.ts

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {zod} from '../third_party/index.js';
88
import type {ConsoleMessageType} from '../third_party/index.js';
99

1010
import {ToolCategory} from './categories.js';
11-
import {defineTool} from './ToolDefinition.js';
11+
import {definePageTool} from './ToolDefinition.js';
1212
type ConsoleResponseType = ConsoleMessageType | 'issue';
1313

1414
const FILTERABLE_MESSAGE_TYPES: [
@@ -37,14 +37,13 @@ const FILTERABLE_MESSAGE_TYPES: [
3737
'issue',
3838
];
3939

40-
export const listConsoleMessages = defineTool({
40+
export const listConsoleMessages = definePageTool({
4141
name: 'list_console_messages',
4242
description:
4343
'List all console messages for the currently selected page since the last navigation.',
4444
annotations: {
4545
category: ToolCategory.DEBUGGING,
4646
readOnlyHint: true,
47-
pageScoped: true,
4847
},
4948
schema: {
5049
pageSize: zod
@@ -87,13 +86,12 @@ export const listConsoleMessages = defineTool({
8786
},
8887
});
8988

90-
export const getConsoleMessage = defineTool({
89+
export const getConsoleMessage = definePageTool({
9190
name: 'get_console_message',
9291
description: `Gets a console message by its ID. You can get all messages by calling ${listConsoleMessages.name}.`,
9392
annotations: {
9493
category: ToolCategory.DEBUGGING,
9594
readOnlyHint: true,
96-
pageScoped: true,
9795
},
9896
schema: {
9997
msgid: zod

src/tools/emulation.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,20 @@
88
import {zod, PredefinedNetworkConditions} from '../third_party/index.js';
99

1010
import {ToolCategory} from './categories.js';
11-
import {defineTool} from './ToolDefinition.js';
11+
import {definePageTool} from './ToolDefinition.js';
1212

1313
const throttlingOptions: [string, ...string[]] = [
1414
'No emulation',
1515
'Offline',
1616
...Object.keys(PredefinedNetworkConditions),
1717
];
1818

19-
export const emulate = defineTool({
19+
export const emulate = definePageTool({
2020
name: 'emulate',
2121
description: `Emulates various features on the selected page.`,
2222
annotations: {
2323
category: ToolCategory.EMULATION,
2424
readOnlyHint: false,
25-
pageScoped: true,
2625
},
2726
schema: {
2827
networkConditions: zod
@@ -105,7 +104,7 @@ export const emulate = defineTool({
105104
),
106105
},
107106
handler: async (request, _response, context) => {
108-
const page = request.page!;
107+
const page = request.page;
109108
await context.emulate(request.params, page);
110109
},
111110
});

0 commit comments

Comments
 (0)