Skip to content

Commit 78e388f

Browse files
committed
feat: add new mcp tool to screenshot pages
Implements #244
1 parent 9619c64 commit 78e388f

2 files changed

Lines changed: 194 additions & 6 deletions

File tree

scrapling/core/ai.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from datetime import datetime, timezone
44
from dataclasses import dataclass, field
55

6-
from mcp.server.fastmcp import FastMCP
6+
from mcp.server.fastmcp import FastMCP, Image
7+
from mcp.types import ImageContent, TextContent
78
from pydantic import BaseModel, Field
89

910
from scrapling.core.shell import Convertor
@@ -31,6 +32,7 @@
3132
)
3233

3334
SessionType = Literal["dynamic", "stealthy"]
35+
ScreenshotType = Literal["png", "jpeg"]
3436

3537

3638
class ResponseModel(BaseModel):
@@ -106,14 +108,14 @@ class ScraplingMCPServer:
106108
def __init__(self):
107109
self._sessions: Dict[str, _SessionEntry] = {}
108110

109-
def _get_session(self, session_id: str, expected_type: SessionType) -> _SessionEntry:
110-
"""Look up a session by ID and validate its type."""
111+
def _get_session(self, session_id: str, expected_type: Optional[SessionType]) -> _SessionEntry:
112+
"""Look up a session by ID, optionally validating its type. Pass `None` to skip the type check."""
111113
entry = self._sessions.get(session_id)
112114
if entry is None:
113115
raise ValueError(f"Session '{session_id}' not found. Use list_sessions to see active sessions.")
114116
if not entry.session._is_alive:
115117
raise ValueError(f"Session '{session_id}' is no longer alive. Open a new session.")
116-
if entry.session_type != expected_type:
118+
if expected_type is not None and entry.session_type != expected_type:
117119
raise ValueError(
118120
f"Session '{session_id}' is a '{entry.session_type}' session, but this tool requires a "
119121
f"'{expected_type}' session. Use the matching fetch tool for your session type."
@@ -260,6 +262,69 @@ async def list_sessions(self) -> List[SessionInfo]:
260262
for sid, entry in self._sessions.items()
261263
]
262264

