diff --git a/src/main.rs b/src/main.rs index 19d32f5..ab5dfc6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ use anyhow; use clap::Parser; -use pyo3::{prelude::*, types::PyBool}; -use std::{rc::Rc, sync::Mutex, thread}; -use tokio::sync::{mpsc, oneshot}; +use pyo3::prelude::*; +use std::thread; +use tokio::sync::oneshot; #[derive(Parser, Debug)] struct Args {} @@ -38,75 +38,24 @@ mod scripts { }; } -#[derive(Clone)] -#[pyclass] -struct OpaqueHandle { - handle: tokio::runtime::Handle, -} - -/// setup_extension creates a Python module consisting of members that are defined -/// in the farore host process (in Rust). -fn setup_extension(bridge: Bridge) -> Result, PyErr> { - Python::with_gil(|py| { - let bound = py.get_type_bound::(); - let module = PyModule::from_code_bound(py, "", "farore.py", "farore")?; - module.add( - "farore", - Core { - handle: bridge.handle.clone(), - }, - )?; - Ok(module.unbind()) - }) -} - -#[pyclass] -struct Bridge { - handle: tokio::runtime::Handle, - numbers: mpsc::Receiver, -} - -#[pymethods] -impl Bridge { - fn receive(&mut self, fut: Bound) -> Result<(), PyErr> { - println!("Rust: Bridge receive value: {}", fut); - let fut = fut.unbind(); - self.handle.spawn(async move { - println!("Rust: receive task going to sleep"); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - println!("Rust: receive task finished sleep"); - match Python::with_gil(|py| -> Result<(), PyErr> { - let main_module = py.import_bound("main")?; - println!("Rust: receive task has gil"); - let fut = fut.bind(py); - fut.setattr("_asyncio_future_blocking", true); - fut.getattr("set_result")?.call1((5,))?; - println!("Rust: receive task set the result of a future"); - Ok(()) - }) { - Ok(()) => {} - Err(e) => println!("Rust: saw python error: {}", e), - } - }); - Ok(()) - } -} - /// Creates an asyncio Future and schedules a task in asyncio to wait on that /// Future. fn create_asyncio_future() -> Result<(Py, Py), PyErr> { Python::with_gil(|py| { - let main = py.import_bound("main")?; - let run = main.getattr("create_later_slot")?; - let later = run.call0()?.unbind(); + let later = py + .import_bound("main") + .and_then(|main| main.getattr("Later")) + .and_then(|class| class.call0()) + .map(|later| later.unbind())?; let dupe = Py::clone_ref(&later, py); Ok((later, dupe)) }) } -#[pyclass] +#[pyclass(name = "Farore")] struct Core { handle: tokio::runtime::Handle, + tx_shutdown: Option>, } #[pymethods] @@ -114,58 +63,74 @@ impl Core { fn next_message(&self) -> Result, PyErr> { let (here, there) = create_asyncio_future()?; self.handle.spawn(async move { - println!("Going to sleep from the resolve function I guess"); tokio::time::sleep(std::time::Duration::from_millis(500)).await; - println!("Finished sleeping in the resolve function"); + let py_result = Python::with_gil(|py| -> Result<(), PyErr> { - println!("We've got the gil in the resolve function"); let later = here.bind(py); let set_result = later.getattr("set_result")?; set_result.call1((17,))?; - Ok(()) }); match py_result { Ok(()) => {} Err(e) => { - println!("Failed to complete future: {}", e); + println!("Rust: Failed to complete future: {}", e); } } }); Ok(there) } + + fn shutdown(&mut self) { + match self.tx_shutdown.take() { + Some(tx) => { + match tx.send(()) { + Ok(()) => println!("Rust: Propagated shutdown to the tokio runtime"), + Err(_) => { + println!("Rust: Failed to propagate shutdown signal because nobody was listening") + } + } + } + None => { + println!("Rust: Ignoring request to shutdown: shutdown already called"); + } + } + } } -async fn tokio_main(tx: mpsc::Sender) { - println!("Rust: Inside the tokio runtime"); - tokio::time::sleep(std::time::Duration::from_secs(100)).await; - println!("Rust: Tokio runtime is done?"); +async fn tokio_main(rx_shutdown: oneshot::Receiver<()>) { + match rx_shutdown.await { + Ok(_) => { + println!("Rust: Received shutdown request"); + } + Err(e) => { + println!("Rust: Unable to receive shutdown request: {}", e); + } + } } -fn start_tokio(tx_bridge: oneshot::Sender) -> Result<(), anyhow::Error> { +fn start_tokio() -> Result { + let runtime = tokio::runtime::Runtime::new()?; + let (tx_shutdown, rx_shutdown) = oneshot::channel(); + let core = Core { + handle: runtime.handle().clone(), + tx_shutdown: Some(tx_shutdown), + }; + + // Spawn a new thread to run the tokio runtime thread::spawn(move || { - // Start the runtime - let runtime = tokio::runtime::Runtime::new().unwrap(); - - // Create a bridge between the runtimes - let (tx, rx) = mpsc::channel(16); - let bridge = Bridge { - handle: runtime.handle().clone(), - numbers: rx, - }; - let _ = tx_bridge - .send(bridge) - .inspect_err(|_| println!("Rust: failed to send bridge")); - - runtime.block_on(tokio_main(tx)); + runtime.block_on(tokio_main(rx_shutdown)); }); - Ok(()) + Ok(core) } -fn run_python(bridge: Bridge) -> Result, PyErr> { - setup_extension(bridge)?; - +fn run_python(core: Core) -> Result, PyErr> { Python::with_gil(|py| { + // Initialize the farore module in python + let module = PyModule::from_code_bound(py, "", "farore.py", "farore")?; + module.add("farore", core)?; + + // Run the main.py script in the python interpreter let main = scripts::main.load(py)?; let run = main.getattr("run")?; let returned = run.call0()?; @@ -174,10 +139,8 @@ fn run_python(bridge: Bridge) -> Result, PyErr> { } fn main() -> Result<(), anyhow::Error> { - let (tx, rx) = oneshot::channel(); - start_tokio(tx)?; - let bridge = rx.blocking_recv()?; - let v = run_python(bridge)?; + let core = start_tokio()?; + let v = run_python(core)?; println!("Rust: Python interpreter completed with result: {}", v); Ok(()) } diff --git a/src/python/farore.pyi b/src/python/farore.pyi index 6e4d2d3..bc8914b 100644 --- a/src/python/farore.pyi +++ b/src/python/farore.pyi @@ -1,18 +1,7 @@ -from collections.abc import Awaitable from typing import Any -class Bridge: - def receive(self, fut): ... - -bridge: Bridge - -class OpaqueHandle: ... - -tokio_handle: OpaqueHandle - -def next_message(h: OpaqueHandle): ... - class Core: - async def next_message(self) -> Awaitable[Any]: ... + async def next_message(self) -> Any: ... + def shutdown(self): ... farore: Core diff --git a/src/python/main.py b/src/python/main.py index 9697a00..fc827ec 100644 --- a/src/python/main.py +++ b/src/python/main.py @@ -4,19 +4,15 @@ import logging logging.basicConfig(level=logging.DEBUG) -log = logging.getLogger("import") -log.debug(f"Python: farore module: {farore}") - - -async def wait_on(fut: asyncio.Future): - return await fut - class Later: + """Later is a wrapper around an asyncio.Future that can be passed to another + thread so that we can set the result of a future from another thread, + potentially from within the Tokio runtime""" + def __init__(self): loop = asyncio.get_event_loop() self.fut = loop.create_future() - self.task = loop.create_task(wait_on(self.fut)) def set_result(self, v): def f(): @@ -26,20 +22,20 @@ class Later: loop.call_soon_threadsafe(f) def __await__(self): - return self.task.__await__() - - -def create_later_slot() -> Later: - return Later() + return self.fut.__await__() async def main_task(): log = logging.getLogger("main_task") - while True: - log.debug("Getting next message ...") + count = 0 + while count < 4: m = await farore.next_message() log.debug(f"next message: {m}") + count += 1 + log.debug("shutting down farore") + farore.shutdown() + log.debug("farore finished shutdown") def run(): - asyncio.run(main_task()) + return asyncio.run(main_task())