# -*- coding: utf-8 -*- """ Client for a per-analysis execution worker subprocess. """ import json import os import queue import subprocess import sys import threading import uuid from typing import Any, Dict, Optional class WorkerSessionError(RuntimeError): """Raised when the execution worker cannot serve a request.""" class WorkerTimeoutError(WorkerSessionError): """Raised when the worker does not respond within the configured timeout.""" class ExecutionSessionClient: """Client that proxies CodeExecutor methods to a dedicated worker process.""" def __init__( self, output_dir: str, allowed_files=None, python_executable: str = None, request_timeout_seconds: float = 60.0, startup_timeout_seconds: float = 180.0, ): self.output_dir = os.path.abspath(output_dir) self.allowed_files = [os.path.abspath(path) for path in (allowed_files or [])] self.allowed_read_roots = sorted( {os.path.dirname(path) for path in self.allowed_files} ) self.python_executable = python_executable or sys.executable self.request_timeout_seconds = request_timeout_seconds self.startup_timeout_seconds = startup_timeout_seconds self._process: Optional[subprocess.Popen] = None self._stderr_handle = None self._start_worker() self._request( "init_session", { "output_dir": self.output_dir, "variables": { "session_output_dir": self.output_dir, "allowed_files": self.allowed_files, "allowed_read_roots": self.allowed_read_roots, }, }, timeout_seconds=self.startup_timeout_seconds, ) def execute_code(self, code: str) -> Dict[str, Any]: return self._request("execute_code", {"code": code}) def set_variable(self, name: str, value: Any) -> None: self._request("set_variable", {"name": name, "value": value}) def get_environment_info(self) -> str: payload = self._request("get_environment_info", {}) return payload.get("environment_info", "") def reset_environment(self) -> None: self._request("reset_environment", {}) self.set_variable("session_output_dir", self.output_dir) self.set_variable("allowed_files", self.allowed_files) self.set_variable("allowed_read_roots", self.allowed_read_roots) def ping(self) -> bool: payload = self._request("ping", {}) return bool(payload.get("alive")) def close(self) -> None: if self._process is None: return try: if self._process.poll() is None: self._request("shutdown", {}, timeout_seconds=5) except Exception: pass finally: self._teardown_worker() def _start_worker(self) -> None: runtime_dir = os.path.join(self.output_dir, ".worker_runtime") mpl_dir = os.path.join(runtime_dir, "mplconfig") ipython_dir = os.path.join(runtime_dir, "ipython") os.makedirs(mpl_dir, exist_ok=True) os.makedirs(ipython_dir, exist_ok=True) stderr_log_path = os.path.join(self.output_dir, "execution_worker.log") self._stderr_handle = open(stderr_log_path, "a", encoding="utf-8") worker_script = os.path.join( os.path.dirname(__file__), "execution_worker.py", ) env = os.environ.copy() env["MPLCONFIGDIR"] = mpl_dir env["IPYTHONDIR"] = ipython_dir env.setdefault("PYTHONIOENCODING", "utf-8") self._process = subprocess.Popen( [self.python_executable, worker_script], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=self._stderr_handle, text=True, encoding="utf-8", cwd=os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), env=env, bufsize=1, ) def _request( self, request_type: str, payload: Dict[str, Any], timeout_seconds: float = None, ) -> Dict[str, Any]: if self._process is None or self._process.stdin is None or self._process.stdout is None: raise WorkerSessionError("执行子进程未启动") if self._process.poll() is not None: raise WorkerSessionError( f"执行子进程已退出,退出码: {self._process.returncode}" ) request_id = str(uuid.uuid4()) message = { "request_id": request_id, "type": request_type, "payload": payload, } try: self._process.stdin.write(json.dumps(message, ensure_ascii=False) + "\n") self._process.stdin.flush() except BrokenPipeError as exc: self._teardown_worker() raise WorkerSessionError("执行子进程通信中断") from exc effective_timeout = ( self.request_timeout_seconds if timeout_seconds is None else timeout_seconds ) response_line = self._read_response_line(effective_timeout) if not response_line: if self._process.poll() is not None: exit_code = self._process.returncode self._teardown_worker() raise WorkerSessionError(f"执行子进程已异常退出,退出码: {exit_code}") raise WorkerSessionError("执行子进程未返回响应") try: response = json.loads(response_line) except json.JSONDecodeError as exc: raise WorkerSessionError(f"执行子进程返回了无效JSON: {response_line}") from exc if response.get("request_id") != request_id: raise WorkerSessionError("执行子进程响应 request_id 不匹配") if response.get("status") != "ok": raise WorkerSessionError(response.get("error", "执行子进程返回未知错误")) return response.get("payload", {}) def _read_response_line(self, timeout_seconds: float) -> str: assert self._process is not None and self._process.stdout is not None response_queue: queue.Queue = queue.Queue(maxsize=1) def _reader() -> None: try: response_queue.put((True, self._process.stdout.readline())) except Exception as exc: response_queue.put((False, exc)) thread = threading.Thread(target=_reader, daemon=True) thread.start() try: success, value = response_queue.get(timeout=timeout_seconds) except queue.Empty as exc: self._teardown_worker(force=True) raise WorkerTimeoutError( f"执行子进程在 {timeout_seconds:.1f} 秒内未响应,已终止当前会话" ) from exc if success: return value self._teardown_worker() raise WorkerSessionError(f"读取执行子进程响应失败: {value}") def _teardown_worker(self, force: bool = False) -> None: if self._process is not None and self._process.poll() is None: self._process.terminate() try: self._process.wait(timeout=5) except subprocess.TimeoutExpired: self._process.kill() self._process.wait(timeout=5) if self._stderr_handle is not None: self._stderr_handle.close() self._stderr_handle = None self._process = None def __del__(self): self.close()