265+
async def screenshot(
266+
self,
267+
url: str,
268+
session_id: str,
269+
image_type: ScreenshotType = "png",
270+
full_page: bool = False,
271+
quality: Optional[int] = None,
272+
wait: int | float = 0,
273+
wait_selector: Optional[str] = None,
274+
wait_selector_state: SelectorWaitStates = "attached",
275+
network_idle: bool = False,
276+
timeout: int | float = 30000,
277+
) -> List[ImageContent | TextContent]:
278+
"""Capture a screenshot of a web page using an existing browser session and return it as an image.
279+
A browser session must be opened first with `open_session` (either `dynamic` or `stealthy`); the session ID is then passed here.
280+
281+
:param url: The URL to navigate to and capture.
282+
:param session_id: ID of an open browser session created with `open_session`.
283+
:param image_type: Image format. Defaults to "png". Use "jpeg" for smaller file sizes.
284+
:param full_page: When True, captures the full scrollable page instead of just the viewport. Defaults to False.
285+
:param quality: Image quality (0-100) for JPEG only. Raises if passed with `image_type="png"`.
286+
:param wait: Time in milliseconds to wait after page load before capturing. Defaults to 0.
287+
:param wait_selector: Optional CSS selector to wait for before capturing.
288+
:param wait_selector_state: State to wait for the selector. Defaults to "attached".
289+
:param network_idle: Wait for the page until there are no network connections for at least 500 ms.
290+
:param timeout: Timeout in milliseconds for page operations. Defaults to 30,000.
291+
"""
292+
if quality is not None and image_type != "jpeg":
293+
raise ValueError("'quality' is only valid when 'image_type' is 'jpeg'.")
294+
295+
entry = self._get_session(session_id, expected_type=None)
296+
297+
screenshot_kwargs: Dict[str, Any] = {"type": image_type, "full_page": full_page}
298+
if quality is not None:
299+
screenshot_kwargs["quality"] = quality
300+
301+
captured: Dict[str, Any] = {}
302+
303+
async def _capture(page: Any) -> None:
304+
try:
305+
captured["bytes"] = await page.screenshot(**screenshot_kwargs)
306+
captured["url"] = page.url
307+
except Exception as exc:
308+
captured["error"] = exc
309+
310+
await entry.session.fetch(
311+
url,
312+
wait=wait,
313+
timeout=timeout,
314+
network_idle=network_idle,
315+
wait_selector=wait_selector,
316+
wait_selector_state=wait_selector_state,
317+
page_action=_capture,
318+
)
319+
320+
if "error" in captured:
321+
raise captured["error"]
322+
if "bytes" not in captured:
323+
raise RuntimeError(f"Failed to capture screenshot for {url}")
324+
325+
image = Image(data=captured["bytes"], format=image_type).to_image_content()
326+
return [image, TextContent(type="text", text=captured["url"])]
327+
263328
@staticmethod
264329
async def get(
265330
url: str,
@@ -298,7 +363,8 @@ async def get(
298363
:param headers: Headers to include in the request.
299364
:param cookies: Cookies to use in the request.
300365
:param timeout: Number of seconds to wait before timing out.
301-
:param follow_redirects: Whether to follow redirects. Defaults to "safe", which follows redirects but rejects those targeting internal/private IPs (SSRF protection). Pass True to follow all redirects without restriction.
366+
:param follow_redirects: Whether to follow redirects. Defaults to "safe", which follows redirects but rejects those targeting internal/private IPs (SSRF protection).
367+
Pass True to follow all redirects without restriction.
302368
:param max_redirects: Maximum number of redirects. Default 30, use -1 for unlimited.
303369
:param retries: Number of retry attempts. Defaults to 3.
304370
:param retry_delay: Number of seconds to wait between retry attempts. Defaults to 1 second.
@@ -371,7 +437,8 @@ async def bulk_get(
371437
:param headers: Headers to include in the request.
372438
:param cookies: Cookies to use in the request.
373439
:param timeout: Number of seconds to wait before timing out.
374-
:param follow_redirects: Whether to follow redirects. Defaults to "safe", which follows redirects but rejects those targeting internal/private IPs (SSRF protection). Pass True to follow all redirects without restriction.
440+
:param follow_redirects: Whether to follow redirects. Defaults to "safe", which follows redirects but rejects those targeting internal/private IPs (SSRF protection).
441+
Pass True to follow all redirects without restriction.
375442
:param max_redirects: Maximum number of redirects. Default 30, use -1 for unlimited.
376443
:param retries: Number of retry attempts. Defaults to 3.
377444
:param retry_delay: Number of seconds to wait between retry attempts. Defaults to 1 second.
@@ -835,4 +902,6 @@ def serve(self, http: bool, host: str, port: int):
835902
description=self.bulk_stealthy_fetch.__doc__,
836903
structured_output=True,
837904
)
905+
# Screenshot tool (returns image + url content blocks, not structured JSON)
906+
server.add_tool(self.screenshot, title="screenshot", description=self.screenshot.__doc__)
838907
server.run(transport="stdio" if not http else "streamable-http")

tests/ai/test_ai_mcp.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
import base64
2+
import struct
3+
from contextlib import contextmanager
4+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
5+
from threading import Thread
6+
17
import pytest
28
import pytest_httpbin
9+
from mcp.types import ImageContent, TextContent
310

411
from scrapling.core.ai import (
512
ScraplingMCPServer,
@@ -197,6 +204,118 @@ async def test_open_session_duplicate_id_raises(self, server):
197204
await server.close_session("dupe")
198205

199206

207+
def _png_height(data: bytes) -> int:
208+
"""Read the height field from a PNG IHDR chunk."""
209+
return struct.unpack(">I", data[20:24])[0]
210+
211+
212+
@contextmanager
213+
def _serve_html(body: bytes):
214+
"""Serve a fixed HTML body on localhost, yielding its URL."""
215+
216+
class _Handler(BaseHTTPRequestHandler):
217+
def do_GET(self):
218+
self.send_response(200)
219+
self.send_header("Content-Type", "text/html; charset=utf-8")
220+
self.send_header("Content-Length", str(len(body)))
221+
self.end_headers()
222+
self.wfile.write(body)
223+
224+
def log_message(self, *args, **kwargs):
225+
pass
226+
227+
server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler)
228+
thread = Thread(target=server.serve_forever, daemon=True)
229+
thread.start()
230+
try:
231+
yield f"http://127.0.0.1:{server.server_address[1]}/"
232+
finally:
233+
server.shutdown()
234+
server.server_close()
235+
236+
237+
@pytest_httpbin.use_class_based_httpbin
238+
class TestScreenshot:
239+
"""Test the screenshot tool"""
240+
241+
@pytest.fixture(scope="class")
242+
def test_url(self, httpbin):
243+
return f"{httpbin.url}/html"
244+
245+
@pytest.fixture
246+
def server(self):
247+
return ScraplingMCPServer()
248+
249+
@pytest.mark.asyncio
250+
async def test_screenshot_png_with_dynamic_session(self, server, test_url):
251+
"""PNG screenshot via a dynamic session returns image and url content blocks"""
252+
opened = await server.open_session(session_type="dynamic", headless=True)
253+
try:
254+
result = await server.screenshot(url=test_url, session_id=opened.session_id)
255+
assert isinstance(result, list) and len(result) == 2
256+
assert isinstance(result[0], ImageContent)
257+
assert result[0].mimeType == "image/png"
258+
assert isinstance(result[1], TextContent)
259+
assert result[1].text == test_url
260+
finally:
261+
await server.close_session(opened.session_id)
262+
263+
@pytest.mark.asyncio
264+
async def test_screenshot_jpeg_with_quality(self, server, test_url):
265+
"""JPEG screenshot with quality parameter via a dynamic session"""
266+
opened = await server.open_session(session_type="dynamic", headless=True)
267+
try:
268+
result = await server.screenshot(url=test_url, session_id=opened.session_id, image_type="jpeg", quality=80)
269+
assert isinstance(result[0], ImageContent)
270+
assert result[0].mimeType == "image/jpeg"
271+
finally:
272+
await server.close_session(opened.session_id)
273+
274+
@pytest.mark.asyncio
275+
async def test_screenshot_with_stealthy_session(self, server, test_url):
276+
"""PNG screenshot via a stealthy session"""
277+
opened = await server.open_session(session_type="stealthy", headless=True)
278+
try:
279+
result = await server.screenshot(url=test_url, session_id=opened.session_id)
280+
assert isinstance(result[0], ImageContent)
281+
assert result[0].mimeType == "image/png"
282+
finally:
283+
await server.close_session(opened.session_id)
284+
285+
@pytest.mark.asyncio
286+
async def test_screenshot_full_page_taller_than_viewport(self, server):
287+
"""full_page=True produces an image taller than the viewport-only capture"""
288+
tall_html = b"<html><body><div style='height:5000px;background:#abc'></div></body></html>"
289+
with _serve_html(tall_html) as tall_url:
290+
opened = await server.open_session(session_type="dynamic", headless=True)
291+
try:
292+
viewport_result = await server.screenshot(url=tall_url, session_id=opened.session_id, full_page=False)
293+
full_result = await server.screenshot(url=tall_url, session_id=opened.session_id, full_page=True)
294+
295+
viewport_png = base64.b64decode(viewport_result[0].data)
296+
full_png = base64.b64decode(full_result[0].data)
297+
298+
assert _png_height(full_png) > _png_height(viewport_png)
299+
finally:
300+
await server.close_session(opened.session_id)
301+
302+
@pytest.mark.asyncio
303+
async def test_screenshot_invalid_session_id_raises(self, server, test_url):
304+
"""Unknown session_id raises ValueError"""
305+
with pytest.raises(ValueError, match="not found"):
306+
await server.screenshot(url=test_url, session_id="does-not-exist")
307+
308+
@pytest.mark.asyncio
309+
async def test_screenshot_quality_with_png_raises(self, server, test_url):
310+
"""quality is rejected when image_type is png"""
311+
opened = await server.open_session(session_type="dynamic", headless=True)
312+
try:
313+
with pytest.raises(ValueError, match="quality"):
314+
await server.screenshot(url=test_url, session_id=opened.session_id, image_type="png", quality=90)
315+
finally:
316+
await server.close_session(opened.session_id)
317+
318+
200319
class TestNormalizeCredentials:
201320
"""Test the _normalize_credentials helper"""
202321

0 commit comments

Comments
 (0)