322 lines
11 KiB
Python
322 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
from pathlib import Path
|
|
from contextlib import redirect_stderr, redirect_stdout
|
|
from io import StringIO
|
|
from typing import Any, Dict, Iterable, Optional
|
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
|
|
class WorkerProtocolError(RuntimeError):
|
|
"""Raised when the worker receives an invalid protocol message."""
|
|
|
|
|
|
class FileAccessPolicy:
|
|
"""Controls which files the worker may read and where it may write outputs."""
|
|
|
|
def __init__(self):
|
|
self.allowed_reads = set()
|
|
self.allowed_read_roots = set()
|
|
self.allowed_write_root = ""
|
|
|
|
@staticmethod
|
|
def _normalize(path: Any) -> str:
|
|
if isinstance(path, Path):
|
|
path = str(path)
|
|
elif hasattr(path, "__fspath__"):
|
|
path = os.fspath(path)
|
|
elif not isinstance(path, str):
|
|
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
|
|
return os.path.realpath(os.path.abspath(path))
|
|
|
|
def configure(
|
|
self,
|
|
allowed_reads: Iterable[Any],
|
|
allowed_write_root: Any,
|
|
allowed_read_roots: Optional[Iterable[Any]] = None,
|
|
) -> None:
|
|
self.allowed_reads = {
|
|
self._normalize(path) for path in allowed_reads if path
|
|
}
|
|
explicit_roots = {
|
|
self._normalize(path) for path in (allowed_read_roots or []) if path
|
|
}
|
|
derived_roots = {
|
|
os.path.dirname(path) for path in self.allowed_reads
|
|
}
|
|
self.allowed_read_roots = explicit_roots | derived_roots
|
|
self.allowed_write_root = (
|
|
self._normalize(allowed_write_root) if allowed_write_root else ""
|
|
)
|
|
|
|
def ensure_readable(self, path: Any) -> str:
|
|
normalized_path = self._normalize(path)
|
|
if normalized_path in self.allowed_reads:
|
|
return normalized_path
|
|
if self._is_within_read_roots(normalized_path):
|
|
return normalized_path
|
|
if self._is_within_write_root(normalized_path):
|
|
return normalized_path
|
|
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
|
|
|
|
def ensure_writable(self, path: Any) -> str:
|
|
normalized_path = self._normalize(path)
|
|
if self._is_within_write_root(normalized_path):
|
|
return normalized_path
|
|
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
|
|
|
|
def _is_within_write_root(self, normalized_path: str) -> bool:
|
|
if not self.allowed_write_root:
|
|
return False
|
|
return normalized_path == self.allowed_write_root or normalized_path.startswith(
|
|
self.allowed_write_root + os.sep
|
|
)
|
|
|
|
def _is_within_read_roots(self, normalized_path: str) -> bool:
|
|
for root in self.allowed_read_roots:
|
|
if normalized_path == root or normalized_path.startswith(root + os.sep):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _write_message(message: Dict[str, Any]) -> None:
|
|
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def _write_log(text: str) -> None:
|
|
if not text:
|
|
return
|
|
sys.stderr.write(text)
|
|
if not text.endswith("\n"):
|
|
sys.stderr.write("\n")
|
|
sys.stderr.flush()
|
|
|
|
|
|
class ExecutionWorker:
|
|
"""JSON-line protocol wrapper around CodeExecutor."""
|
|
|
|
def __init__(self):
|
|
self.executor = None
|
|
self.access_policy = FileAccessPolicy()
|
|
self._patches_installed = False
|
|
|
|
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
request_id = request.get("request_id", "")
|
|
request_type = request.get("type")
|
|
payload = request.get("payload", {})
|
|
|
|
try:
|
|
if request_type == "ping":
|
|
return self._ok(request_id, {"alive": True})
|
|
if request_type == "init_session":
|
|
return self._handle_init_session(request_id, payload)
|
|
if request_type == "execute_code":
|
|
self._require_executor()
|
|
return self._ok(
|
|
request_id,
|
|
self.executor.execute_code(payload.get("code", "")),
|
|
)
|
|
if request_type == "set_variable":
|
|
self._require_executor()
|
|
self._handle_set_variable(payload["name"], payload["value"])
|
|
return self._ok(request_id, {"set": True})
|
|
if request_type == "get_environment_info":
|
|
self._require_executor()
|
|
return self._ok(
|
|
request_id,
|
|
{"environment_info": self.executor.get_environment_info()},
|
|
)
|
|
if request_type == "reset_environment":
|
|
self._require_executor()
|
|
self.executor.reset_environment()
|
|
return self._ok(request_id, {"reset": True})
|
|
if request_type == "shutdown":
|
|
return self._ok(request_id, {"shutdown": True})
|
|
|
|
raise WorkerProtocolError(f"未知请求类型: {request_type}")
|
|
except Exception as exc:
|
|
return {
|
|
"request_id": request_id,
|
|
"status": "error",
|
|
"error": str(exc),
|
|
"traceback": traceback.format_exc(),
|
|
}
|
|
|
|
def _handle_init_session(
|
|
self, request_id: str, payload: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
output_dir = payload.get("output_dir")
|
|
if not output_dir:
|
|
raise WorkerProtocolError("init_session 缺少 output_dir")
|
|
|
|
from utils.code_executor import CodeExecutor
|
|
|
|
self.executor = CodeExecutor(output_dir)
|
|
self.access_policy.configure(
|
|
payload.get("variables", {}).get("allowed_files", []),
|
|
output_dir,
|
|
payload.get("variables", {}).get("allowed_read_roots", []),
|
|
)
|
|
self._install_file_guards()
|
|
|
|
for name, value in payload.get("variables", {}).items():
|
|
self.executor.set_variable(name, value)
|
|
|
|
return self._ok(request_id, {"initialized": True})
|
|
|
|
def _install_file_guards(self) -> None:
|
|
if self._patches_installed:
|
|
return
|
|
|
|
import builtins
|
|
import matplotlib.figure
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
policy = self.access_policy
|
|
|
|
original_open = builtins.open
|
|
original_read_csv = pd.read_csv
|
|
original_read_excel = pd.read_excel
|
|
original_to_csv = pd.DataFrame.to_csv
|
|
original_to_excel = pd.DataFrame.to_excel
|
|
original_plt_savefig = plt.savefig
|
|
original_figure_savefig = matplotlib.figure.Figure.savefig
|
|
|
|
def guarded_open(file, mode="r", *args, **kwargs):
|
|
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
|
|
if any(flag in mode for flag in ("w", "a", "x", "+")):
|
|
policy.ensure_writable(file)
|
|
else:
|
|
policy.ensure_readable(file)
|
|
return original_open(file, mode, *args, **kwargs)
|
|
|
|
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
|
|
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
|
|
filepath_or_buffer, "__fspath__"
|
|
):
|
|
policy.ensure_readable(filepath_or_buffer)
|
|
return original_read_csv(filepath_or_buffer, *args, **kwargs)
|
|
|
|
def guarded_read_excel(io, *args, **kwargs):
|
|
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
|
|
policy.ensure_readable(io)
|
|
return original_read_excel(io, *args, **kwargs)
|
|
|
|
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
|
|
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
|
|
policy.ensure_writable(path_or_buf)
|
|
return original_to_csv(df, path_or_buf, *args, **kwargs)
|
|
|
|
def guarded_to_excel(df, excel_writer, *args, **kwargs):
|
|
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
|
|
policy.ensure_writable(excel_writer)
|
|
return original_to_excel(df, excel_writer, *args, **kwargs)
|
|
|
|
def guarded_savefig(*args, **kwargs):
|
|
target = args[0] if args else kwargs.get("fname")
|
|
if target is not None and (
|
|
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
|
|
):
|
|
policy.ensure_writable(target)
|
|
return original_plt_savefig(*args, **kwargs)
|
|
|
|
def guarded_figure_savefig(fig, fname, *args, **kwargs):
|
|
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
|
|
policy.ensure_writable(fname)
|
|
return original_figure_savefig(fig, fname, *args, **kwargs)
|
|
|
|
builtins.open = guarded_open
|
|
pd.read_csv = guarded_read_csv
|
|
pd.read_excel = guarded_read_excel
|
|
pd.DataFrame.to_csv = guarded_to_csv
|
|
pd.DataFrame.to_excel = guarded_to_excel
|
|
plt.savefig = guarded_savefig
|
|
matplotlib.figure.Figure.savefig = guarded_figure_savefig
|
|
|
|
self._patches_installed = True
|
|
|
|
def _require_executor(self) -> None:
|
|
if self.executor is None:
|
|
raise WorkerProtocolError("执行会话尚未初始化")
|
|
|
|
def _handle_set_variable(self, name: str, value: Any) -> None:
|
|
self.executor.set_variable(name, value)
|
|
|
|
if name == "allowed_files":
|
|
self.access_policy.configure(
|
|
value,
|
|
self.access_policy.allowed_write_root,
|
|
self.access_policy.allowed_read_roots,
|
|
)
|
|
elif name == "allowed_read_roots":
|
|
self.access_policy.configure(
|
|
self.access_policy.allowed_reads,
|
|
self.access_policy.allowed_write_root,
|
|
value,
|
|
)
|
|
elif name == "session_output_dir":
|
|
self.access_policy.configure(
|
|
self.access_policy.allowed_reads,
|
|
value,
|
|
self.access_policy.allowed_read_roots,
|
|
)
|
|
|
|
@staticmethod
|
|
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
return {
|
|
"request_id": request_id,
|
|
"status": "ok",
|
|
"payload": payload,
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
worker = ExecutionWorker()
|
|
|
|
for raw_line in sys.stdin:
|
|
raw_line = raw_line.strip()
|
|
if not raw_line:
|
|
continue
|
|
|
|
try:
|
|
request = json.loads(raw_line)
|
|
except json.JSONDecodeError as exc:
|
|
_write_message(
|
|
{
|
|
"request_id": "",
|
|
"status": "error",
|
|
"error": f"无效JSON请求: {exc}",
|
|
}
|
|
)
|
|
continue
|
|
|
|
captured_stdout = StringIO()
|
|
captured_stderr = StringIO()
|
|
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
|
|
response = worker.handle_request(request)
|
|
|
|
_write_log(captured_stdout.getvalue())
|
|
_write_log(captured_stderr.getvalue())
|
|
_write_message(response)
|
|
|
|
if request.get("type") == "shutdown" and response.get("status") == "ok":
|
|
return 0
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|