cleaning up

main
Jordan Orelli 1 month ago
parent 1666a38ad8
commit d33ed584cf

@ -1,8 +1,8 @@
use anyhow; use anyhow;
use clap::Parser; use clap::Parser;
use pyo3::{prelude::*, types::PyBool}; use pyo3::prelude::*;
use std::{rc::Rc, sync::Mutex, thread}; use std::thread;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::oneshot;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
struct Args {} struct Args {}
@ -38,75 +38,24 @@ mod scripts {
}; };
} }
#[derive(Clone)]
#[pyclass]
struct OpaqueHandle {
handle: tokio::runtime::Handle,
}
/// setup_extension creates a Python module consisting of members that are defined
/// in the farore host process (in Rust).
fn setup_extension(bridge: Bridge) -> Result<Py<PyModule>, PyErr> {
Python::with_gil(|py| {
let bound = py.get_type_bound::<Bridge>();
let module = PyModule::from_code_bound(py, "", "farore.py", "farore")?;
module.add(
"farore",
Core {
handle: bridge.handle.clone(),
},
)?;
Ok(module.unbind())
})
}
#[pyclass]
struct Bridge {
handle: tokio::runtime::Handle,
numbers: mpsc::Receiver<usize>,
}
#[pymethods]
impl Bridge {
fn receive(&mut self, fut: Bound<PyAny>) -> Result<(), PyErr> {
println!("Rust: Bridge receive value: {}", fut);
let fut = fut.unbind();
self.handle.spawn(async move {
println!("Rust: receive task going to sleep");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
println!("Rust: receive task finished sleep");
match Python::with_gil(|py| -> Result<(), PyErr> {
let main_module = py.import_bound("main")?;
println!("Rust: receive task has gil");
let fut = fut.bind(py);
fut.setattr("_asyncio_future_blocking", true);
fut.getattr("set_result")?.call1((5,))?;
println!("Rust: receive task set the result of a future");
Ok(())
}) {
Ok(()) => {}
Err(e) => println!("Rust: saw python error: {}", e),
}
});
Ok(())
}
}
/// Creates an asyncio Future and schedules a task in asyncio to wait on that /// Creates an asyncio Future and schedules a task in asyncio to wait on that
/// Future. /// Future.
fn create_asyncio_future() -> Result<(Py<PyAny>, Py<PyAny>), PyErr> { fn create_asyncio_future() -> Result<(Py<PyAny>, Py<PyAny>), PyErr> {
Python::with_gil(|py| { Python::with_gil(|py| {
let main = py.import_bound("main")?; let later = py
let run = main.getattr("create_later_slot")?; .import_bound("main")
let later = run.call0()?.unbind(); .and_then(|main| main.getattr("Later"))
.and_then(|class| class.call0())
.map(|later| later.unbind())?;
let dupe = Py::clone_ref(&later, py); let dupe = Py::clone_ref(&later, py);
Ok((later, dupe)) Ok((later, dupe))
}) })
} }
#[pyclass] #[pyclass(name = "Farore")]
struct Core { struct Core {
handle: tokio::runtime::Handle, handle: tokio::runtime::Handle,
tx_shutdown: Option<oneshot::Sender<()>>,
} }
#[pymethods] #[pymethods]
@ -114,58 +63,74 @@ impl Core {
fn next_message(&self) -> Result<Py<PyAny>, PyErr> { fn next_message(&self) -> Result<Py<PyAny>, PyErr> {
let (here, there) = create_asyncio_future()?; let (here, there) = create_asyncio_future()?;
self.handle.spawn(async move { self.handle.spawn(async move {
println!("Going to sleep from the resolve function I guess");
tokio::time::sleep(std::time::Duration::from_millis(500)).await; tokio::time::sleep(std::time::Duration::from_millis(500)).await;
println!("Finished sleeping in the resolve function");
let py_result = Python::with_gil(|py| -> Result<(), PyErr> { let py_result = Python::with_gil(|py| -> Result<(), PyErr> {
println!("We've got the gil in the resolve function");
let later = here.bind(py); let later = here.bind(py);
let set_result = later.getattr("set_result")?; let set_result = later.getattr("set_result")?;
set_result.call1((17,))?; set_result.call1((17,))?;
Ok(()) Ok(())
}); });
match py_result { match py_result {
Ok(()) => {} Ok(()) => {}
Err(e) => { Err(e) => {
println!("Failed to complete future: {}", e); println!("Rust: Failed to complete future: {}", e);
} }
} }
}); });
Ok(there) Ok(there)
} }
fn shutdown(&mut self) {
match self.tx_shutdown.take() {
Some(tx) => {
match tx.send(()) {
Ok(()) => println!("Rust: Propagated shutdown to the tokio runtime"),
Err(_) => {
println!("Rust: Failed to propagate shutdown signal because nobody was listening")
}
}
}
None => {
println!("Rust: Ignoring request to shutdown: shutdown already called");
}
}
}
} }
async fn tokio_main(tx: mpsc::Sender<usize>) { async fn tokio_main(rx_shutdown: oneshot::Receiver<()>) {
println!("Rust: Inside the tokio runtime"); match rx_shutdown.await {
tokio::time::sleep(std::time::Duration::from_secs(100)).await; Ok(_) => {
println!("Rust: Tokio runtime is done?"); println!("Rust: Received shutdown request");
}
Err(e) => {
println!("Rust: Unable to receive shutdown request: {}", e);
}
}
} }
fn start_tokio(tx_bridge: oneshot::Sender<Bridge>) -> Result<(), anyhow::Error> { fn start_tokio() -> Result<Core, anyhow::Error> {
let runtime = tokio::runtime::Runtime::new()?;
let (tx_shutdown, rx_shutdown) = oneshot::channel();
let core = Core {
handle: runtime.handle().clone(),
tx_shutdown: Some(tx_shutdown),
};
// Spawn a new thread to run the tokio runtime
thread::spawn(move || { thread::spawn(move || {
// Start the runtime runtime.block_on(tokio_main(rx_shutdown));
let runtime = tokio::runtime::Runtime::new().unwrap();
// Create a bridge between the runtimes
let (tx, rx) = mpsc::channel(16);
let bridge = Bridge {
handle: runtime.handle().clone(),
numbers: rx,
};
let _ = tx_bridge
.send(bridge)
.inspect_err(|_| println!("Rust: failed to send bridge"));
runtime.block_on(tokio_main(tx));
}); });
Ok(()) Ok(core)
} }
fn run_python(bridge: Bridge) -> Result<Py<PyAny>, PyErr> { fn run_python(core: Core) -> Result<Py<PyAny>, PyErr> {
setup_extension(bridge)?;
Python::with_gil(|py| { Python::with_gil(|py| {
// Initialize the farore module in python
let module = PyModule::from_code_bound(py, "", "farore.py", "farore")?;
module.add("farore", core)?;
// Run the main.py script in the python interpreter
let main = scripts::main.load(py)?; let main = scripts::main.load(py)?;
let run = main.getattr("run")?; let run = main.getattr("run")?;
let returned = run.call0()?; let returned = run.call0()?;
@ -174,10 +139,8 @@ fn run_python(bridge: Bridge) -> Result<Py<PyAny>, PyErr> {
} }
fn main() -> Result<(), anyhow::Error> { fn main() -> Result<(), anyhow::Error> {
let (tx, rx) = oneshot::channel(); let core = start_tokio()?;
start_tokio(tx)?; let v = run_python(core)?;
let bridge = rx.blocking_recv()?;
let v = run_python(bridge)?;
println!("Rust: Python interpreter completed with result: {}", v); println!("Rust: Python interpreter completed with result: {}", v);
Ok(()) Ok(())
} }

@ -1,18 +1,7 @@
from collections.abc import Awaitable
from typing import Any from typing import Any
class Bridge:
def receive(self, fut): ...
bridge: Bridge
class OpaqueHandle: ...
tokio_handle: OpaqueHandle
def next_message(h: OpaqueHandle): ...
class Core: class Core:
async def next_message(self) -> Awaitable[Any]: ... async def next_message(self) -> Any: ...
def shutdown(self): ...
farore: Core farore: Core

@ -4,19 +4,15 @@ import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger("import")
log.debug(f"Python: farore module: {farore}")
async def wait_on(fut: asyncio.Future):
return await fut
class Later: class Later:
"""Later is a wrapper around an asyncio.Future that can be passed to another
thread so that we can set the result of a future from another thread,
potentially from within the Tokio runtime"""
def __init__(self): def __init__(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.fut = loop.create_future() self.fut = loop.create_future()
self.task = loop.create_task(wait_on(self.fut))
def set_result(self, v): def set_result(self, v):
def f(): def f():
@ -26,20 +22,20 @@ class Later:
loop.call_soon_threadsafe(f) loop.call_soon_threadsafe(f)
def __await__(self): def __await__(self):
return self.task.__await__() return self.fut.__await__()
def create_later_slot() -> Later:
return Later()
async def main_task(): async def main_task():
log = logging.getLogger("main_task") log = logging.getLogger("main_task")
while True: count = 0
log.debug("Getting next message ...") while count < 4:
m = await farore.next_message() m = await farore.next_message()
log.debug(f"next message: {m}") log.debug(f"next message: {m}")
count += 1
log.debug("shutting down farore")
farore.shutdown()
log.debug("farore finished shutdown")
def run(): def run():
asyncio.run(main_task()) return asyncio.run(main_task())

Loading…
Cancel
Save