openzeppelin_relayer/services/plugins/
runner.rs

1//! This module is the orchestrator of the plugin execution.
2//!
3//! 1. Initiates a socket connection to the relayer server - socket.rs
4//! 2. Executes the plugin script - script_executor.rs
5//! 3. Sends the shutdown signal to the relayer server - socket.rs
6//! 4. Waits for the relayer server to finish the execution - socket.rs
7//! 5. Returns the output of the script - script_executor.rs
8//!
9use std::{sync::Arc, time::Duration};
10
11use crate::services::plugins::{RelayerApi, ScriptExecutor, ScriptResult, SocketService};
12use crate::{
13    jobs::JobProducerTrait,
14    models::{
15        NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel,
16        ThinDataAppState, TransactionRepoModel,
17    },
18    repositories::{
19        ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository,
20        Repository, TransactionCounterTrait, TransactionRepository,
21    },
22};
23
24use super::PluginError;
25use async_trait::async_trait;
26use tokio::{sync::oneshot, time::timeout};
27
28#[cfg(test)]
29use mockall::automock;
30
31#[cfg_attr(test, automock)]
32#[async_trait]
33pub trait PluginRunnerTrait {
34    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
35    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
36        &self,
37        plugin_id: String,
38        socket_path: &str,
39        script_path: String,
40        timeout_duration: Duration,
41        script_params: String,
42        http_request_id: Option<String>,
43        headers_json: Option<String>,
44        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
45    ) -> Result<ScriptResult, PluginError>
46    where
47        J: JobProducerTrait + Send + Sync + 'static,
48        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
49        TR: TransactionRepository
50            + Repository<TransactionRepoModel, String>
51            + Send
52            + Sync
53            + 'static,
54        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
55        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
56        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
57        TCR: TransactionCounterTrait + Send + Sync + 'static,
58        PR: PluginRepositoryTrait + Send + Sync + 'static,
59        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static;
60}
61
62#[derive(Default)]
63pub struct PluginRunner;
64
65#[async_trait]
66impl PluginRunnerTrait for PluginRunner {
67    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
68        &self,
69        plugin_id: String,
70        socket_path: &str,
71        script_path: String,
72        timeout_duration: Duration,
73        script_params: String,
74        http_request_id: Option<String>,
75        headers_json: Option<String>,
76        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
77    ) -> Result<ScriptResult, PluginError>
78    where
79        J: JobProducerTrait + Send + Sync + 'static,
80        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
81        TR: TransactionRepository
82            + Repository<TransactionRepoModel, String>
83            + Send
84            + Sync
85            + 'static,
86        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
87        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
88        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
89        TCR: TransactionCounterTrait + Send + Sync + 'static,
90        PR: PluginRepositoryTrait + Send + Sync + 'static,
91        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
92    {
93        let socket_service = SocketService::new(socket_path)?;
94        let socket_path_clone = socket_service.socket_path().to_string();
95
96        let (shutdown_tx, shutdown_rx) = oneshot::channel();
97
98        let server_handle = tokio::spawn(async move {
99            let relayer_api = Arc::new(RelayerApi);
100            socket_service.listen(shutdown_rx, state, relayer_api).await
101        });
102
103        let exec_outcome = match timeout(
104            timeout_duration,
105            ScriptExecutor::execute_typescript(
106                plugin_id,
107                script_path,
108                socket_path_clone,
109                script_params,
110                http_request_id,
111                headers_json,
112            ),
113        )
114        .await
115        {
116            Ok(result) => result,
117            Err(_) => {
118                // ensures the socket gets closed.
119                let _ = shutdown_tx.send(());
120                return Err(PluginError::ScriptTimeout(timeout_duration.as_secs()));
121            }
122        };
123
124        let _ = shutdown_tx.send(());
125
126        let server_handle = server_handle
127            .await
128            .map_err(|e| PluginError::SocketError(e.to_string()))?;
129
130        let traces = match server_handle {
131            Ok(traces) => traces,
132            Err(e) => return Err(PluginError::SocketError(e.to_string())),
133        };
134
135        match exec_outcome {
136            Ok(mut script_result) => {
137                // attach traces on success
138                script_result.trace = traces;
139                Ok(script_result)
140            }
141            Err(err) => Err(err.with_traces(traces)),
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use actix_web::web;
149    use std::fs;
150
151    use crate::{
152        jobs::MockJobProducerTrait,
153        repositories::{
154            ApiKeyRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage,
155            PluginRepositoryStorage, RelayerRepositoryStorage, SignerRepositoryStorage,
156            TransactionCounterRepositoryStorage, TransactionRepositoryStorage,
157        },
158        services::plugins::LogLevel,
159        utils::mocks::mockutils::create_mock_app_state,
160    };
161    use tempfile::tempdir;
162
163    use super::*;
164
165    static TS_CONFIG: &str = r#"
166        {
167            "compilerOptions": {
168              "target": "es2016",
169              "module": "commonjs",
170              "esModuleInterop": true,
171              "forceConsistentCasingInFileNames": true,
172              "strict": true,
173              "skipLibCheck": true
174            }
175          }
176    "#;
177
178    #[tokio::test]
179    async fn test_run() {
180        let temp_dir = tempdir().unwrap();
181        let ts_config = temp_dir.path().join("tsconfig.json");
182        let script_path = temp_dir.path().join("test_run.ts");
183        let socket_path = temp_dir.path().join("test_run.sock");
184
185        let content = r#"
186            export async function handler(api: any, params: any) {
187                console.log('test');
188                console.error('test-error');
189                return 'test-result';
190            }
191        "#;
192        fs::write(script_path.clone(), content).unwrap();
193        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
194
195        let state = create_mock_app_state(None, None, None, None, None, None).await;
196
197        let plugin_runner = PluginRunner;
198        let plugin_id = "test-plugin".to_string();
199        let socket_path_str = socket_path.display().to_string();
200        let script_path_str = script_path.display().to_string();
201        let result = plugin_runner
202            .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage, ApiKeyRepositoryStorage>(
203                plugin_id,
204                &socket_path_str,
205                script_path_str,
206                Duration::from_secs(10),
207                "{ \"test\": \"test\" }".to_string(),
208                None,
209                None,
210                Arc::new(web::ThinData(state)),
211            )
212            .await;
213        if matches!(
214            result,
215            Err(PluginError::SocketError(ref msg)) if msg.contains("Operation not permitted")
216        ) {
217            eprintln!("skipping test_run due to sandbox socket restrictions");
218            return;
219        }
220
221        let result = result.expect("runner should complete without error");
222        assert_eq!(result.logs[0].level, LogLevel::Log);
223        assert_eq!(result.logs[0].message, "test");
224        assert_eq!(result.logs[1].level, LogLevel::Error);
225        assert_eq!(result.logs[1].message, "test-error");
226        assert_eq!(result.return_value, "test-result");
227    }
228
229    #[tokio::test]
230    async fn test_run_timeout() {
231        let temp_dir = tempdir().unwrap();
232        let ts_config = temp_dir.path().join("tsconfig.json");
233        let script_path = temp_dir.path().join("test_simple_timeout.ts");
234        let socket_path = temp_dir.path().join("test_simple_timeout.sock");
235
236        // Script that takes 200ms
237        let content = r#"
238            function sleep(ms) {
239                return new Promise(resolve => setTimeout(resolve, ms));
240            }
241
242            async function main() {
243                await sleep(200); // 200ms
244                console.log(JSON.stringify({ level: 'result', message: 'Should not reach here' }));
245            }
246
247            main();
248        "#;
249
250        fs::write(script_path.clone(), content).unwrap();
251        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
252
253        let state = create_mock_app_state(None, None, None, None, None, None).await;
254        let plugin_runner = PluginRunner;
255
256        // Use 100ms timeout for a 200ms script
257        let plugin_id = "test-plugin".to_string();
258        let socket_path_str = socket_path.display().to_string();
259        let script_path_str = script_path.display().to_string();
260        let result = plugin_runner
261            .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage, ApiKeyRepositoryStorage>(
262                plugin_id,
263                &socket_path_str,
264                script_path_str,
265                Duration::from_millis(100), // 100ms timeout
266                "{}".to_string(),
267                None,
268                None,
269                Arc::new(web::ThinData(state)),
270            )
271            .await;
272
273        // Should timeout
274        if matches!(
275            result,
276            Err(PluginError::SocketError(ref msg)) if msg.contains("Operation not permitted")
277        ) {
278            eprintln!("skipping test_run_timeout due to sandbox socket restrictions");
279            return;
280        }
281
282        let err = result.expect_err("runner should timeout");
283        assert!(err.to_string().contains("Script execution timed out after"));
284    }
285}