From 0c44dfd485a5ad974921b88f361b5a3105cd286b Mon Sep 17 00:00:00 2001 From: "jiajian.chi" Date: Thu, 18 Jun 2026 23:48:41 +0800 Subject: [PATCH 1/2] Add AK/SK REST authentication path --- README.md | 37 ++- src/zstack_mcp/server.py | 121 +++++++-- src/zstack_mcp/zstack_client.py | 331 ++++++++++++++++++++++++- tests/test_auth_refactor.py | 96 ++++++- tests/test_zstack_client_access_key.py | 164 ++++++++++++ 5 files changed, 715 insertions(+), 34 deletions(-) create mode 100644 tests/test_zstack_client_access_key.py diff --git a/README.md b/README.md index aec8cf3..f288648 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,10 @@ export ZSTACK_PASSWORD="your-password" # 密码(明文) # 认证方式二:直接传入 SessionID(优先级更高,设置后忽略用户名密码) export ZSTACK_SESSION_ID="your-session-uuid" # 已有的 Session UUID +# 认证方式三:AK/SK(AccessKey/SecretKey,请求签名认证) +export ZSTACK_ACCESS_KEY_ID="your-access-key-id" +export ZSTACK_ACCESS_KEY_SECRET="your-access-key-secret" + # 查询响应控制(可选) export ZSTACK_QUERY_DEFAULT_LIMIT="50" # Query API 默认 limit(设 0 禁用) export ZSTACK_RESPONSE_SIZE_LIMIT="65536" # 响应大小上限,字节(设 0 禁用) @@ -48,8 +52,10 @@ export ZSTACK_RESPONSE_SIZE_LIMIT="65536" # 响应大小上限,字节( |------|----------|------| | 用户名密码 | `ZSTACK_ACCOUNT` + `ZSTACK_PASSWORD` | 自动登录获取 Session | | Session ID | `ZSTACK_SESSION_ID` | 直接使用已有 Session(优先级更高) | +| AK/SK | `ZSTACK_ACCESS_KEY_ID` + `ZSTACK_ACCESS_KEY_SECRET` | 通过 REST API 签名调用,不创建 Session | -> 💡 如果同时设置了 `ZSTACK_SESSION_ID` 和用户名密码,会优先使用 Session ID +> 💡 环境变量同时存在时优先级为:`ZSTACK_SESSION_ID` > AK/SK > 用户名密码 +> AK/SK 只适用于已配置 REST 路由映射的 API。未映射的 API 会返回 `REST_MAPPING_NOT_FOUND`,不会回退调用 `/zstack/api/` message API。 ### 安全说明 @@ -141,9 +147,11 @@ uvx zstack-mcp-server | `X-ZStack-Account` | `ZSTACK_ACCOUNT` | 账户名 | | `X-ZStack-Password` | `ZSTACK_PASSWORD` | 密码 | | `X-ZStack-Session-Id` | `ZSTACK_SESSION_ID` | 已有 Session(优先级高于账号密码) | +| `X-ZStack-Access-Key-Id` | `ZSTACK_ACCESS_KEY_ID` | AccessKey ID | +| `X-ZStack-Access-Key-Secret` | `ZSTACK_ACCESS_KEY_SECRET` | AccessKey Secret | | `X-ZStack-API-URL` | `ZSTACK_API_URL` | ZStack 管理节点地址(可代理多套环境) | -凭据优先级:HTTP 头 > 环境变量 +凭据优先级:HTTP 头 > 环境变量;同一来源内优先级为 Session ID > AK/SK > 用户名密码。 典型用法: ```bash @@ -212,6 +220,30 @@ ZSTACK_ALLOW_ALL_API=false uvx zstack-mcp-server --transport streamable-http --h } ``` +**方式三:使用 AK/SK** +```json +{ + "mcpServers": { + "zstack": { + "command": "uvx", + "args": ["zstack-mcp-server"], + "env": { + "ZSTACK_API_URL": "http://your-zstack-server:8080", + "ZSTACK_ACCESS_KEY_ID": "your-access-key-id", + "ZSTACK_ACCESS_KEY_SECRET": "your-access-key-secret", + "ZSTACK_ALLOW_ALL_API": "false" + } + } + } +} +``` + +AK/SK 模式下,MCP Server 会将已映射的 Query API 转为 REST GET 请求,例如: +- `QueryZone` → `GET /zstack/v1/zones` +- `QueryVmInstance` → `GET /zstack/v1/vm-instances` + +未映射的 API 会返回明确错误,避免误打 `/zstack/api/` 后触发缺少 session 的 `ID.1001`。 + > 💡 将 `ZSTACK_ALLOW_ALL_API` 设为 `"true"` 可启用写操作(创建/删除/修改等) ## 可用工具 @@ -349,4 +381,3 @@ pytest ## License MIT - diff --git a/src/zstack_mcp/server.py b/src/zstack_mcp/server.py index 141ae72..aca65f6 100644 --- a/src/zstack_mcp/server.py +++ b/src/zstack_mcp/server.py @@ -17,6 +17,7 @@ import asyncio import atexit import copy +import hashlib import json import logging import os @@ -76,6 +77,8 @@ class RequestAuth: account: Optional[str] = None password: Optional[str] = None session_id: Optional[str] = None + access_key_id: Optional[str] = None + access_key_secret: Optional[str] = None api_url: Optional[str] = None @@ -92,6 +95,16 @@ def _extract_auth_from_context(ctx: Context) -> RequestAuth: auth.account = headers.get("x-zstack-account") or None auth.password = headers.get("x-zstack-password") or None auth.session_id = headers.get("x-zstack-session-id") or None + auth.access_key_id = ( + headers.get("x-zstack-access-key-id") + or headers.get("x-zstack-ak") + or None + ) + auth.access_key_secret = ( + headers.get("x-zstack-access-key-secret") + or headers.get("x-zstack-sk") + or None + ) auth.api_url = headers.get("x-zstack-api-url") or None except Exception: # stdio 模式或其他无 HTTP request 的场景,安全忽略 @@ -111,6 +124,8 @@ async def get_client( account: Optional[str] = None, password: Optional[str] = None, session_id: Optional[str] = None, + access_key_id: Optional[str] = None, + access_key_secret: Optional[str] = None, api_url: Optional[str] = None, ) -> ZStackClient: """获取 client,优先用缓存的 session @@ -120,46 +135,115 @@ async def get_client( """ # 确定凭据来源 env_session_id = os.environ.get("ZSTACK_SESSION_ID", "") + env_access_key_id = os.environ.get("ZSTACK_ACCESS_KEY_ID", "") or os.environ.get("ZSTACK_AK", "") + env_access_key_secret = os.environ.get("ZSTACK_ACCESS_KEY_SECRET", "") or os.environ.get("ZSTACK_SK", "") env_account = os.environ.get("ZSTACK_ACCOUNT", "") env_password = os.environ.get("ZSTACK_PASSWORD", "") env_api_url = os.environ.get("ZSTACK_API_URL", "") - effective_session_id = session_id or env_session_id - effective_account = account or env_account - effective_password = password or env_password effective_api_url = api_url or env_api_url or None - # 如果有 session_id(参数传入或环境变量)且没有传账号密码 → 直接用 session_id - if effective_session_id and not account and not password: - cache_key = f"__session_id__{effective_api_url or ''}|{effective_session_id}" + request_session = bool(session_id) + request_access_key = bool(access_key_id or access_key_secret) + request_account = bool(account or password) + + # HTTP 头中的凭据优先于环境变量;同一认证方式内允许从环境变量补齐另一半。 + if request_session: + cache_key = f"__session_id__{effective_api_url or ''}|{session_id}" if cache_key in self._clients: self._clients.move_to_end(cache_key) return self._clients[cache_key] - client = ZStackClient(session_id=effective_session_id, api_url=effective_api_url) - self._clients[cache_key] = client - return client + client = ZStackClient(session_id=session_id, api_url=effective_api_url) + return await self._cache_client(cache_key, client) + + if request_access_key: + effective_access_key_id = access_key_id or env_access_key_id + effective_access_key_secret = access_key_secret or env_access_key_secret + return await self._get_access_key_client( + effective_api_url, + effective_access_key_id, + effective_access_key_secret, + ) + + if request_account: + effective_account = account or env_account + effective_password = password or env_password + return await self._get_password_client(effective_api_url, effective_account, effective_password) + + # 无 HTTP 凭据时回退环境变量,认证优先级:Session ID > AK/SK > 账号密码。 + if env_session_id: + cache_key = f"__session_id__{effective_api_url or ''}|{env_session_id}" + if cache_key in self._clients: + self._clients.move_to_end(cache_key) + return self._clients[cache_key] + client = ZStackClient(session_id=env_session_id, api_url=effective_api_url) + return await self._cache_client(cache_key, client) + + if env_access_key_id or env_access_key_secret: + return await self._get_access_key_client( + effective_api_url, + env_access_key_id, + env_access_key_secret, + ) + + return await self._get_password_client(effective_api_url, env_account, env_password) + + async def _get_access_key_client( + self, + api_url: Optional[str], + access_key_id: Optional[str], + access_key_secret: Optional[str], + ) -> ZStackClient: + if not access_key_id or not access_key_secret: + raise ZStackApiError( + "缺少 AK/SK 认证凭据。请同时传入 X-ZStack-Access-Key-Id / " + "X-ZStack-Access-Key-Secret,或设置环境变量 " + "ZSTACK_ACCESS_KEY_ID + ZSTACK_ACCESS_KEY_SECRET" + ) + + secret_fingerprint = hashlib.sha256(access_key_secret.encode("utf-8")).hexdigest()[:16] + cache_key = f"__access_key__{api_url or ''}|{access_key_id}|{secret_fingerprint}" + if cache_key in self._clients: + self._clients.move_to_end(cache_key) + return self._clients[cache_key] + client = ZStackClient( + access_key_id=access_key_id, + access_key_secret=access_key_secret, + api_url=api_url, + ) + return await self._cache_client(cache_key, client) + + async def _get_password_client( + self, + api_url: Optional[str], + account: Optional[str], + password: Optional[str], + ) -> ZStackClient: # 必须有凭据 - if not effective_account or not effective_password: + if not account or not password: raise ZStackApiError( "缺少认证凭据。请通过 HTTP 头传入 X-ZStack-Account/X-ZStack-Password," "或设置环境变量 ZSTACK_ACCOUNT + ZSTACK_PASSWORD," - "或设置 ZSTACK_SESSION_ID / X-ZStack-Session-Id 直接使用已有会话" + "或设置 ZSTACK_SESSION_ID / X-ZStack-Session-Id 直接使用已有会话," + "或设置 ZSTACK_ACCESS_KEY_ID + ZSTACK_ACCESS_KEY_SECRET 使用 AK/SK 认证" ) - cache_key = f"{effective_api_url or ''}|{effective_account}" + cache_key = f"{api_url or ''}|{account}" if cache_key in self._clients: self._clients.move_to_end(cache_key) return self._clients[cache_key] # 缓存未命中 → 创建新 client 并登录 client = ZStackClient( - account=effective_account, - password=effective_password, - api_url=effective_api_url, + account=account, + password=password, + api_url=api_url, ) await client.login() + return await self._cache_client(cache_key, client) + async def _cache_client(self, cache_key: str, client: ZStackClient) -> ZStackClient: # 超过上限 → 淘汰最早的 while len(self._clients) >= self._max_sessions: _, old_client = self._clients.popitem(last=False) @@ -951,6 +1035,8 @@ async def execute_api( account=auth.account, password=auth.password, session_id=auth.session_id, + access_key_id=auth.access_key_id, + access_key_secret=auth.access_key_secret, api_url=auth.api_url, ) is_async = api_info.call_type == 'async' @@ -1173,6 +1259,8 @@ async def get_metric_data( account=auth.account, password=auth.password, session_id=auth.session_id, + access_key_id=auth.access_key_id, + access_key_secret=auth.access_key_secret, api_url=auth.api_url, ) result = await client.query_metric_data( @@ -1277,6 +1365,8 @@ async def get_metric_summary( account=auth.account, password=auth.password, session_id=auth.session_id, + access_key_id=auth.access_key_id, + access_key_secret=auth.access_key_secret, api_url=auth.api_url, ) metrics = [] @@ -1656,4 +1746,3 @@ def _shutdown_cleanup(): if __name__ == "__main__": main() - diff --git a/src/zstack_mcp/zstack_client.py b/src/zstack_mcp/zstack_client.py index 1ef9746..60fc571 100644 --- a/src/zstack_mcp/zstack_client.py +++ b/src/zstack_mcp/zstack_client.py @@ -1,9 +1,10 @@ """ ZStack API 客户端 - 处理与 ZStack Cloud 的 API 通信 -支持两种认证方式: +支持三种认证方式: 1. 用户名密码登录获取 Session 2. 直接传入 SessionID(通过环境变量 ZSTACK_SESSION_ID) +3. AccessKey/SecretKey 请求签名(通过环境变量 ZSTACK_ACCESS_KEY_ID / ZSTACK_ACCESS_KEY_SECRET) 支持: - 自动登录和 session 管理 @@ -12,12 +13,16 @@ """ import asyncio +import base64 import hashlib +import hmac import json import os from datetime import datetime, timezone +from email.utils import format_datetime from typing import Any, Optional from dataclasses import dataclass +from urllib.parse import urlencode, urlparse import httpx @@ -39,13 +44,66 @@ class ZStackSession: expire_date: Optional[str] = None +@dataclass(frozen=True) +class ZStackRestRoute: + """ZStack REST 路由映射。path 不包含 /zstack 前缀。""" + method: str + path: str + + +REST_API_ROUTES: dict[str, ZStackRestRoute] = { + # 常用只读 Query API,路径来自 zstack-sdk-go-v2 generated actions。 + "QueryAccessKey": ZStackRestRoute("GET", "v1/accesskeys"), + "QueryAccount": ZStackRestRoute("GET", "v1/accounts"), + "QueryBackupStorage": ZStackRestRoute("GET", "v1/backup-storage"), + "QueryCephBackupStorage": ZStackRestRoute("GET", "v1/backup-storage/ceph"), + "QueryCephPrimaryStorage": ZStackRestRoute("GET", "v1/primary-storage/ceph"), + "QueryCluster": ZStackRestRoute("GET", "v1/clusters"), + "QueryDiskOffering": ZStackRestRoute("GET", "v1/disk-offerings"), + "QueryEip": ZStackRestRoute("GET", "v1/eips"), + "QueryGlobalConfig": ZStackRestRoute("GET", "v1/global-configurations"), + "QueryHost": ZStackRestRoute("GET", "v1/hosts"), + "QueryImage": ZStackRestRoute("GET", "v1/images"), + "QueryImageStoreBackupStorage": ZStackRestRoute("GET", "v1/backup-storage/image-store"), + "QueryInstanceOffering": ZStackRestRoute("GET", "v1/instance-offerings"), + "QueryIpRange": ZStackRestRoute("GET", "v1/l3-networks/ip-ranges"), + "QueryL2Network": ZStackRestRoute("GET", "v1/l2-networks"), + "QueryL3Network": ZStackRestRoute("GET", "v1/l3-networks"), + "QueryLoadBalancer": ZStackRestRoute("GET", "v1/load-balancers"), + "QueryLoadBalancerListener": ZStackRestRoute("GET", "v1/load-balancers/listeners"), + "QueryLocalStorageResourceRef": ZStackRestRoute("GET", "v1/primary-storage/local-storage/resource-refs"), + "QueryLongJob": ZStackRestRoute("GET", "v1/longjobs"), + "QueryManagementNode": ZStackRestRoute("GET", "v1/management-nodes"), + "QueryPolicy": ZStackRestRoute("GET", "v1/accounts/policies"), + "QueryPortForwardingRule": ZStackRestRoute("GET", "v1/port-forwarding"), + "QueryPrimaryStorage": ZStackRestRoute("GET", "v1/primary-storage"), + "QueryRole": ZStackRestRoute("GET", "v1/identities/roles"), + "QuerySecurityGroup": ZStackRestRoute("GET", "v1/security-groups"), + "QuerySftpBackupStorage": ZStackRestRoute("GET", "v1/backup-storage/sftp"), + "QuerySystemTag": ZStackRestRoute("GET", "v1/system-tags"), + "QueryUser": ZStackRestRoute("GET", "v1/accounts/users"), + "QueryUserTag": ZStackRestRoute("GET", "v1/user-tags"), + "QueryVip": ZStackRestRoute("GET", "v1/vips"), + "QueryVirtualRouterOffering": ZStackRestRoute("GET", "v1/instance-offerings/virtual-routers"), + "QueryVirtualRouterVm": ZStackRestRoute("GET", "v1/vm-instances/appliances/virtual-routers"), + "QueryVmInstance": ZStackRestRoute("GET", "v1/vm-instances"), + "QueryVmNic": ZStackRestRoute("GET", "v1/vm-instances/nics"), + "QueryVolume": ZStackRestRoute("GET", "v1/volumes"), + "QueryVolumeSnapshot": ZStackRestRoute("GET", "v1/volume-snapshots"), + "QueryVRouterRouteEntry": ZStackRestRoute("GET", "v1/vrouter-route-tables/route-entries"), + "QueryVRouterRouteTable": ZStackRestRoute("GET", "v1/vrouter-route-tables"), + "QueryZone": ZStackRestRoute("GET", "v1/zones"), +} + + class ZStackClient: """ ZStack API 客户端 认证方式(按优先级): 1. 如果设置了 ZSTACK_SESSION_ID,直接使用该 Session - 2. 否则使用 ZSTACK_ACCOUNT + ZSTACK_PASSWORD 登录获取 Session + 2. 如果设置了 ZSTACK_ACCESS_KEY_ID + ZSTACK_ACCESS_KEY_SECRET,使用 AK/SK 签名 + 3. 否则使用 ZSTACK_ACCOUNT + ZSTACK_PASSWORD 登录获取 Session """ # 轮询 Job 的配置 @@ -58,6 +116,8 @@ def __init__( account: Optional[str] = None, password: Optional[str] = None, session_id: Optional[str] = None, + access_key_id: Optional[str] = None, + access_key_secret: Optional[str] = None, ): """ 初始化 ZStack 客户端 @@ -67,13 +127,31 @@ def __init__( account: 账户名(用户名密码认证时使用) password: 密码(明文,会自动进行 SHA512 加密) session_id: 直接传入的 Session UUID(优先级高于用户名密码) + access_key_id: AccessKey ID(AK/SK 认证时使用) + access_key_secret: AccessKey Secret(AK/SK 认证时使用) """ self.api_url = api_url or os.environ.get('ZSTACK_API_URL', 'http://localhost:8080') self.account = account or os.environ.get('ZSTACK_ACCOUNT', 'admin') self.password = password or os.environ.get('ZSTACK_PASSWORD', '') + self.access_key_id = ( + access_key_id + or os.environ.get('ZSTACK_ACCESS_KEY_ID', '') + or os.environ.get('ZSTACK_AK', '') + ) + self.access_key_secret = ( + access_key_secret + or os.environ.get('ZSTACK_ACCESS_KEY_SECRET', '') + or os.environ.get('ZSTACK_SK', '') + ) - # 优先使用直接传入的 session_id - env_session_id = session_id or os.environ.get('ZSTACK_SESSION_ID', '') + # 显式传入 AK/SK 时不再回退环境变量里的 session,避免 HTTP 头凭据被进程级 session 覆盖。 + explicit_access_key = access_key_id is not None or access_key_secret is not None + if session_id is not None: + env_session_id = session_id + elif explicit_access_key: + env_session_id = '' + else: + env_session_id = os.environ.get('ZSTACK_SESSION_ID', '') # 如果有 session_id,直接创建 session 对象 if env_session_id: @@ -95,6 +173,8 @@ def auth_mode(self) -> str: if not self.session.account_uuid: return "session_id" # 直接传入的 session return "session" # 登录获取的 session + if self.access_key_id or self.access_key_secret: + return "access_key" return "password" # 需要密码登录 async def _get_http_client(self) -> httpx.AsyncClient: @@ -105,6 +185,9 @@ async def _get_http_client(self) -> httpx.AsyncClient: async def logout(self) -> None: """调用 LogOut API 销毁当前 session,然后关闭 HTTP 客户端""" + if self.auth_mode == "access_key": + await self.close() + return if self.session and self.session.uuid: try: await self.execute( @@ -128,6 +211,203 @@ def _sha512(text: str) -> str: """SHA512 加密""" return hashlib.sha512(text.encode('utf-8')).hexdigest() + @staticmethod + def _format_access_key_date() -> str: + """返回 ZStack AK/SK 签名使用的本地时区 Date 字符串。""" + now = datetime.now().astimezone() + zone_name = now.strftime("%Z") + if zone_name: + return now.strftime("%a, %d %b %Y %H:%M:%S %Z") + return format_datetime(datetime.now(timezone.utc), usegmt=True) + + @staticmethod + def _canonical_access_key_uri(url: str) -> str: + """按 ZStack Go SDK 规则提取签名用 URI:去掉 /zstack context path 和 query。""" + parsed = urlparse(url) + path = parsed.path or "/" + context_path = "/zstack" + idx = path.find(context_path) + if idx >= 0: + uri = path[idx + len(context_path):] + return uri or "/" + return path + + def _validate_access_key(self) -> None: + if not self.access_key_id or not self.access_key_secret: + raise ZStackApiError( + "AK/SK 未配置完整,请同时设置 ZSTACK_ACCESS_KEY_ID 和 " + "ZSTACK_ACCESS_KEY_SECRET,或通过 HTTP 头传入 " + "X-ZStack-Access-Key-Id / X-ZStack-Access-Key-Secret" + ) + + def _access_key_auth_headers( + self, + method: str, + url: str, + date: Optional[str] = None, + ) -> dict[str, str]: + """生成 ZStack AK/SK 请求签名头。 + + 签名算法参考 zstack-sdk-go-v2: + base64(hmac-sha1(secret, METHOD + "\n" + Date + "\n" + uri)) + """ + self._validate_access_key() + date = date or self._format_access_key_date() + method = method.upper() + uri = self._canonical_access_key_uri(url) + string_to_sign = f"{method}\n{date}\n{uri}" + digest = hmac.new( + self.access_key_secret.encode("utf-8"), + string_to_sign.encode("utf-8"), + hashlib.sha1, + ).digest() + signature = base64.b64encode(digest).decode("ascii") + return { + "Authorization": f"ZStack {self.access_key_id}:{signature}", + "Date": date, + } + + def _request_headers(self, method: str, url: str) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self.auth_mode == "access_key": + headers.update(self._access_key_auth_headers(method, url)) + return headers + + def _rest_url(self, path: str) -> str: + return f"{self.api_url.rstrip('/')}/zstack/{path.lstrip('/')}" + + @staticmethod + def _rest_query_value(value: Any) -> str: + if isinstance(value, bool): + return str(value).lower() + if isinstance(value, (list, tuple, set)): + return ",".join(str(item) for item in value) + return str(value) + + @classmethod + def _rest_condition_to_q(cls, condition: Any) -> Optional[str]: + if isinstance(condition, str): + return condition.strip() or None + if not isinstance(condition, dict): + return None + name = condition.get("name") + value = condition.get("value") + if name is None or value is None: + return None + op = str(condition.get("op") or "=").strip() + if op == "==": + op = "=" + return f"{name}{op}{cls._rest_query_value(value)}" + + @classmethod + def _rest_query_params(cls, parameters: dict[str, Any]) -> dict[str, Any]: + query: dict[str, Any] = {} + passthrough_keys = ( + "limit", + "start", + "replyWithCount", + "count", + "groupBy", + "filterName", + "sort", + ) + for key in passthrough_keys: + if key in parameters and parameters[key] is not None: + query[key] = cls._rest_query_value(parameters[key]) + + fields = parameters.get("fields") + if fields: + query["fields"] = cls._rest_query_value(fields) + + q_values: list[str] = [] + raw_q = parameters.get("q") + if isinstance(raw_q, str): + q_values.append(raw_q) + elif isinstance(raw_q, (list, tuple, set)): + q_values.extend(str(item) for item in raw_q if item is not None) + + conditions = parameters.get("conditions") + if isinstance(conditions, dict): + conditions = [conditions] + if isinstance(conditions, (list, tuple)): + for condition in conditions: + q = cls._rest_condition_to_q(condition) + if q: + q_values.append(q) + if q_values: + query["q"] = q_values + + return query + + def _rest_route_for_api(self, api_name: str) -> ZStackRestRoute: + route = REST_API_ROUTES.get(api_name) + if route is None: + raise ZStackApiError( + message=( + f"AK/SK 认证不支持 /zstack/api/ message API,且当前未配置 " + f"{api_name} 的 REST 路由映射" + ), + code="REST_MAPPING_NOT_FOUND", + details={ + "apiName": api_name, + "authMode": "access_key", + "hint": "请为该 API 增加 REST path/method 映射,或改用账号密码/session 认证。", + }, + ) + return route + + async def execute_rest( + self, + api_name: str, + parameters: dict[str, Any], + ) -> dict[str, Any]: + """使用 REST API 执行 AK/SK 请求。""" + route = self._rest_route_for_api(api_name) + if route.method != "GET": + raise ZStackApiError( + message=f"AK/SK REST 路由 {api_name} 的方法 {route.method} 暂未实现", + code="REST_METHOD_NOT_IMPLEMENTED", + details={"apiName": api_name, "method": route.method, "path": route.path}, + ) + + url = self._rest_url(route.path) + query = self._rest_query_params(parameters) + if query: + url = f"{url}?{urlencode(query, doseq=True)}" + + client = await self._get_http_client() + response = await client.get( + url, + headers=self._request_headers(route.method, url), + ) + if response.status_code >= 400: + raise ZStackApiError( + message=f"HTTP 错误 {response.status_code}: {response.text[:500]}", + code=str(response.status_code), + ) + + try: + result = response.json() + except Exception as e: + raise ZStackApiError( + message=f"响应解析失败: {str(e)}, 响应内容: {response.text[:500]}", + ) + + if isinstance(result, list): + return {"inventories": result} + if isinstance(result, dict): + if "error" in result: + error = result["error"] + if isinstance(error, dict): + raise ZStackApiError( + message=error.get("description", "请求失败"), + code=error.get("code"), + details=error, + ) + raise ZStackApiError(message=str(error or "请求失败")) + return result + return {"raw": result} + @staticmethod def _normalize_metric_time(value: Any) -> Any: """将时间规范化为秒级时间戳(支持 ISO 字符串/毫秒/秒)""" @@ -300,7 +580,7 @@ def _is_session_invalid_error(error: ZStackApiError) -> bool: return 'session' in message and ('invalid' in message or 'expired' in message) def _can_refresh_session(self) -> bool: - if self.auth_mode == "session_id": + if self.auth_mode in ("session_id", "access_key"): return False return bool(self.password) @@ -315,8 +595,14 @@ async def login(self) -> ZStackSession: Returns: ZStackSession 对象 """ + if self.auth_mode == "access_key": + raise ZStackApiError("AK/SK 认证不需要登录,请直接执行 API") if not self.password: - raise ZStackApiError("密码未配置,请设置 ZSTACK_PASSWORD 环境变量,或设置 ZSTACK_SESSION_ID 直接使用已有会话") + raise ZStackApiError( + "密码未配置,请设置 ZSTACK_PASSWORD 环境变量," + "或设置 ZSTACK_SESSION_ID 直接使用已有会话," + "或设置 ZSTACK_ACCESS_KEY_ID + ZSTACK_ACCESS_KEY_SECRET 使用 AK/SK 认证" + ) password_hash = self._sha512(self.password) @@ -331,7 +617,7 @@ async def login(self) -> ZStackSession: response = await client.post( self.api_endpoint, json=request_body, - headers={"Content-Type": "application/json"} + headers=self._request_headers("POST", self.api_endpoint) ) # 检查 HTTP 状态码 @@ -387,11 +673,17 @@ async def execute( Returns: API 返回结果 """ + if self.auth_mode == "access_key": + return await self.execute_rest(api_name, parameters) + base_parameters = dict(parameters) async def send_once() -> dict[str, Any]: # 确保已登录(除了登录 API 本身) - if 'LogIn' not in api_name: + if self.auth_mode == "access_key": + self._validate_access_key() + request_parameters = base_parameters + elif 'LogIn' not in api_name: session = await self.ensure_session() # 添加 session 信息 request_parameters = { @@ -410,7 +702,7 @@ async def send_once() -> dict[str, Any]: response = await client.post( self.api_endpoint, json=request_body, - headers={"Content-Type": "application/json"} + headers=self._request_headers("POST", self.api_endpoint) ) # 检查 HTTP 状态码 @@ -470,7 +762,7 @@ async def _poll_job(self, job_location: str) -> dict[str, Any]: response = await client.get( job_location, - headers={"Content-Type": "application/json"} + headers=self._request_headers("GET", job_location) ) # 检查 HTTP 状态码 @@ -537,9 +829,20 @@ async def query_metric_data( labels = self._normalize_metric_labels(labels) async def send_once() -> dict[str, Any]: - session = await self.ensure_session() + if self.auth_mode == "access_key": + raise ZStackApiError( + message=( + "AK/SK 认证不支持 /zstack/api/ message API," + "get_metric_data 暂未配置 REST 路由映射" + ), + code="REST_MAPPING_NOT_FOUND", + details={ + "apiName": "GetMetricData", + "authMode": "access_key", + "hint": "请为 GetMetricData 增加 REST path/method/parameter 映射,或改用账号密码/session 认证。", + }, + ) payload = { - "session": {"uuid": session.uuid}, "namespace": namespace, "metricName": metric_name, "startTime": start_time, @@ -547,6 +850,8 @@ async def send_once() -> dict[str, Any]: "period": period, "labels": labels, } + session = await self.ensure_session() + payload["session"] = {"uuid": session.uuid} payload = {key: value for key, value in payload.items() if value is not None} request_body = { "org.zstack.zwatch.api.APIGetMetricDataMsg": payload @@ -556,7 +861,7 @@ async def send_once() -> dict[str, Any]: response = await client.post( self.api_endpoint, json=request_body, - headers={"Content-Type": "application/json"} + headers=self._request_headers("POST", self.api_endpoint) ) # 检查 HTTP 状态码 diff --git a/tests/test_auth_refactor.py b/tests/test_auth_refactor.py index 4ff7770..1fe8a66 100644 --- a/tests/test_auth_refactor.py +++ b/tests/test_auth_refactor.py @@ -5,6 +5,7 @@ 2. cache_key 从 account 改为 api_url|account(多环境隔离) 3. execute_api 通过 ctx: Context 从 HTTP 头取认证(用 mock ctx) 4. 多环境:同一账号不同 api_url → 各自独立 session +5. AK/SK:从 HTTP 头或环境变量传入 AccessKey/SecretKey,且不触发登录 运行方式: pytest tests/test_auth_refactor.py -v @@ -70,7 +71,14 @@ def _skip_if_unreachable(env_key: str): # 工具函数 # --------------------------------------------------------------------------- -def _make_mock_ctx(account: str, password: str, api_url: str, session_id: str = "") -> MagicMock: +def _make_mock_ctx( + account: str, + password: str, + api_url: str, + session_id: str = "", + access_key_id: str = "", + access_key_secret: str = "", +) -> MagicMock: """构造一个带 HTTP headers 的 mock FastMCP Context""" headers = { "x-zstack-account": account, @@ -79,6 +87,10 @@ def _make_mock_ctx(account: str, password: str, api_url: str, session_id: str = } if session_id: headers["x-zstack-session-id"] = session_id + if access_key_id: + headers["x-zstack-access-key-id"] = access_key_id + if access_key_secret: + headers["x-zstack-access-key-secret"] = access_key_secret mock_request = MagicMock() mock_request.headers = headers @@ -88,7 +100,16 @@ def _make_mock_ctx(account: str, password: str, api_url: str, session_id: str = def _clear_auth_env(): - for key in ("ZSTACK_ACCOUNT", "ZSTACK_PASSWORD", "ZSTACK_SESSION_ID", "ZSTACK_API_URL"): + for key in ( + "ZSTACK_ACCOUNT", + "ZSTACK_PASSWORD", + "ZSTACK_SESSION_ID", + "ZSTACK_ACCESS_KEY_ID", + "ZSTACK_ACCESS_KEY_SECRET", + "ZSTACK_AK", + "ZSTACK_SK", + "ZSTACK_API_URL", + ): os.environ.pop(key, None) @@ -116,6 +137,22 @@ def test_extract_auth_with_session_id(): assert auth.api_url == "http://172.20.0.37:8080" +def test_extract_auth_with_access_key(): + """HTTP 模式:包含 AK/SK 时正确提取""" + ctx = _make_mock_ctx( + "", + "", + "http://172.20.0.37:8080", + access_key_id="ak-123", + access_key_secret="sk-456", + ) + auth = _extract_auth_from_context(ctx) + + assert auth.access_key_id == "ak-123" + assert auth.access_key_secret == "sk-456" + assert auth.api_url == "http://172.20.0.37:8080" + + def test_extract_auth_from_stdio_context(): """stdio 模式(无 HTTP request context):安全返回全空""" class _StdioCtx: @@ -127,6 +164,8 @@ def request_context(self): assert auth.account is None assert auth.password is None assert auth.session_id is None + assert auth.access_key_id is None + assert auth.access_key_secret is None assert auth.api_url is None @@ -156,6 +195,59 @@ async def test_no_credentials_error(): await mgr.get_client() +@pytest.mark.anyio +async def test_access_key_client_uses_cache_without_login(): + """AK/SK 模式不登录,按 api_url + access_key_id 缓存 client""" + _clear_auth_env() + mgr = _SessionManager(max_sessions=3) + try: + client1 = await mgr.get_client( + access_key_id="ak-123", + access_key_secret="sk-456", + api_url="http://dev1:8080", + ) + client2 = await mgr.get_client( + access_key_id="ak-123", + access_key_secret="sk-456", + api_url="http://dev1:8080", + ) + + assert client1 is client2 + assert client1.auth_mode == "access_key" + assert client1.session is None + assert len(mgr._clients) == 1 + finally: + await mgr.logout_all() + + +@pytest.mark.anyio +async def test_access_key_env_fallback(): + """stdio/env 模式可使用 ZSTACK_ACCESS_KEY_ID + ZSTACK_ACCESS_KEY_SECRET""" + _clear_auth_env() + os.environ["ZSTACK_API_URL"] = "http://dev1:8080" + os.environ["ZSTACK_ACCESS_KEY_ID"] = "ak-env" + os.environ["ZSTACK_ACCESS_KEY_SECRET"] = "sk-env" + + mgr = _SessionManager(max_sessions=3) + try: + client = await mgr.get_client() + assert client.auth_mode == "access_key" + assert client.access_key_id == "ak-env" + assert client.access_key_secret == "sk-env" + finally: + await mgr.logout_all() + _clear_auth_env() + + +@pytest.mark.anyio +async def test_access_key_requires_id_and_secret(): + """AK/SK 必须成对传入""" + _clear_auth_env() + mgr = _SessionManager(max_sessions=3) + with pytest.raises(ZStackApiError, match="AK/SK"): + await mgr.get_client(access_key_id="ak-only") + + # --------------------------------------------------------------------------- # 集成测试:env1(172.20.0.37) # --------------------------------------------------------------------------- diff --git a/tests/test_zstack_client_access_key.py b/tests/test_zstack_client_access_key.py new file mode 100644 index 0000000..7abe959 --- /dev/null +++ b/tests/test_zstack_client_access_key.py @@ -0,0 +1,164 @@ +import base64 +import hashlib +import hmac +from typing import Any + +import pytest + +from zstack_mcp.zstack_client import ZStackApiError, ZStackClient + + +class _DummyResponse: + def __init__(self, payload: dict[str, Any], status_code: int = 200): + self._payload = payload + self.status_code = status_code + self.text = str(payload) + + def json(self) -> dict[str, Any]: + return self._payload + + +class _RecordingHttpClient: + def __init__(self): + self.posts: list[dict[str, Any]] = [] + self.gets: list[dict[str, Any]] = [] + self.post_payload: dict[str, Any] = { + "org.zstack.header.zone.APIQueryZoneReply": { + "success": True, + "inventories": [], + } + } + self.get_payload: dict[str, Any] = { + "org.zstack.header.zone.APIQueryZoneReply": { + "success": True, + "inventories": [], + } + } + + async def post(self, url: str, json: dict[str, Any], headers: dict[str, str]): + self.posts.append({"url": url, "json": json, "headers": headers}) + return _DummyResponse(self.post_payload) + + async def get(self, url: str, headers: dict[str, str]): + self.gets.append({"url": url, "headers": headers}) + return _DummyResponse(self.get_payload) + + async def aclose(self): + return None + + +def test_access_key_signature_matches_go_sdk_shape() -> None: + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + date = "Mon, 02 Jan 2006 15:04:05 UTC" + headers = client._access_key_auth_headers("POST", client.api_endpoint, date=date) + + string_to_sign = f"POST\n{date}\n/api/" + expected_signature = base64.b64encode( + hmac.new(b"sk", string_to_sign.encode("utf-8"), hashlib.sha1).digest() + ).decode("ascii") + + assert headers["Authorization"] == f"ZStack ak:{expected_signature}" + assert headers["Date"] == date + assert ZStackClient._canonical_access_key_uri("http://example.com:8080/zstack/api/?x=1") == "/api/" + + +def test_explicit_access_key_is_not_overridden_by_env_session(monkeypatch) -> None: + monkeypatch.setenv("ZSTACK_SESSION_ID", "env-session") + + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + + assert client.auth_mode == "access_key" + assert client.session is None + + +@pytest.mark.anyio +async def test_execute_with_access_key_uses_rest_get(monkeypatch) -> None: + recorder = _RecordingHttpClient() + recorder.get_payload = {"inventories": []} + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + + async def get_http_client(): + return recorder + + monkeypatch.setattr(client, "_get_http_client", get_http_client) + + result = await client.execute( + "QueryZone", + "org.zstack.header.zone.APIQueryZoneMsg", + { + "conditions": [{"name": "name", "op": "=", "value": "zone-a"}], + "limit": 10, + "replyWithCount": True, + }, + ) + + assert result == {"inventories": []} + assert recorder.posts == [] + assert len(recorder.gets) == 1 + assert recorder.gets[0]["url"].startswith("http://example.com:8080/zstack/v1/zones?") + assert "q=name%3Dzone-a" in recorder.gets[0]["url"] + assert "limit=10" in recorder.gets[0]["url"] + assert "replyWithCount=true" in recorder.gets[0]["url"] + assert recorder.gets[0]["headers"]["Authorization"].startswith("ZStack ak:") + assert "Date" in recorder.gets[0]["headers"] + + +@pytest.mark.anyio +async def test_execute_with_access_key_blocks_unmapped_message_api(monkeypatch) -> None: + recorder = _RecordingHttpClient() + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + + async def get_http_client(): + return recorder + + monkeypatch.setattr(client, "_get_http_client", get_http_client) + + with pytest.raises(ZStackApiError) as exc: + await client.execute( + "QueryNotMappedResource", + "org.zstack.header.unknown.APIQueryNotMappedResourceMsg", + {"conditions": []}, + ) + + assert exc.value.code == "REST_MAPPING_NOT_FOUND" + assert recorder.posts == [] + assert recorder.gets == [] + + +@pytest.mark.anyio +async def test_poll_job_with_access_key_signs_get(monkeypatch) -> None: + recorder = _RecordingHttpClient() + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + client.JOB_POLL_INTERVAL = 0 + + async def get_http_client(): + return recorder + + monkeypatch.setattr(client, "_get_http_client", get_http_client) + + result = await client._poll_job("http://example.com:8080/zstack/api/result/job-1") + + assert result["success"] is True + assert len(recorder.gets) == 1 + assert recorder.gets[0]["headers"]["Authorization"].startswith("ZStack ak:") + assert "Date" in recorder.gets[0]["headers"] From 4ae1b841861de012a6174674ed6e8667dcb2f0f1 Mon Sep 17 00:00:00 2001 From: "jiajian.chi" Date: Fri, 19 Jun 2026 00:34:15 +0800 Subject: [PATCH 2/2] Add AK/SK REST metric data support --- src/zstack_mcp/zstack_client.py | 89 ++++++++++++++++++++---- tests/test_zstack_client_access_key.py | 93 ++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 13 deletions(-) diff --git a/src/zstack_mcp/zstack_client.py b/src/zstack_mcp/zstack_client.py index 60fc571..0d232d2 100644 --- a/src/zstack_mcp/zstack_client.py +++ b/src/zstack_mcp/zstack_client.py @@ -52,7 +52,8 @@ class ZStackRestRoute: REST_API_ROUTES: dict[str, ZStackRestRoute] = { - # 常用只读 Query API,路径来自 zstack-sdk-go-v2 generated actions。 + # 常用只读 REST API,路径来自 zstack-sdk-go-v2 generated actions。 + "GetMetricData": ZStackRestRoute("GET", "v1/zwatch/metrics"), "QueryAccessKey": ZStackRestRoute("GET", "v1/accesskeys"), "QueryAccount": ZStackRestRoute("GET", "v1/accounts"), "QueryBackupStorage": ZStackRestRoute("GET", "v1/backup-storage"), @@ -339,6 +340,68 @@ def _rest_query_params(cls, parameters: dict[str, Any]) -> dict[str, Any]: return query + @staticmethod + def _first_present(parameters: dict[str, Any], *keys: str) -> Any: + for key in keys: + if key in parameters and parameters[key] is not None: + return parameters[key] + return None + + @classmethod + def _metric_rest_query_params(cls, parameters: dict[str, Any]) -> dict[str, Any]: + query: dict[str, Any] = {} + + metric_name = cls._first_present(parameters, "metricName", "metric_name") + scalar_values = { + "namespace": cls._first_present(parameters, "namespace"), + "metricName": metric_name, + "startTime": cls._normalize_metric_time( + cls._first_present(parameters, "startTime", "start_time") + ), + "endTime": cls._normalize_metric_time( + cls._first_present(parameters, "endTime", "end_time") + ), + "offsetAheadOfCurrentTime": cls._first_present( + parameters, + "offsetAheadOfCurrentTime", + "offset_ahead_of_current_time", + ), + "period": cls._normalize_metric_period( + cls._first_present(parameters, "period") + ), + } + for key, value in scalar_values.items(): + if value is not None: + query[key] = value + + labels = cls._normalize_metric_labels(parameters.get("labels")) + if labels: + query["labels"] = labels + + repeated_values = { + "valueConditions": cls._first_present( + parameters, + "valueConditions", + "value_conditions", + ), + "functions": cls._first_present(parameters, "functions"), + } + for key, value in repeated_values.items(): + if value is not None: + query[key] = ( + list(value) + if isinstance(value, (list, tuple, set)) + else value + ) + + return query + + @classmethod + def _rest_params_for_api(cls, api_name: str, parameters: dict[str, Any]) -> dict[str, Any]: + if api_name == "GetMetricData": + return cls._metric_rest_query_params(parameters) + return cls._rest_query_params(parameters) + def _rest_route_for_api(self, api_name: str) -> ZStackRestRoute: route = REST_API_ROUTES.get(api_name) if route is None: @@ -371,7 +434,7 @@ async def execute_rest( ) url = self._rest_url(route.path) - query = self._rest_query_params(parameters) + query = self._rest_params_for_api(api_name, parameters) if query: url = f"{url}?{urlencode(query, doseq=True)}" @@ -394,7 +457,8 @@ async def execute_rest( ) if isinstance(result, list): - return {"inventories": result} + list_key = "data" if api_name == "GetMetricData" else "inventories" + return {list_key: result} if isinstance(result, dict): if "error" in result: error = result["error"] @@ -830,16 +894,15 @@ async def query_metric_data( async def send_once() -> dict[str, Any]: if self.auth_mode == "access_key": - raise ZStackApiError( - message=( - "AK/SK 认证不支持 /zstack/api/ message API," - "get_metric_data 暂未配置 REST 路由映射" - ), - code="REST_MAPPING_NOT_FOUND", - details={ - "apiName": "GetMetricData", - "authMode": "access_key", - "hint": "请为 GetMetricData 增加 REST path/method/parameter 映射,或改用账号密码/session 认证。", + return await self.execute_rest( + "GetMetricData", + { + "namespace": namespace, + "metricName": metric_name, + "startTime": start_time, + "endTime": end_time, + "period": period, + "labels": labels, }, ) payload = { diff --git a/tests/test_zstack_client_access_key.py b/tests/test_zstack_client_access_key.py index 7abe959..c49ce7a 100644 --- a/tests/test_zstack_client_access_key.py +++ b/tests/test_zstack_client_access_key.py @@ -2,6 +2,7 @@ import hashlib import hmac from typing import Any +from urllib.parse import parse_qs, urlparse import pytest @@ -141,6 +142,98 @@ async def get_http_client(): assert recorder.gets == [] +@pytest.mark.anyio +async def test_query_metric_data_with_access_key_uses_rest_get(monkeypatch) -> None: + recorder = _RecordingHttpClient() + recorder.get_payload = { + "data": [ + { + "labels": {"VMUuid": "vm-1", "CPUNum": "0"}, + "time": 1700000000, + "value": 42.5, + } + ] + } + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + + async def get_http_client(): + return recorder + + monkeypatch.setattr(client, "_get_http_client", get_http_client) + + result = await client.query_metric_data( + namespace="ZStack/VM", + metric_name="CPUUsedUtilization", + start_time=1700000000000, + end_time=1700000060, + period="60", + labels={"VMUuid": "vm-1", "CPUNum": "0"}, + ) + + assert result == recorder.get_payload + assert recorder.posts == [] + assert len(recorder.gets) == 1 + parsed = urlparse(recorder.gets[0]["url"]) + query = parse_qs(parsed.query) + assert parsed.path == "/zstack/v1/zwatch/metrics" + assert query["namespace"] == ["ZStack/VM"] + assert query["metricName"] == ["CPUUsedUtilization"] + assert query["startTime"] == ["1700000000"] + assert query["endTime"] == ["1700000060"] + assert query["period"] == ["60"] + assert query["labels"] == ["VMUuid=vm-1", "CPUNum=0"] + assert recorder.gets[0]["headers"]["Authorization"].startswith("ZStack ak:") + assert "Date" in recorder.gets[0]["headers"] + + +@pytest.mark.anyio +async def test_execute_get_metric_data_with_access_key_uses_metric_rest_params(monkeypatch) -> None: + recorder = _RecordingHttpClient() + recorder.get_payload = [ + { + "labels": {"VMUuid": "vm-1"}, + "time": 1700000000, + "value": 7, + } + ] + client = ZStackClient( + api_url="http://example.com:8080", + access_key_id="ak", + access_key_secret="sk", + ) + + async def get_http_client(): + return recorder + + monkeypatch.setattr(client, "_get_http_client", get_http_client) + + result = await client.execute( + "GetMetricData", + "org.zstack.zwatch.api.APIGetMetricDataMsg", + { + "namespace": "ZStack/VM", + "metricName": "CPUUsedUtilization", + "valueConditions": ["value>1"], + "functions": ["max"], + }, + ) + + assert result == {"data": recorder.get_payload} + assert recorder.posts == [] + assert len(recorder.gets) == 1 + parsed = urlparse(recorder.gets[0]["url"]) + query = parse_qs(parsed.query) + assert parsed.path == "/zstack/v1/zwatch/metrics" + assert query["namespace"] == ["ZStack/VM"] + assert query["metricName"] == ["CPUUsedUtilization"] + assert query["valueConditions"] == ["value>1"] + assert query["functions"] == ["max"] + + @pytest.mark.anyio async def test_poll_job_with_access_key_signs_get(monkeypatch) -> None: recorder = _RecordingHttpClient()