Add web session analysis platform with follow-up topics
This commit is contained in:
321
utils/execution_worker.py
Normal file
321
utils/execution_worker.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
class WorkerProtocolError(RuntimeError):
|
||||
"""Raised when the worker receives an invalid protocol message."""
|
||||
|
||||
|
||||
class FileAccessPolicy:
|
||||
"""Controls which files the worker may read and where it may write outputs."""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_reads = set()
|
||||
self.allowed_read_roots = set()
|
||||
self.allowed_write_root = ""
|
||||
|
||||
@staticmethod
|
||||
def _normalize(path: Any) -> str:
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
elif hasattr(path, "__fspath__"):
|
||||
path = os.fspath(path)
|
||||
elif not isinstance(path, str):
|
||||
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
|
||||
return os.path.realpath(os.path.abspath(path))
|
||||
|
||||
def configure(
|
||||
self,
|
||||
allowed_reads: Iterable[Any],
|
||||
allowed_write_root: Any,
|
||||
allowed_read_roots: Optional[Iterable[Any]] = None,
|
||||
) -> None:
|
||||
self.allowed_reads = {
|
||||
self._normalize(path) for path in allowed_reads if path
|
||||
}
|
||||
explicit_roots = {
|
||||
self._normalize(path) for path in (allowed_read_roots or []) if path
|
||||
}
|
||||
derived_roots = {
|
||||
os.path.dirname(path) for path in self.allowed_reads
|
||||
}
|
||||
self.allowed_read_roots = explicit_roots | derived_roots
|
||||
self.allowed_write_root = (
|
||||
self._normalize(allowed_write_root) if allowed_write_root else ""
|
||||
)
|
||||
|
||||
def ensure_readable(self, path: Any) -> str:
|
||||
normalized_path = self._normalize(path)
|
||||
if normalized_path in self.allowed_reads:
|
||||
return normalized_path
|
||||
if self._is_within_read_roots(normalized_path):
|
||||
return normalized_path
|
||||
if self._is_within_write_root(normalized_path):
|
||||
return normalized_path
|
||||
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
|
||||
|
||||
def ensure_writable(self, path: Any) -> str:
|
||||
normalized_path = self._normalize(path)
|
||||
if self._is_within_write_root(normalized_path):
|
||||
return normalized_path
|
||||
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
|
||||
|
||||
def _is_within_write_root(self, normalized_path: str) -> bool:
|
||||
if not self.allowed_write_root:
|
||||
return False
|
||||
return normalized_path == self.allowed_write_root or normalized_path.startswith(
|
||||
self.allowed_write_root + os.sep
|
||||
)
|
||||
|
||||
def _is_within_read_roots(self, normalized_path: str) -> bool:
|
||||
for root in self.allowed_read_roots:
|
||||
if normalized_path == root or normalized_path.startswith(root + os.sep):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _write_message(message: Dict[str, Any]) -> None:
|
||||
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _write_log(text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
sys.stderr.write(text)
|
||||
if not text.endswith("\n"):
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
|
||||
class ExecutionWorker:
|
||||
"""JSON-line protocol wrapper around CodeExecutor."""
|
||||
|
||||
def __init__(self):
|
||||
self.executor = None
|
||||
self.access_policy = FileAccessPolicy()
|
||||
self._patches_installed = False
|
||||
|
||||
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
request_id = request.get("request_id", "")
|
||||
request_type = request.get("type")
|
||||
payload = request.get("payload", {})
|
||||
|
||||
try:
|
||||
if request_type == "ping":
|
||||
return self._ok(request_id, {"alive": True})
|
||||
if request_type == "init_session":
|
||||
return self._handle_init_session(request_id, payload)
|
||||
if request_type == "execute_code":
|
||||
self._require_executor()
|
||||
return self._ok(
|
||||
request_id,
|
||||
self.executor.execute_code(payload.get("code", "")),
|
||||
)
|
||||
if request_type == "set_variable":
|
||||
self._require_executor()
|
||||
self._handle_set_variable(payload["name"], payload["value"])
|
||||
return self._ok(request_id, {"set": True})
|
||||
if request_type == "get_environment_info":
|
||||
self._require_executor()
|
||||
return self._ok(
|
||||
request_id,
|
||||
{"environment_info": self.executor.get_environment_info()},
|
||||
)
|
||||
if request_type == "reset_environment":
|
||||
self._require_executor()
|
||||
self.executor.reset_environment()
|
||||
return self._ok(request_id, {"reset": True})
|
||||
if request_type == "shutdown":
|
||||
return self._ok(request_id, {"shutdown": True})
|
||||
|
||||
raise WorkerProtocolError(f"未知请求类型: {request_type}")
|
||||
except Exception as exc:
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
|
||||
def _handle_init_session(
|
||||
self, request_id: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
output_dir = payload.get("output_dir")
|
||||
if not output_dir:
|
||||
raise WorkerProtocolError("init_session 缺少 output_dir")
|
||||
|
||||
from utils.code_executor import CodeExecutor
|
||||
|
||||
self.executor = CodeExecutor(output_dir)
|
||||
self.access_policy.configure(
|
||||
payload.get("variables", {}).get("allowed_files", []),
|
||||
output_dir,
|
||||
payload.get("variables", {}).get("allowed_read_roots", []),
|
||||
)
|
||||
self._install_file_guards()
|
||||
|
||||
for name, value in payload.get("variables", {}).items():
|
||||
self.executor.set_variable(name, value)
|
||||
|
||||
return self._ok(request_id, {"initialized": True})
|
||||
|
||||
def _install_file_guards(self) -> None:
|
||||
if self._patches_installed:
|
||||
return
|
||||
|
||||
import builtins
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
policy = self.access_policy
|
||||
|
||||
original_open = builtins.open
|
||||
original_read_csv = pd.read_csv
|
||||
original_read_excel = pd.read_excel
|
||||
original_to_csv = pd.DataFrame.to_csv
|
||||
original_to_excel = pd.DataFrame.to_excel
|
||||
original_plt_savefig = plt.savefig
|
||||
original_figure_savefig = matplotlib.figure.Figure.savefig
|
||||
|
||||
def guarded_open(file, mode="r", *args, **kwargs):
|
||||
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
|
||||
if any(flag in mode for flag in ("w", "a", "x", "+")):
|
||||
policy.ensure_writable(file)
|
||||
else:
|
||||
policy.ensure_readable(file)
|
||||
return original_open(file, mode, *args, **kwargs)
|
||||
|
||||
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
|
||||
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
|
||||
filepath_or_buffer, "__fspath__"
|
||||
):
|
||||
policy.ensure_readable(filepath_or_buffer)
|
||||
return original_read_csv(filepath_or_buffer, *args, **kwargs)
|
||||
|
||||
def guarded_read_excel(io, *args, **kwargs):
|
||||
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
|
||||
policy.ensure_readable(io)
|
||||
return original_read_excel(io, *args, **kwargs)
|
||||
|
||||
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
|
||||
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
|
||||
policy.ensure_writable(path_or_buf)
|
||||
return original_to_csv(df, path_or_buf, *args, **kwargs)
|
||||
|
||||
def guarded_to_excel(df, excel_writer, *args, **kwargs):
|
||||
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
|
||||
policy.ensure_writable(excel_writer)
|
||||
return original_to_excel(df, excel_writer, *args, **kwargs)
|
||||
|
||||
def guarded_savefig(*args, **kwargs):
|
||||
target = args[0] if args else kwargs.get("fname")
|
||||
if target is not None and (
|
||||
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
|
||||
):
|
||||
policy.ensure_writable(target)
|
||||
return original_plt_savefig(*args, **kwargs)
|
||||
|
||||
def guarded_figure_savefig(fig, fname, *args, **kwargs):
|
||||
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
|
||||
policy.ensure_writable(fname)
|
||||
return original_figure_savefig(fig, fname, *args, **kwargs)
|
||||
|
||||
builtins.open = guarded_open
|
||||
pd.read_csv = guarded_read_csv
|
||||
pd.read_excel = guarded_read_excel
|
||||
pd.DataFrame.to_csv = guarded_to_csv
|
||||
pd.DataFrame.to_excel = guarded_to_excel
|
||||
plt.savefig = guarded_savefig
|
||||
matplotlib.figure.Figure.savefig = guarded_figure_savefig
|
||||
|
||||
self._patches_installed = True
|
||||
|
||||
def _require_executor(self) -> None:
|
||||
if self.executor is None:
|
||||
raise WorkerProtocolError("执行会话尚未初始化")
|
||||
|
||||
def _handle_set_variable(self, name: str, value: Any) -> None:
|
||||
self.executor.set_variable(name, value)
|
||||
|
||||
if name == "allowed_files":
|
||||
self.access_policy.configure(
|
||||
value,
|
||||
self.access_policy.allowed_write_root,
|
||||
self.access_policy.allowed_read_roots,
|
||||
)
|
||||
elif name == "allowed_read_roots":
|
||||
self.access_policy.configure(
|
||||
self.access_policy.allowed_reads,
|
||||
self.access_policy.allowed_write_root,
|
||||
value,
|
||||
)
|
||||
elif name == "session_output_dir":
|
||||
self.access_policy.configure(
|
||||
self.access_policy.allowed_reads,
|
||||
value,
|
||||
self.access_policy.allowed_read_roots,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "ok",
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
worker = ExecutionWorker()
|
||||
|
||||
for raw_line in sys.stdin:
|
||||
raw_line = raw_line.strip()
|
||||
if not raw_line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(raw_line)
|
||||
except json.JSONDecodeError as exc:
|
||||
_write_message(
|
||||
{
|
||||
"request_id": "",
|
||||
"status": "error",
|
||||
"error": f"无效JSON请求: {exc}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
captured_stdout = StringIO()
|
||||
captured_stderr = StringIO()
|
||||
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
|
||||
response = worker.handle_request(request)
|
||||
|
||||
_write_log(captured_stdout.getvalue())
|
||||
_write_log(captured_stderr.getvalue())
|
||||
_write_message(response)
|
||||
|
||||
if request.get("type") == "shutdown" and response.get("status") == "ok":
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user