1"""Module implementing an MCP server that can be used to connect to stdio or SSE based MCP servers.
2
3Heavily inspired by: https://github.com/sparfenyuk/mcp-proxy
4"""
5
6from __future__ import annotations
7
8import contextlib
9import logging
10from typing import TYPE_CHECKING, Any
11
12import uvicorn
13from mcp.client.session import ClientSession
14from mcp.client.sse import sse_client
15from mcp.client.stdio import StdioServerParameters, stdio_client
16from mcp.server.sse import SseServerTransport
17from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
18from starlette.applications import Starlette
19from starlette.requests import Request
20from starlette.responses import JSONResponse, Response
21from starlette.routing import Mount, Route
22
23from .event_store import InMemoryEventStore
24from .models import ServerParameters, ServerType, SseServerParameters
25from .proxy_server import create_proxy_server
26
27if TYPE_CHECKING:
28 from collections.abc import AsyncIterator, Callable
29
30 from mcp.server import Server
31 from starlette import types as st
32 from starlette.requests import Request
33 from starlette.types import Receive, Scope, Send
34
35logger = logging.getLogger('apify')
36
37
38class ProxyServer:
39 """Main class implementing the proxy functionality using MCP SDK.
40
41 This proxy runs a Starlette app that exposes /sse and /messages/ endpoints for legacy SSE transport,
42 and a /mcp endpoint for streamable HTTP transport.
43 It then connects to stdio or SSE based MCP servers and forwards the messages to the client.
44
45 The server can optionally charge for operations using a provided charging function.
46 This is typically used in Apify Actors to charge users for MCP operations.
47 The charging function should accept an event name and optional parameters.
48 """
49
50 def __init__(
51 self,
52 config: ServerParameters,
53 host: str,
54 port: int,
55 actor_charge_function: Callable[[str, int], None] | None = None,
56 ) -> None:
57 """Initialize the proxy server.
58
59 Args:
60 config: Server configuration (stdio or SSE parameters)
61 host: Host to bind the server to
62 port: Port to bind the server to
63 actor_charge_function: Optional function to charge for operations.
64 Should accept (event_name: str, count: int).
65 Typically, Actor.charge in Apify Actors.
66 If None, no charging will occur.
67 """
68 self.server_type = ServerType.STDIO if isinstance(config, StdioServerParameters) else ServerType.SSE
69 self.config = self._validate_config(self.server_type, config)
70 self.path_sse: str = '/sse'
71 self.path_message: str = '/message'
72 self.host: str = host
73 self.port: int = port
74 self.actor_charge_function = actor_charge_function
75
76 @staticmethod
77 def _validate_config(client_type: ServerType, config: ServerParameters) -> ServerParameters:
78 """Validate and return the appropriate server parameters."""
79
80 def validate_and_return() -> ServerParameters:
81 if client_type == ServerType.STDIO:
82 return StdioServerParameters.model_validate(config)
83 if client_type == ServerType.SSE:
84 return SseServerParameters.model_validate(config)
85 raise ValueError(f'Invalid client type: {client_type}')
86
87 try:
88 return validate_and_return()
89 except Exception as e:
90 raise ValueError(f'Invalid server configuration: {e}') from e
91
92 @staticmethod
93 async def create_starlette_app(mcp_server: Server) -> Starlette:
94 """Create a Starlette app (SSE server) that exposes /sse and /messages/ endpoints."""
95 transport = SseServerTransport('/messages/')
96 event_store = InMemoryEventStore()
97 session_manager = StreamableHTTPSessionManager(
98 app=mcp_server,
99 event_store=event_store,
100 json_response=False,
101 )
102
103 @contextlib.asynccontextmanager
104 async def lifespan(_app: Starlette) -> AsyncIterator[None]:
105 """Context manager for managing session manager lifecycle."""
106 async with session_manager.run():
107 logger.info('Application started with StreamableHTTP session manager!')
108 try:
109 yield
110 finally:
111 logger.info('Application shutting down...')
112
113 async def handle_root(request: Request) -> st.Response:
114 """Handle root endpoint."""
115
116 if 'x-apify-container-server-readiness-probe' in request.headers:
117 return Response(
118 content=b'ok',
119 media_type='text/plain',
120 status_code=200,
121 )
122
123 return JSONResponse(
124 {
125 'status': 'running',
126 'type': 'mcp-server',
127 'transport': 'sse+streamable-http',
128 'endpoints': {
129 'sse': '/sse',
130 'messages': '/messages/',
131 'streamableHttp': '/mcp',
132 },
133 }
134 )
135
136 async def handle_sse(request: st.Request) -> st.Response | None:
137 """Handle incoming SSE requests."""
138 try:
139 async with transport.connect_sse(request.scope, request.receive, request._send) as streams:
140 init_options = mcp_server.create_initialization_options()
141 await mcp_server.run(streams[0], streams[1], init_options)
142 except Exception as e:
143 logger.exception('Error in SSE connection')
144 return Response(status_code=500, content=str(e))
145 finally:
146 logger.info('SSE connection closed')
147
148 return Response(status_code=204)
149
150
151 async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
152 await session_manager.handle_request(scope, receive, send)
153
154 app = Starlette(
155 debug=True,
156 routes=[
157 Route('/', endpoint=handle_root),
158 Route('/sse', endpoint=handle_sse, methods=['GET']),
159 Mount('/messages/', app=transport.handle_post_message),
160 Mount('/mcp/', app=handle_streamable_http),
161 ],
162 lifespan=lifespan,
163 )
164
165
166
167
168 @app.middleware('http')
169 async def rewrite_mcp(request: Request, call_next: Callable):
170 """Middleware to rewrite /mcp to /mcp/."""
171 if request.url.path == '/mcp':
172 request.scope['path'] = '/mcp/'
173 request.scope['raw_path'] = b'/mcp/'
174 return await call_next(request)
175
176 return app
177
178 async def _run_server(self, app: Starlette) -> None:
179 """Run the Starlette app with uvicorn."""
180 config_ = uvicorn.Config(
181 app,
182 host=self.host,
183 port=self.port,
184 log_level='info',
185 access_log=True,
186 )
187 server = uvicorn.Server(config_)
188 await server.serve()
189
190 async def _initialize_and_run_server(self, client_session_factory: Any, **client_params: dict) -> None:
191 """Initialize and run the server."""
192 async with client_session_factory(**client_params) as streams, ClientSession(*streams) as session:
193 mcp_server = await create_proxy_server(session, self.actor_charge_function)
194 app = await self.create_starlette_app(mcp_server)
195 await self._run_server(app)
196
197 async def start(self) -> None:
198 """Start Starlette app (SSE server) and connect to stdio or SSE based MCP server."""
199 logger.info(f'Starting MCP server with client type: {self.server_type} and config {self.config}')
200
201 if self.server_type == ServerType.STDIO:
202 logger.info(f'Starting and connecting to stdio based MCP server with config {self.config}')
203 await self._initialize_and_run_server(stdio_client, server=self.config)
204 elif self.server_type == ServerType.SSE:
205 logger.info(f'Connecting to SSE based MCP server with config {self.config}')
206 params = self.config.model_dump(exclude_unset=True)
207 await self._initialize_and_run_server(sse_client, **params)