From be94ca88a2398644a19e30a86dc812dbb94c0593 Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Mon, 23 Sep 2024 19:43:49 +0300
Subject: [PATCH] fix: RouteBuilder::register now only supports Fn, not FnMut

See doc comments for RouteBuilder::register, but TL;DR it's never safe
to wrap a yielding function in a FnMut.
The user should instead do the regular internal mutability thing.
---
 picoplugin/src/transport/rpc/server.rs | 31 +++++++++++++++++---------
 1 file changed, 20 insertions(+), 11 deletions(-)

diff --git a/picoplugin/src/transport/rpc/server.rs b/picoplugin/src/transport/rpc/server.rs
index ee2cc450fa..39b6c0d5c5 100644
--- a/picoplugin/src/transport/rpc/server.rs
+++ b/picoplugin/src/transport/rpc/server.rs
@@ -68,10 +68,19 @@ impl<'a> RouteBuilder<'a> {
 
     /// Register the RPC endpoint with the currently chosen parameters and the
     /// provided handler.
+    ///
+    /// Note that `f` must implement `Fn`. This is required by rust's semantics
+    /// to allow the RPC handlers to yield. If a handler yields then another
+    /// concurrent RPC request may result in the same handler being executed,
+    /// so we must not hold any `&mut` references in those closures (other than
+    /// ones allowed by rust semantics, see official reference on undefined
+    /// behaviour [here]).
+    ///
+    /// [here]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
     #[track_caller]
     pub fn register<F>(self, f: F) -> Result<(), BoxError>
     where
-        F: FnMut(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
+        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
     {
         let Some(path) = self.path else {
             #[rustfmt::skip]
@@ -84,7 +93,8 @@ impl<'a> RouteBuilder<'a> {
             service: self.service.into(),
             version: self.version.into(),
         };
-        if let Err(e) = register_rpc_handler(&identifier, f) {
+        let handler = FfiRpcHandler::new(&identifier, f);
+        if let Err(e) = register_rpc_handler(handler) {
             // Note: recreating the error to capture the caller's source location
             #[rustfmt::skip]
             return Err(BoxError::new(e.error_code(), e.message()));
@@ -118,12 +128,7 @@ pub struct FfiRpcRouteIdentifier {
 
 /// **For internal use**.
 #[inline]
-fn register_rpc_handler<F>(identifier: &FfiRpcRouteIdentifier, f: F) -> Result<(), BoxError>
-where
-    F: FnMut(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
-{
-    let handler = FfiRpcHandler::new(identifier, f);
-
+fn register_rpc_handler(handler: FfiRpcHandler) -> Result<(), BoxError> {
     // This is safe.
     let rc = unsafe { ffi::pico_ffi_register_rpc_handler(handler) };
     if rc == -1 {
@@ -152,6 +157,10 @@ pub struct FfiRpcHandler {
     callback: RpcHandlerCallback,
     drop: extern "C" fn(*mut FfiRpcHandler),
 
+    /// The pointer to the closure object.
+    ///
+    /// Note that the pointer must be `mut` because we will at some point drop the data pointed to by it.
+    /// But when calling the closure, the `const` pointer should be used.
     closure_pointer: *mut (),
 
     /// Points into [`Self::string_storage`].
@@ -181,7 +190,7 @@ impl Drop for FfiRpcHandler {
 impl FfiRpcHandler {
     fn new<F>(identifier: &FfiRpcRouteIdentifier, f: F) -> Self
     where
-        F: FnMut(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
+        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
     {
         let closure = Box::new(f);
         let closure_pointer: *mut F = Box::into_raw(closure);
@@ -250,10 +259,10 @@ impl FfiRpcHandler {
         output: *mut FfiSafeBytes,
     ) -> std::ffi::c_int
     where
-        F: FnMut(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
+        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
     {
         // This is safe. To verify see `register_rpc_handler` above.
-        let closure_pointer: *mut F = unsafe { (*handler).closure_pointer.cast::<F>() };
+        let closure_pointer: *const F = unsafe { (*handler).closure_pointer.cast::<F>() };
         let closure = unsafe { &*closure_pointer };
         let input = unsafe { input.as_bytes() };
         let context = unsafe { &*context };
-- 
GitLab