Skip to content

Commit

Permalink
Don't store and use wrong main Lua state in module mode (Lua 5.1/JIT …
Browse files Browse the repository at this point in the history
…only).

When mlua module is loaded from a non-main coroutine we store a reference to it to use later.
If the coroutine is destroyed by GC we can pass a wrong pointer to Lua that will trigger a segfault.
Instead, set main_state as Option and use current (active) state if needed.
Relates to #479
  • Loading branch information
khvzak committed Nov 4, 2024
1 parent b34d67e commit 05778fb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
53 changes: 28 additions & 25 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ impl Lua {
let lua = self.lock();
unsafe {
if (*lua.extra.get()).sandboxed != enabled {
let state = lua.main_state;
let state = lua.main_state();
check_stack(state, 3)?;
protect_lua!(state, 0, 0, |state| {
if enabled {
Expand Down Expand Up @@ -562,10 +562,10 @@ impl Lua {
unsafe {
let state = lua.state();
ffi::lua_sethook(state, None, 0, 0);
match crate::util::get_main_state(lua.main_state) {
Some(main_state) if !ptr::eq(state, main_state) => {
match lua.main_state {
Some(main_state) if state != main_state.as_ptr() => {
// If main_state is different from state, remove hook from it too
ffi::lua_sethook(main_state, None, 0, 0);
ffi::lua_sethook(main_state.as_ptr(), None, 0, 0);
}
_ => {}
};
Expand Down Expand Up @@ -654,7 +654,7 @@ impl Lua {
let lua = self.lock();
unsafe {
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
(*ffi::lua_callbacks(lua.main_state)).interrupt = Some(interrupt_proc);
(*ffi::lua_callbacks(lua.main_state())).interrupt = Some(interrupt_proc);
}
}

Expand All @@ -667,7 +667,7 @@ impl Lua {
let lua = self.lock();
unsafe {
(*lua.extra.get()).interrupt_callback = None;
(*ffi::lua_callbacks(lua.main_state)).interrupt = None;
(*ffi::lua_callbacks(lua.main_state())).interrupt = None;
}
}

Expand Down Expand Up @@ -697,10 +697,9 @@ impl Lua {
}

let lua = self.lock();
let state = lua.main_state;
unsafe {
(*lua.extra.get()).warn_callback = Some(Box::new(callback));
ffi::lua_setwarnf(state, Some(warn_proc), lua.extra.get() as *mut c_void);
ffi::lua_setwarnf(lua.state(), Some(warn_proc), lua.extra.get() as *mut c_void);
}
}

Expand All @@ -715,7 +714,7 @@ impl Lua {
let lua = self.lock();
unsafe {
(*lua.extra.get()).warn_callback = None;
ffi::lua_setwarnf(lua.main_state, None, ptr::null_mut());
ffi::lua_setwarnf(lua.state(), None, ptr::null_mut());
}
}

Expand Down Expand Up @@ -767,13 +766,14 @@ impl Lua {
/// Returns the amount of memory (in bytes) currently used inside this Lua state.
pub fn used_memory(&self) -> usize {
let lua = self.lock();
let state = lua.main_state();
unsafe {
match MemoryState::get(lua.main_state) {
match MemoryState::get(state) {
mem_state if !mem_state.is_null() => (*mem_state).used_memory(),
_ => {
// Get data from the Lua GC
let used_kbytes = ffi::lua_gc(lua.main_state, ffi::LUA_GCCOUNT, 0);
let used_kbytes_rem = ffi::lua_gc(lua.main_state, ffi::LUA_GCCOUNTB, 0);
let used_kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0);
let used_kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0);
(used_kbytes as usize) * 1024 + (used_kbytes_rem as usize)
}
}
Expand All @@ -790,7 +790,7 @@ impl Lua {
pub fn set_memory_limit(&self, limit: usize) -> Result<usize> {
let lua = self.lock();
unsafe {
match MemoryState::get(lua.main_state) {
match MemoryState::get(lua.state()) {
mem_state if !mem_state.is_null() => Ok((*mem_state).set_memory_limit(limit)),
_ => Err(Error::MemoryControlNotAvailable),
}
Expand All @@ -803,19 +803,19 @@ impl Lua {
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "luau"))]
pub fn gc_is_running(&self) -> bool {
let lua = self.lock();
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCISRUNNING, 0) != 0 }
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCISRUNNING, 0) != 0 }
}

/// Stop the Lua GC from running
pub fn gc_stop(&self) {
let lua = self.lock();
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCSTOP, 0) };
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCSTOP, 0) };
}

/// Restarts the Lua GC if it is not running
pub fn gc_restart(&self) {
let lua = self.lock();
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCRESTART, 0) };
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCRESTART, 0) };
}

/// Perform a full garbage-collection cycle.
Expand All @@ -824,9 +824,10 @@ impl Lua {
/// objects. Once to finish the current gc cycle, and once to start and finish the next cycle.
pub fn gc_collect(&self) -> Result<()> {
let lua = self.lock();
let state = lua.main_state();
unsafe {
check_stack(lua.main_state, 2)?;
protect_lua!(lua.main_state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0))
check_stack(state, 2)?;
protect_lua!(state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0))
}
}

