# -*- coding: utf-8 -*- """ Subprocess worker that hosts a single CodeExecutor instance for one analysis session. """ import json import os import sys import traceback from pathlib import Path from contextlib import redirect_stderr, redirect_stdout from io import StringIO from typing import Any, Dict, Iterable, Optional PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) class WorkerProtocolError(RuntimeError): """Raised when the worker receives an invalid protocol message.""" class FileAccessPolicy: """Controls which files the worker may read and where it may write outputs.""" def __init__(self): self.allowed_reads = set() self.allowed_read_roots = set() self.allowed_write_root = "" @staticmethod def _normalize(path: Any) -> str: if isinstance(path, Path): path = str(path) elif hasattr(path, "__fspath__"): path = os.fspath(path) elif not isinstance(path, str): raise TypeError(f"不支持的路径类型: {type(path).__name__}") return os.path.realpath(os.path.abspath(path)) def configure( self, allowed_reads: Iterable[Any], allowed_write_root: Any, allowed_read_roots: Optional[Iterable[Any]] = None, ) -> None: self.allowed_reads = { self._normalize(path) for path in allowed_reads if path } explicit_roots = { self._normalize(path) for path in (allowed_read_roots or []) if path } derived_roots = { os.path.dirname(path) for path in self.allowed_reads } self.allowed_read_roots = explicit_roots | derived_roots self.allowed_write_root = ( self._normalize(allowed_write_root) if allowed_write_root else "" ) def ensure_readable(self, path: Any) -> str: normalized_path = self._normalize(path) if normalized_path in self.allowed_reads: return normalized_path if self._is_within_read_roots(normalized_path): return normalized_path if self._is_within_write_root(normalized_path): return normalized_path raise PermissionError(f"禁止读取未授权文件: {normalized_path}") def ensure_writable(self, path: Any) -> str: normalized_path = self._normalize(path) if self._is_within_write_root(normalized_path): return normalized_path raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}") def _is_within_write_root(self, normalized_path: str) -> bool: if not self.allowed_write_root: return False return normalized_path == self.allowed_write_root or normalized_path.startswith( self.allowed_write_root + os.sep ) def _is_within_read_roots(self, normalized_path: str) -> bool: for root in self.allowed_read_roots: if normalized_path == root or normalized_path.startswith(root + os.sep): return True return False def _write_message(message: Dict[str, Any]) -> None: sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n") sys.stdout.flush() def _write_log(text: str) -> None: if not text: return sys.stderr.write(text) if not text.endswith("\n"): sys.stderr.write("\n") sys.stderr.flush() class ExecutionWorker: """JSON-line protocol wrapper around CodeExecutor.""" def __init__(self): self.executor = None self.access_policy = FileAccessPolicy() self._patches_installed = False def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]: request_id = request.get("request_id", "") request_type = request.get("type") payload = request.get("payload", {}) try: if request_type == "ping": return self._ok(request_id, {"alive": True}) if request_type == "init_session": return self._handle_init_session(request_id, payload) if request_type == "execute_code": self._require_executor() return self._ok( request_id, self.executor.execute_code(payload.get("code", "")), ) if request_type == "set_variable": self._require_executor() self._handle_set_variable(payload["name"], payload["value"]) return self._ok(request_id, {"set": True}) if request_type == "get_environment_info": self._require_executor() return self._ok( request_id, {"environment_info": self.executor.get_environment_info()}, ) if request_type == "reset_environment": self._require_executor() self.executor.reset_environment() return self._ok(request_id, {"reset": True}) if request_type == "shutdown": return self._ok(request_id, {"shutdown": True}) raise WorkerProtocolError(f"未知请求类型: {request_type}") except Exception as exc: return { "request_id": request_id, "status": "error", "error": str(exc), "traceback": traceback.format_exc(), } def _handle_init_session( self, request_id: str, payload: Dict[str, Any] ) -> Dict[str, Any]: output_dir = payload.get("output_dir") if not output_dir: raise WorkerProtocolError("init_session 缺少 output_dir") from utils.code_executor import CodeExecutor self.executor = CodeExecutor(output_dir) self.access_policy.configure( payload.get("variables", {}).get("allowed_files", []), output_dir, payload.get("variables", {}).get("allowed_read_roots", []), ) self._install_file_guards() for name, value in payload.get("variables", {}).items(): self.executor.set_variable(name, value) return self._ok(request_id, {"initialized": True}) def _install_file_guards(self) -> None: if self._patches_installed: return import builtins import matplotlib.figure import matplotlib.pyplot as plt import pandas as pd policy = self.access_policy original_open = builtins.open original_read_csv = pd.read_csv original_read_excel = pd.read_excel original_to_csv = pd.DataFrame.to_csv original_to_excel = pd.DataFrame.to_excel original_plt_savefig = plt.savefig original_figure_savefig = matplotlib.figure.Figure.savefig def guarded_open(file, mode="r", *args, **kwargs): if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"): if any(flag in mode for flag in ("w", "a", "x", "+")): policy.ensure_writable(file) else: policy.ensure_readable(file) return original_open(file, mode, *args, **kwargs) def guarded_read_csv(filepath_or_buffer, *args, **kwargs): if isinstance(filepath_or_buffer, (str, Path)) or hasattr( filepath_or_buffer, "__fspath__" ): policy.ensure_readable(filepath_or_buffer) return original_read_csv(filepath_or_buffer, *args, **kwargs) def guarded_read_excel(io, *args, **kwargs): if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"): policy.ensure_readable(io) return original_read_excel(io, *args, **kwargs) def guarded_to_csv(df, path_or_buf=None, *args, **kwargs): if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"): policy.ensure_writable(path_or_buf) return original_to_csv(df, path_or_buf, *args, **kwargs) def guarded_to_excel(df, excel_writer, *args, **kwargs): if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"): policy.ensure_writable(excel_writer) return original_to_excel(df, excel_writer, *args, **kwargs) def guarded_savefig(*args, **kwargs): target = args[0] if args else kwargs.get("fname") if target is not None and ( isinstance(target, (str, Path)) or hasattr(target, "__fspath__") ): policy.ensure_writable(target) return original_plt_savefig(*args, **kwargs) def guarded_figure_savefig(fig, fname, *args, **kwargs): if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"): policy.ensure_writable(fname) return original_figure_savefig(fig, fname, *args, **kwargs) builtins.open = guarded_open pd.read_csv = guarded_read_csv pd.read_excel = guarded_read_excel pd.DataFrame.to_csv = guarded_to_csv pd.DataFrame.to_excel = guarded_to_excel plt.savefig = guarded_savefig matplotlib.figure.Figure.savefig = guarded_figure_savefig self._patches_installed = True def _require_executor(self) -> None: if self.executor is None: raise WorkerProtocolError("执行会话尚未初始化") def _handle_set_variable(self, name: str, value: Any) -> None: self.executor.set_variable(name, value) if name == "allowed_files": self.access_policy.configure( value, self.access_policy.allowed_write_root, self.access_policy.allowed_read_roots, ) elif name == "allowed_read_roots": self.access_policy.configure( self.access_policy.allowed_reads, self.access_policy.allowed_write_root, value, ) elif name == "session_output_dir": self.access_policy.configure( self.access_policy.allowed_reads, value, self.access_policy.allowed_read_roots, ) @staticmethod def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: return { "request_id": request_id, "status": "ok", "payload": payload, } def main() -> int: worker = ExecutionWorker() for raw_line in sys.stdin: raw_line = raw_line.strip() if not raw_line: continue try: request = json.loads(raw_line) except json.JSONDecodeError as exc: _write_message( { "request_id": "", "status": "error", "error": f"无效JSON请求: {exc}", } ) continue captured_stdout = StringIO() captured_stderr = StringIO() with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr): response = worker.handle_request(request) _write_log(captured_stdout.getvalue()) _write_log(captured_stderr.getvalue()) _write_message(response) if request.get("type") == "shutdown" and response.get("status") == "ok": return 0 return 0 if __name__ == "__main__": raise SystemExit(main())