Skip to content

Commit

Permalink
use non null in squirrel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
catornot committed Mar 21, 2024
1 parent c031667 commit 0da0815
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 111 deletions.
14 changes: 7 additions & 7 deletions rrplug_proc/src/impl_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ pub fn get_from_sqvm_impl_struct(input: DeriveInput) -> TokenStream {
#[allow(clippy::not_unsafe_ptr_arg_deref)] // smth should be done about this
#[inline]
fn get_from_sqvm(
sqvm: *mut HSquirrelVM,
sqvm: std::ptr::NonNull<HSquirrelVM>,
sqfunctions: &SquirrelFunctions,
stack_pos: i32,
) -> Self {
use rrplug::{high::squirrel_traits::GetFromSQObject,bindings::squirreldatatypes::SQObject};
let sqstruct = unsafe {
let sqvm = sqvm.as_ref().expect("sqvm has to be valid");
let sqvm = sqvm.as_ref();
((*sqvm._stackOfCurrentFunction.add(stack_pos as usize))
._VAL
.asStructInstance)
Expand Down Expand Up @@ -82,12 +82,12 @@ pub fn push_to_sqvm_impl_struct(input: DeriveInput) -> TokenStream {
impl<#generics> PushToSquirrelVm for #ident<#generics> {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[inline]
fn push_to_sqvm(self, sqvm: *mut HSquirrelVM, sqfunctions: &SquirrelFunctions) {
fn push_to_sqvm(self, sqvm: std::ptr::NonNull<HSquirrelVM>, sqfunctions: &SquirrelFunctions) {
unsafe {
(sqfunctions.sq_pushnewstructinstance)(sqvm, #field_amount);
(sqfunctions.sq_pushnewstructinstance)(sqvm.as_ptr(), #field_amount);
#(
self.#field_idents.push_to_sqvm(sqvm,sqfunctions);
(sqfunctions.sq_sealstructslot)(sqvm, #field_amount_iter);
(sqfunctions.sq_sealstructslot)(sqvm.as_ptr(), #field_amount_iter);
)*
}
}
Expand Down Expand Up @@ -126,7 +126,7 @@ pub fn get_from_sqvm_impl_enum(input: DeriveInput) -> TokenStream {
#[inline]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn get_from_sqvm(
sqvm: *mut HSquirrelVM,
sqvm: std::ptr::NonNull<HSquirrelVM>,
sqfunctions: &SquirrelFunctions,
stack_pos: i32,
) -> Self {
Expand Down Expand Up @@ -158,7 +158,7 @@ pub fn push_to_sqvm_impl_enum(input: DeriveInput) -> TokenStream {
impl<#generics> PushToSquirrelVm for #ident<#generics> {
#[inline]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn push_to_sqvm(self, sqvm: *mut HSquirrelVM, sqfunctions: &SquirrelFunctions) {
fn push_to_sqvm(self, sqvm: std::ptr::NonNull<HSquirrelVM>, sqfunctions: &SquirrelFunctions) {
unsafe { rrplug::mid::squirrel::push_sq_int(sqvm, sqfunctions, self as i32) };
}
}
Expand Down
3 changes: 2 additions & 1 deletion rrplug_proc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ pub fn sqfunction(attr: TokenStream, item: TokenStream) -> TokenStream {
#[doc(hidden)]
#[doc = "generated ffi function for #func_name"]
#vis extern "C" fn #sq_functions_func (sqvm: *mut rrplug::bindings::squirreldatatypes::HSquirrelVM) -> rrplug::bindings::squirrelclasstypes::SQRESULT {
let sqvm = std::ptr::NonNull::new(sqvm).expect("sqvm has to be non null");
use rrplug::high::squirrel_traits::{GetFromSquirrelVm,ReturnToVm};
let sq_functions = SQFUNCTIONS.from_sqvm(sqvm);

#(#sub_stms)*

fn inner_function( sqvm: *mut rrplug::bindings::squirreldatatypes::HSquirrelVM, sq_functions: &'static SquirrelFunctions #(, #input_vec)* ) #output {
fn inner_function( sqvm: std::ptr::NonNull<rrplug::bindings::squirreldatatypes::HSquirrelVM>, sq_functions: &'static SquirrelFunctions #(, #input_vec)* ) #output {
let engine_token = unsafe { rrplug::high::engine::EngineToken::new_unchecked() };
#(#stmts)*
}
Expand Down
5 changes: 4 additions & 1 deletion src/high/engine_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ pub enum AsyncEngineMessage {
function_name: String,
/// the arguments that will passed to it via the closure (use `AsyncEngineMessage::run_squirrel_func`)
args: Box<
dyn FnOnce(*mut HSquirrelVM, &'static SquirrelFunctions) -> i32 + 'static + Send + Sync,
dyn FnOnce(NonNull<HSquirrelVM>, &'static SquirrelFunctions) -> i32
+ 'static
+ Send
+ Sync,
>,
},
/// contains a closure that will be executed once on the next engine frame
Expand Down
56 changes: 32 additions & 24 deletions src/high/squirrel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! squirrel vm related function and statics

use parking_lot::Mutex;
use std::marker::PhantomData;
use std::{marker::PhantomData, ptr::NonNull};

use super::{
squirrel_traits::{GetFromSQObject, IntoSquirrelArgs, IsSQObject},
Expand Down Expand Up @@ -31,30 +31,38 @@ pub static FUNCTION_SQ_REGISTER: Mutex<Vec<SQFuncInfo>> = Mutex::new(Vec::new())
/// also has the current vm type
#[derive(Debug)]
pub struct CSquirrelVMHandle {
handle: *mut CSquirrelVM,
handle: NonNull<CSquirrelVM>,
vm_type: ScriptContext,
}

impl CSquirrelVMHandle {
/// **should** not be used outside of the [`crate::entry`] macro
#[doc(hidden)]
pub fn new(
handle: *mut CSquirrelVM,
mut handle: NonNull<CSquirrelVM>,
context: ScriptContext,
is_being_dropped: bool,
token: EngineToken,
) -> Self {
unsafe {
match (context, is_being_dropped) {
(ScriptContext::SERVER, false) => {
_ = SQVM_SERVER.get(token).replace(Some((*handle).sqvm))
_ = SQVM_SERVER.get(token).replace(Some(
NonNull::new(handle.as_mut().sqvm).expect("sqvm cannot be null"),
))
}
(ScriptContext::SERVER, true) => _ = SQVM_SERVER.get(token).replace(None),
(ScriptContext::CLIENT, false) => {
_ = SQVM_CLIENT.get(token).replace(Some((*handle).sqvm))
_ = SQVM_CLIENT.get(token).replace(Some(
NonNull::new(handle.as_mut().sqvm).expect("sqvm cannot be null"),
))
}
(ScriptContext::CLIENT, true) => _ = SQVM_CLIENT.get(token).replace(None),
(ScriptContext::UI, false) => _ = SQVM_UI.get(token).replace(Some((*handle).sqvm)),
(ScriptContext::UI, false) => {
_ = SQVM_UI.get(token).replace(Some(
NonNull::new(handle.as_mut().sqvm).expect("sqvm cannot be null"),
))
}
(ScriptContext::UI, true) => _ = SQVM_UI.get(token).replace(None),
}
}
Expand All @@ -79,7 +87,7 @@ impl CSquirrelVMHandle {

let name = to_cstring(&name);

unsafe { (sqfunctions.sq_defconst)(self.handle, name.as_ptr(), value.into()) }
unsafe { (sqfunctions.sq_defconst)(self.handle.as_ptr(), name.as_ptr(), value.into()) }
}

/// gets the raw pointer to [`HSquirrelVM`]
Expand All @@ -92,8 +100,8 @@ impl CSquirrelVMHandle {
/// [`UnsafeHandle`] : when used outside of engine thread can cause race conditions or ub
///
/// [`UnsafeHandle`] should only be used to transfer the pointers to other places in the engine thread like sqfunctions or runframe
pub unsafe fn get_sqvm(&self) -> UnsafeHandle<*mut HSquirrelVM> {
unsafe { UnsafeHandle::internal_new((*self.handle).sqvm) }
pub const unsafe fn get_sqvm(&self) -> UnsafeHandle<NonNull<HSquirrelVM>> {
unsafe { UnsafeHandle::internal_new(NonNull::new_unchecked(self.handle.as_ref().sqvm)) }
}
/// gets the raw pointer to [`CSquirrelVM`]
///
Expand All @@ -105,7 +113,7 @@ impl CSquirrelVMHandle {
/// [`UnsafeHandle`] : when used outside of engine thread can cause race conditions or ub
///
/// [`UnsafeHandle`] should only be used to transfer the pointers to other places in the engine thread like sqfunctions or runframe
pub const unsafe fn get_cs_sqvm(&self) -> UnsafeHandle<*mut CSquirrelVM> {
pub const unsafe fn get_cs_sqvm(&self) -> UnsafeHandle<NonNull<CSquirrelVM>> {
UnsafeHandle::internal_new(self.handle)
}

Expand Down Expand Up @@ -183,17 +191,17 @@ impl<T: IntoSquirrelArgs> SquirrelFn<T> {
/// This function will return an error if the fails to execute for some reason which is unlikely since it would be type checked
pub fn run(
&mut self,
sqvm: *mut HSquirrelVM,
sqvm: NonNull<HSquirrelVM>,
sqfunctions: &'static SquirrelFunctions,
args: T,
) -> Result<(), CallError> {
unsafe {
let amount = args.into_push(sqvm, sqfunctions);

(sqfunctions.sq_pushobject)(sqvm, self.func.as_callable());
(sqfunctions.sq_pushroottable)(sqvm);
(sqfunctions.sq_pushobject)(sqvm.as_ptr(), self.func.as_callable());
(sqfunctions.sq_pushroottable)(sqvm.as_ptr());

if (sqfunctions.sq_call)(sqvm, amount, true as u32, true as u32)
if (sqfunctions.sq_call)(sqvm.as_ptr(), amount, true as u32, true as u32)
== SQRESULT::SQRESULT_ERROR
{
return Err(CallError::FunctionFailedToExecute);
Expand All @@ -209,7 +217,7 @@ impl<T: IntoSquirrelArgs> SquirrelFn<T> {
/// This function will return an error if the fails to execute for some reason which is unlikely since it would be type checked
pub fn call(
&mut self,
sqvm: *mut HSquirrelVM,
sqvm: NonNull<HSquirrelVM>,
sqfunctions: &'static SquirrelFunctions,
args: T,
) -> Result<(), CallError> {
Expand Down Expand Up @@ -264,7 +272,7 @@ pub fn register_sq_functions(get_info_func: FuncSQFuncInfo) {
/// }
/// ```
pub fn call_sq_function<R: GetFromSQObject>(
sqvm: *mut HSquirrelVM,
sqvm: NonNull<HSquirrelVM>,
sqfunctions: &'static SquirrelFunctions,
function_name: impl AsRef<str>,
) -> Result<R, CallError> {
Expand All @@ -274,7 +282,7 @@ pub fn call_sq_function<R: GetFromSQObject>(
let function_name = try_cstring(function_name.as_ref())?;

let result = unsafe {
(sqfunctions.sq_getfunction)(sqvm, function_name.as_ptr(), ptr, std::ptr::null())
(sqfunctions.sq_getfunction)(sqvm.as_ptr(), function_name.as_ptr(), ptr, std::ptr::null())
};

if result != 0 {
Expand Down Expand Up @@ -310,7 +318,7 @@ pub fn call_sq_function<R: GetFromSQObject>(
/// }
/// ```
pub fn call_sq_object_function<R: GetFromSQObject>(
sqvm: *mut HSquirrelVM,
sqvm: NonNull<HSquirrelVM>,
sqfunctions: &'static SquirrelFunctions,
mut obj: SQHandle<SQClosure>,
) -> Result<R, CallError> {
Expand All @@ -319,12 +327,12 @@ pub fn call_sq_object_function<R: GetFromSQObject>(

#[inline]
fn _call_sq_object_function<R: GetFromSQObject>(
sqvm: *mut HSquirrelVM,
mut sqvm: NonNull<HSquirrelVM>,
sqfunctions: &'static SquirrelFunctions,
ptr: *mut SQObject,
) -> Result<R, CallError> {
unsafe {
let sqvm = &mut *sqvm;
let sqvm = sqvm.as_mut();
(sqfunctions.sq_pushobject)(sqvm, ptr);
(sqfunctions.sq_pushroottable)(sqvm);

Expand Down Expand Up @@ -358,7 +366,7 @@ fn _call_sq_object_function<R: GetFromSQObject>(
/// }
/// ```
pub fn compile_string(
sqvm: *mut HSquirrelVM,
sqvm: NonNull<HSquirrelVM>,
sqfunctions: &SquirrelFunctions,
should_throw_error: bool,
code: impl AsRef<str>,
Expand All @@ -376,17 +384,17 @@ pub fn compile_string(

unsafe {
let result = (sqfunctions.sq_compilebuffer)(
sqvm,
sqvm.as_ptr(),
&mut compile_buffer as *mut CompileBufferState,
BUFFER_NAME,
-1,
should_throw_error as u32,
);

if result != SQRESULT::SQRESULT_ERROR {
(sqfunctions.sq_pushroottable)(sqvm);
(sqfunctions.sq_pushroottable)(sqvm.as_ptr());

if (sqfunctions.sq_call)(sqvm, 1, 0, 0) == SQRESULT::SQRESULT_ERROR {
if (sqfunctions.sq_call)(sqvm.as_ptr(), 1, 0, 0) == SQRESULT::SQRESULT_ERROR {
Err(SQCompileError::BufferFailedToExecute)
} else {
Ok(())
Expand Down
Loading

0 comments on commit 0da0815

Please sign in to comment.