Expand All @@ -843,9 +844,10 @@ impl Lua {
/// finished a collection cycle.
pub fn gc_step_kbytes(&self, kbytes: c_int) -> Result<bool> {
let lua = self.lock();
let state = lua.main_state();
unsafe {
check_stack(lua.main_state, 3)?;
protect_lua!(lua.main_state, 0, 0, |state| {
check_stack(state, 3)?;
protect_lua!(state, 0, 0, |state| {
ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0
})
}
Expand All @@ -861,11 +863,12 @@ impl Lua {
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5
pub fn gc_set_pause(&self, pause: c_int) -> c_int {
let lua = self.lock();
let state = lua.main_state();
unsafe {
#[cfg(not(feature = "luau"))]
return ffi::lua_gc(lua.main_state, ffi::LUA_GCSETPAUSE, pause);
return ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, pause);
#[cfg(feature = "luau")]
return ffi::lua_gc(lua.main_state, ffi::LUA_GCSETGOAL, pause);
return ffi::lua_gc(state, ffi::LUA_GCSETGOAL, pause);
}
}

Expand All @@ -877,7 +880,7 @@ impl Lua {
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5
pub fn gc_set_step_multiplier(&self, step_multiplier: c_int) -> c_int {
let lua = self.lock();
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCSETSTEPMUL, step_multiplier) }
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCSETSTEPMUL, step_multiplier) }
}

/// Changes the collector to incremental mode with the given parameters.
Expand All @@ -888,7 +891,7 @@ impl Lua {
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5.1
pub fn gc_inc(&self, pause: c_int, step_multiplier: c_int, step_size: c_int) -> GCMode {
let lua = self.lock();
let state = lua.main_state;
let state = lua.main_state();

#[cfg(any(
feature = "lua53",
Expand Down Expand Up @@ -941,7 +944,7 @@ impl Lua {
#[cfg_attr(docsrs, doc(cfg(feature = "lua54")))]
pub fn gc_gen(&self, minor_multiplier: c_int, major_multiplier: c_int) -> GCMode {
let lua = self.lock();
let state = lua.main_state;
let state = lua.main_state();
let prev_mode = unsafe { ffi::lua_gc(state, ffi::LUA_GCGEN, minor_multiplier, major_multiplier) };
match prev_mode {
ffi::LUA_GCGEN => GCMode::Generational,
Expand Down
22 changes: 12 additions & 10 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::any::TypeId;
use std::cell::{Cell, UnsafeCell};
use std::ffi::{CStr, CString};
use std::mem;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::resume_unwind;
use std::ptr::{self, NonNull};
use std::result::Result as StdResult;
use std::sync::Arc;
use std::{mem, ptr};

use crate::chunk::ChunkMode;
use crate::error::{Error, Result};
Expand Down Expand Up @@ -41,7 +42,6 @@ use {
crate::multi::MultiValue,
crate::traits::FromLuaMulti,
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
std::ptr::NonNull,
std::task::{Context, Poll, Waker},
};

Expand All @@ -50,7 +50,7 @@ use {
pub struct RawLua {
// The state is dynamic and depends on context
pub(super) state: Cell<*mut ffi::lua_State>,
pub(super) main_state: *mut ffi::lua_State,
pub(super) main_state: Option<NonNull<ffi::lua_State>>,
pub(super) extra: XRc<UnsafeCell<ExtraData>>,
}

Expand All @@ -61,9 +61,9 @@ impl Drop for RawLua {
return;
}

let mem_state = MemoryState::get(self.main_state);
let mem_state = MemoryState::get(self.main_state());

ffi::lua_close(self.main_state);
ffi::lua_close(self.main_state());

// Deallocate `MemoryState`
if !mem_state.is_null() {
Expand Down Expand Up @@ -95,10 +95,11 @@ impl RawLua {
self.state.get()
}

#[cfg(feature = "luau")]
#[inline(always)]
pub(crate) fn main_state(&self) -> *mut ffi::lua_State {
self.main_state
.map(|state| state.as_ptr())
.unwrap_or_else(|| self.state())
}

#[inline(always)]
Expand Down Expand Up @@ -221,7 +222,8 @@ impl RawLua {
#[allow(clippy::arc_with_non_send_sync)]
let rawlua = XRc::new(ReentrantMutex::new(RawLua {
state: Cell::new(state),
main_state,
// Make sure that we don't store current state as main state (if it's not available)
main_state: get_main_state(state).and_then(NonNull::new),
extra: XRc::clone(&extra),
}));
(*extra.get()).set_lua(&rawlua);
Expand Down Expand Up @@ -263,7 +265,7 @@ impl RawLua {
));
}

let res = load_std_libs(self.main_state, libs);
let res = load_std_libs(self.main_state(), libs);

// If `package` library loaded into a safe lua state then disable C modules
let curr_libs = (*self.extra.get()).libs;
Expand Down Expand Up @@ -734,7 +736,7 @@ impl RawLua {
}

// MemoryInfo is empty in module mode so we cannot predict memory limits
match MemoryState::get(self.main_state) {
match MemoryState::get(self.state()) {
mem_state if !mem_state.is_null() => (*mem_state).memory_limit() == 0,
_ => (*self.extra.get()).skip_memory_check, // Check the special flag (only for module mode)
}
Expand Down Expand Up @@ -1095,7 +1097,7 @@ impl RawLua {
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "luau"))]
unsafe {
if !(*self.extra.get()).libs.contains(StdLib::COROUTINE) {
load_std_libs(self.main_state, StdLib::COROUTINE)?;
load_std_libs(self.main_state(), StdLib::COROUTINE)?;
(*self.extra.get()).libs |= StdLib::COROUTINE;
}
}
Expand Down

0 comments on commit 05778fb

Please sign in to comment.