Files
vibe_data_ana/utils/execution_worker.py

322 lines
11 KiB
Python

# -*- coding: utf-8 -*-
"""
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
"""
import json
import os
import sys
import traceback
from pathlib import Path
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Dict, Iterable, Optional
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
class WorkerProtocolError(RuntimeError):
"""Raised when the worker receives an invalid protocol message."""
class FileAccessPolicy:
"""Controls which files the worker may read and where it may write outputs."""
def __init__(self):
self.allowed_reads = set()
self.allowed_read_roots = set()
self.allowed_write_root = ""
@staticmethod
def _normalize(path: Any) -> str:
if isinstance(path, Path):
path = str(path)
elif hasattr(path, "__fspath__"):
path = os.fspath(path)
elif not isinstance(path, str):
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
return os.path.realpath(os.path.abspath(path))
def configure(
self,
allowed_reads: Iterable[Any],
allowed_write_root: Any,
allowed_read_roots: Optional[Iterable[Any]] = None,
) -> None:
self.allowed_reads = {
self._normalize(path) for path in allowed_reads if path
}
explicit_roots = {
self._normalize(path) for path in (allowed_read_roots or []) if path
}
derived_roots = {
os.path.dirname(path) for path in self.allowed_reads
}
self.allowed_read_roots = explicit_roots | derived_roots
self.allowed_write_root = (
self._normalize(allowed_write_root) if allowed_write_root else ""
)
def ensure_readable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if normalized_path in self.allowed_reads:
return normalized_path
if self._is_within_read_roots(normalized_path):
return normalized_path
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
def ensure_writable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
def _is_within_write_root(self, normalized_path: str) -> bool:
if not self.allowed_write_root:
return False
return normalized_path == self.allowed_write_root or normalized_path.startswith(
self.allowed_write_root + os.sep
)
def _is_within_read_roots(self, normalized_path: str) -> bool:
for root in self.allowed_read_roots:
if normalized_path == root or normalized_path.startswith(root + os.sep):
return True
return False
def _write_message(message: Dict[str, Any]) -> None:
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
sys.stdout.flush()
def _write_log(text: str) -> None:
if not text:
return
sys.stderr.write(text)
if not text.endswith("\n"):
sys.stderr.write("\n")
sys.stderr.flush()
class ExecutionWorker:
"""JSON-line protocol wrapper around CodeExecutor."""
def __init__(self):
self.executor = None
self.access_policy = FileAccessPolicy()
self._patches_installed = False
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_id = request.get("request_id", "")
request_type = request.get("type")
payload = request.get("payload", {})
try:
if request_type == "ping":
return self._ok(request_id, {"alive": True})
if request_type == "init_session":
return self._handle_init_session(request_id, payload)
if request_type == "execute_code":
self._require_executor()
return self._ok(
request_id,
self.executor.execute_code(payload.get("code", "")),
)
if request_type == "set_variable":
self._require_executor()
self._handle_set_variable(payload["name"], payload["value"])
return self._ok(request_id, {"set": True})
if request_type == "get_environment_info":
self._require_executor()
return self._ok(
request_id,
{"environment_info": self.executor.get_environment_info()},
)
if request_type == "reset_environment":
self._require_executor()
self.executor.reset_environment()
return self._ok(request_id, {"reset": True})
if request_type == "shutdown":
return self._ok(request_id, {"shutdown": True})
raise WorkerProtocolError(f"未知请求类型: {request_type}")
except Exception as exc:
return {
"request_id": request_id,
"status": "error",
"error": str(exc),
"traceback": traceback.format_exc(),
}
def _handle_init_session(
self, request_id: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
output_dir = payload.get("output_dir")
if not output_dir:
raise WorkerProtocolError("init_session 缺少 output_dir")
from utils.code_executor import CodeExecutor
self.executor = CodeExecutor(output_dir)
self.access_policy.configure(
payload.get("variables", {}).get("allowed_files", []),
output_dir,
payload.get("variables", {}).get("allowed_read_roots", []),
)
self._install_file_guards()
for name, value in payload.get("variables", {}).items():
self.executor.set_variable(name, value)
return self._ok(request_id, {"initialized": True})
def _install_file_guards(self) -> None:
if self._patches_installed:
return
import builtins
import matplotlib.figure
import matplotlib.pyplot as plt
import pandas as pd
policy = self.access_policy
original_open = builtins.open
original_read_csv = pd.read_csv
original_read_excel = pd.read_excel
original_to_csv = pd.DataFrame.to_csv
original_to_excel = pd.DataFrame.to_excel
original_plt_savefig = plt.savefig
original_figure_savefig = matplotlib.figure.Figure.savefig
def guarded_open(file, mode="r", *args, **kwargs):
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
if any(flag in mode for flag in ("w", "a", "x", "+")):
policy.ensure_writable(file)
else:
policy.ensure_readable(file)
return original_open(file, mode, *args, **kwargs)
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
filepath_or_buffer, "__fspath__"
):
policy.ensure_readable(filepath_or_buffer)
return original_read_csv(filepath_or_buffer, *args, **kwargs)
def guarded_read_excel(io, *args, **kwargs):
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
policy.ensure_readable(io)
return original_read_excel(io, *args, **kwargs)
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
policy.ensure_writable(path_or_buf)
return original_to_csv(df, path_or_buf, *args, **kwargs)
def guarded_to_excel(df, excel_writer, *args, **kwargs):
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
policy.ensure_writable(excel_writer)
return original_to_excel(df, excel_writer, *args, **kwargs)
def guarded_savefig(*args, **kwargs):
target = args[0] if args else kwargs.get("fname")
if target is not None and (
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
):
policy.ensure_writable(target)
return original_plt_savefig(*args, **kwargs)
def guarded_figure_savefig(fig, fname, *args, **kwargs):
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
policy.ensure_writable(fname)
return original_figure_savefig(fig, fname, *args, **kwargs)
builtins.open = guarded_open
pd.read_csv = guarded_read_csv
pd.read_excel = guarded_read_excel
pd.DataFrame.to_csv = guarded_to_csv
pd.DataFrame.to_excel = guarded_to_excel
plt.savefig = guarded_savefig
matplotlib.figure.Figure.savefig = guarded_figure_savefig
self._patches_installed = True
def _require_executor(self) -> None:
if self.executor is None:
raise WorkerProtocolError("执行会话尚未初始化")
def _handle_set_variable(self, name: str, value: Any) -> None:
self.executor.set_variable(name, value)
if name == "allowed_files":
self.access_policy.configure(
value,
self.access_policy.allowed_write_root,
self.access_policy.allowed_read_roots,
)
elif name == "allowed_read_roots":
self.access_policy.configure(
self.access_policy.allowed_reads,
self.access_policy.allowed_write_root,
value,
)
elif name == "session_output_dir":
self.access_policy.configure(
self.access_policy.allowed_reads,
value,
self.access_policy.allowed_read_roots,
)
@staticmethod
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"request_id": request_id,
"status": "ok",
"payload": payload,
}
def main() -> int:
worker = ExecutionWorker()
for raw_line in sys.stdin:
raw_line = raw_line.strip()
if not raw_line:
continue
try:
request = json.loads(raw_line)
except json.JSONDecodeError as exc:
_write_message(
{
"request_id": "",
"status": "error",
"error": f"无效JSON请求: {exc}",
}
)
continue
captured_stdout = StringIO()
captured_stderr = StringIO()
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
response = worker.handle_request(request)
_write_log(captured_stdout.getvalue())
_write_log(captured_stderr.getvalue())
_write_message(response)
if request.get("type") == "shutdown" and response.get("status") == "ok":
return 0
return 0
if __name__ == "__main__":
raise SystemExit(main())