Skip to content

Commit 414c0bd

Browse files
committed
Clone objects when lowering them (mozilla#1797)
Currently, `lower()` always returns a borrow of the object handle. This is fine for function arguments, since you know the object is still alive on the stack while the function is being called. However, for function returns this is not correct. To fix this: clone the handle in `lower()`. Added a test for this -- it was surprisingly easy to cause a segfault with the current behavior.
1 parent 35d8770 commit 414c0bd

File tree

18 files changed

+185
-45
lines changed

18 files changed

+185
-45
lines changed

fixtures/coverall/src/coverall.udl

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace coverall {
2727
ReturnOnlyEnum output_return_only_enum();
2828

2929
void try_input_return_only_dict(ReturnOnlyDict d);
30+
31+
Getters test_round_trip_through_rust(Getters getters);
32+
void test_round_trip_through_foreign(Getters getters);
3033
};
3134

3235
dictionary SimpleDict {
@@ -229,6 +232,7 @@ interface Getters {
229232
string? get_option(string v, boolean arg2);
230233
sequence<i32> get_list(sequence<i32> v, boolean arg2);
231234
void get_nothing(string v);
235+
Coveralls round_trip_object(Coveralls coveralls);
232236
};
233237

234238
// Test trait #2

fixtures/coverall/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ use std::time::SystemTime;
1010
use once_cell::sync::Lazy;
1111

1212
mod traits;
13-
pub use traits::{ancestor_names, get_traits, make_rust_getters, test_getters, Getters, NodeTrait};
13+
pub use traits::{
14+
ancestor_names, get_traits, make_rust_getters, test_getters, test_round_trip_through_foreign,
15+
test_round_trip_through_rust, Getters, NodeTrait,
16+
};
1417

1518
static NUM_ALIVE: Lazy<RwLock<u64>> = Lazy::new(|| RwLock::new(0));
1619

fixtures/coverall/src/traits.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* License, v. 2.0. If a copy of the MPL was not distributed with this
33
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
44

5-
use super::{ComplexError, CoverallError};
5+
use super::{ComplexError, CoverallError, Coveralls};
66
use std::sync::{Arc, Mutex};
77

88
// namespace functions.
@@ -41,6 +41,16 @@ pub trait Getters: Send + Sync {
4141
fn get_option(&self, v: String, arg2: bool) -> Result<Option<String>, ComplexError>;
4242
fn get_list(&self, v: Vec<i32>, arg2: bool) -> Vec<i32>;
4343
fn get_nothing(&self, v: String);
44+
fn round_trip_object(&self, coveralls: Arc<Coveralls>) -> Arc<Coveralls>;
45+
}
46+
47+
pub fn test_round_trip_through_rust(getters: Arc<dyn Getters>) -> Arc<dyn Getters> {
48+
getters
49+
}
50+
51+
pub fn test_round_trip_through_foreign(getters: Arc<dyn Getters>) {
52+
let coveralls = getters.round_trip_object(Arc::new(Coveralls::new("round-trip".to_owned())));
53+
assert_eq!(coveralls.get_name(), "round-trip");
4454
}
4555

4656
struct RustGetters;
@@ -90,6 +100,10 @@ impl Getters for RustGetters {
90100
}
91101

92102
fn get_nothing(&self, _v: String) {}
103+
104+
fn round_trip_object(&self, coveralls: Arc<Coveralls>) -> Arc<Coveralls> {
105+
coveralls
106+
}
93107
}
94108

95109
pub fn make_rust_getters() -> Arc<dyn Getters> {

fixtures/coverall/tests/bindings/test_coverall.kts

+11
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ class KotlinGetters : Getters {
280280

281281
@Suppress("UNUSED_PARAMETER")
282282
override fun getNothing(v: String) = Unit
283+
284+
override fun roundTripObject(coveralls: Coveralls): Coveralls {
285+
return coveralls
286+
}
283287
}
284288

285289
// Test traits implemented in Rust
@@ -395,6 +399,13 @@ getTraits().let { traits ->
395399
// not possible through the `NodeTrait` interface (see #1787).
396400
}
397401

402+
makeRustGetters().let { rustGetters ->
403+
// Check that these don't cause use-after-free bugs
404+
testRoundTripThroughRust(rustGetters)
405+
406+
testRoundTripThroughForeign(KotlinGetters())
407+
}
408+
398409
// This tests that the UniFFI-generated scaffolding doesn't introduce any unexpected locking.
399410
// We have one thread busy-wait for a some period of time, while a second thread repeatedly
400411
// increments the counter and then checks if the object is still busy. The second thread should

fixtures/coverall/tests/bindings/test_coverall.py

+11
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def get_list(self, v, arg2):
333333
def get_nothing(self, _v):
334334
return None
335335

336+
def round_trip_object(self, coveralls):
337+
return coveralls
338+
336339
class PyNode:
337340
def __init__(self):
338341
self.parent = None
@@ -432,5 +435,13 @@ def test_path(self):
432435
py_node.set_parent(None)
433436
traits[0].set_parent(None)
434437

438+
def test_round_tripping(self):
439+
rust_getters = make_rust_getters();
440+
coveralls = Coveralls("test_round_tripping")
441+
# Check that these don't cause use-after-free bugs
442+
test_round_trip_through_rust(rust_getters)
443+
444+
test_round_trip_through_foreign(PyGetters())
445+
435446
if __name__=='__main__':
436447
unittest.main()

fixtures/coverall/tests/bindings/test_coverall.swift

+13
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ class SwiftGetters: Getters {
315315
func getList(v: [Int32], arg2: Bool) -> [Int32] { arg2 ? v : [] }
316316
func getNothing(v: String) -> () {
317317
}
318+
319+
func roundTripObject(coveralls: Coveralls) -> Coveralls {
320+
return coveralls
321+
}
318322
}
319323

320324

@@ -444,3 +448,12 @@ do {
444448
swiftNode.setParent(parent: nil)
445449
traits[0].setParent(parent: nil)
446450
}
451+
452+
// Test round tripping
453+
do {
454+
let rustGetters = makeRustGetters()
455+
// Check that these don't cause use-after-free bugs
456+
let _ = testRoundTripThroughRust(getters: rustGetters)
457+
458+
testRoundTripThroughForeign(getters: SwiftGetters())
459+
}

uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt

+10-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,14 @@ abstract class FFIObject: Disposable, AutoCloseable {
106106
private val callCounter = AtomicLong(1)
107107

108108
open protected fun freeRustArcPtr() {
109-
// To be overridden in subclasses.
109+
// Overridden by generated subclasses, the default method exists to allow users to manually
110+
// implement the interface
111+
}
112+
113+
open fun uniffiClonePointer(): Pointer {
114+
// Overridden by generated subclasses, the default method exists to allow users to manually
115+
// implement the interface
116+
throw RuntimeException("uniffiClonePointer not implemented")
110117
}
111118

112119
override fun destroy() {
@@ -139,7 +146,7 @@ abstract class FFIObject: Disposable, AutoCloseable {
139146
} while (! this.callCounter.compareAndSet(c, c + 1L))
140147
// Now we can safely do the method call without the pointer being freed concurrently.
141148
try {
142-
return block(this.pointer!!)
149+
return block(this.uniffiClonePointer())
143150
} finally {
144151
// This decrement always matches the increment we performed above.
145152
if (this.callCounter.decrementAndGet() == 0L) {
@@ -150,4 +157,4 @@ abstract class FFIObject: Disposable, AutoCloseable {
150157
}
151158

152159
/** Used to instantiate a [FFIObject] without an actual pointer, for fakes in tests, mostly. */
153-
object NoPointer
160+
object NoPointer

uniffi_bindgen/src/bindings/kotlin/templates/ObjectTemplate.kt

+7-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ open class {{ impl_class_name }} : FFIObject, {{ interface_name }} {
2929
{%- when None %}
3030
{%- endmatch %}
3131

32+
override fun uniffiClonePointer(): Pointer {
33+
return rustCall() { status ->
34+
_UniFFILib.INSTANCE.{{ obj.ffi_object_clone().name() }}(pointer!!, status)
35+
}
36+
}
37+
3238
/**
3339
* Disconnect the object from the underlying Rust object.
3440
*
@@ -165,7 +171,7 @@ public object {{ obj|ffi_converter_name }}: FfiConverter<{{ type_name }}, Pointe
165171
override fun lower(value: {{ type_name }}): Pointer {
166172
{%- match obj.imp() %}
167173
{%- when ObjectImpl::Struct %}
168-
return value.callWithPointer { it }
174+
return value.uniffiClonePointer()
169175
{%- when ObjectImpl::Trait %}
170176
return Pointer(handleMap.insert(value))
171177
{%- endmatch %}

uniffi_bindgen/src/bindings/python/templates/ObjectTemplate.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __del__(self):
2525
if pointer is not None:
2626
_rust_call(_UniffiLib.{{ obj.ffi_object_free().name() }}, pointer)
2727

28+
def _uniffi_clone_pointer(self):
29+
return _rust_call(_UniffiLib.{{ obj.ffi_object_clone().name() }}, self._pointer)
30+
2831
# Used by alternative constructors or any methods which return this type.
2932
@classmethod
3033
def _make_instance_(cls, pointer):
@@ -60,13 +63,13 @@ def __eq__(self, other: object) -> {{ eq.return_type().unwrap()|type_name }}:
6063
if not isinstance(other, {{ type_name }}):
6164
return NotImplemented
6265

63-
return {{ eq.return_type().unwrap()|lift_fn }}({% call py::to_ffi_call_with_prefix("self._pointer", eq) %})
66+
return {{ eq.return_type().unwrap()|lift_fn }}({% call py::to_ffi_call_with_prefix("self._uniffi_clone_pointer()", eq) %})
6467

6568
def __ne__(self, other: object) -> {{ ne.return_type().unwrap()|type_name }}:
6669
if not isinstance(other, {{ type_name }}):
6770
return NotImplemented
6871

69-
return {{ ne.return_type().unwrap()|lift_fn }}({% call py::to_ffi_call_with_prefix("self._pointer", ne) %})
72+
return {{ ne.return_type().unwrap()|lift_fn }}({% call py::to_ffi_call_with_prefix("self._uniffi_clone_pointer()", ne) %})
7073
{%- when UniffiTrait::Hash { hash } %}
7174
{%- call py::method_decl("__hash__", hash) %}
7275
{% endmatch %}
@@ -103,7 +106,7 @@ def lower(value: {{ protocol_name }}):
103106
{%- when ObjectImpl::Struct %}
104107
if not isinstance(value, {{ impl_name }}):
105108
raise TypeError("Expected {{ impl_name }} instance, {} found".format(type(value).__name__))
106-
return value._pointer
109+
return value._uniffi_clone_pointer()
107110
{%- when ObjectImpl::Trait %}
108111
return {{ ffi_converter_name }}._handle_map.insert(value)
109112
{%- endmatch %}

uniffi_bindgen/src/bindings/python/templates/macros.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def {{ py_method_name }}(self, {% call arg_list_decl(meth) %}):
120120
{%- call setup_args_extra_indent(meth) %}
121121
return _uniffi_rust_call_async(
122122
_UniffiLib.{{ meth.ffi_func().name() }}(
123-
self._pointer, {% call arg_list_lowered(meth) %}
123+
self._uniffi_clone_pointer(), {% call arg_list_lowered(meth) %}
124124
),
125125
_UniffiLib.{{ meth.ffi_rust_future_poll(ci) }},
126126
_UniffiLib.{{ meth.ffi_rust_future_complete(ci) }},
@@ -150,15 +150,15 @@ def {{ py_method_name }}(self, {% call arg_list_decl(meth) %}) -> "{{ return_typ
150150
{%- call docstring(meth, 8) %}
151151
{%- call setup_args_extra_indent(meth) %}
152152
return {{ return_type|lift_fn }}(
153-
{% call to_ffi_call_with_prefix("self._pointer", meth) %}
153+
{% call to_ffi_call_with_prefix("self._uniffi_clone_pointer()", meth) %}
154154
)
155155

156156
{%- when None %}
157157

158158
def {{ py_method_name }}(self, {% call arg_list_decl(meth) %}):
159159
{%- call docstring(meth, 8) %}
160160
{%- call setup_args_extra_indent(meth) %}
161-
{% call to_ffi_call_with_prefix("self._pointer", meth) %}
161+
{% call to_ffi_call_with_prefix("self._uniffi_clone_pointer()", meth) %}
162162
{% endmatch %}
163163
{% endif %}
164164

uniffi_bindgen/src/bindings/ruby/templates/ObjectTemplate.rb

+10-3
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,15 @@ def self._uniffi_check(inst)
3131
end
3232
end
3333

34+
def _uniffi_clone_pointer()
35+
return {{ ci.namespace()|class_name_rb }}.rust_call(
36+
:{{ obj.ffi_object_clone().name() }},
37+
@pointer
38+
)
39+
end
40+
3441
def self._uniffi_lower(inst)
35-
return inst.instance_variable_get :@pointer
42+
return inst._uniffi_clone_pointer()
3643
end
3744

3845
{%- match obj.primary_constructor() %}
@@ -62,14 +69,14 @@ def self.{{ cons.name()|fn_name_rb }}({% call rb::arg_list_decl(cons) %})
6269
{%- when Some with (return_type) -%}
6370
def {{ meth.name()|fn_name_rb }}({% call rb::arg_list_decl(meth) %})
6471
{%- call rb::setup_args_extra_indent(meth) %}
65-
result = {% call rb::to_ffi_call_with_prefix("@pointer", meth) %}
72+
result = {% call rb::to_ffi_call_with_prefix("_uniffi_clone_pointer()", meth) %}
6673
return {{ "result"|lift_rb(return_type) }}
6774
end
6875
6976
{%- when None -%}
7077
def {{ meth.name()|fn_name_rb }}({% call rb::arg_list_decl(meth) %})
7178
{%- call rb::setup_args_extra_indent(meth) %}
72-
{% call rb::to_ffi_call_with_prefix("@pointer", meth) %}
79+
{% call rb::to_ffi_call_with_prefix("_uniffi_clone_pointer()", meth) %}
7380
end
7481
{% endmatch %}
7582
{% endfor %}

uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift

+12-8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public class {{ impl_class_name }}:
3030
self.pointer = pointer
3131
}
3232

33+
public func uniffiClonePointer() -> UnsafeMutableRawPointer {
34+
return try! rustCall { {{ obj.ffi_object_clone().name() }}(self.pointer, $0) }
35+
}
36+
3337
{%- match obj.primary_constructor() %}
3438
{%- when Some with (cons) %}
3539
{%- call swift::docstring(cons, 4) %}
@@ -59,7 +63,7 @@ public class {{ impl_class_name }}:
5963
return {% call swift::try(meth) %} await uniffiRustCallAsync(
6064
rustFutureFunc: {
6165
{{ meth.ffi_func().name() }}(
62-
self.pointer
66+
self.uniffiClonePointer()
6367
{%- for arg in meth.arguments() -%}
6468
,
6569
{{ arg|lower_fn }}({{ arg.name()|var_name }})
@@ -92,14 +96,14 @@ public class {{ impl_class_name }}:
9296
{%- call swift::docstring(meth, 4) %}
9397
public func {{ meth.name()|fn_name }}({% call swift::arg_list_decl(meth) %}) {% call swift::throws(meth) %} -> {{ return_type|type_name }} {
9498
return {% call swift::try(meth) %} {{ return_type|lift_fn }}(
95-
{% call swift::to_ffi_call_with_prefix("self.pointer", meth) %}
99+
{% call swift::to_ffi_call_with_prefix("self.uniffiClonePointer()", meth) %}
96100
)
97101
}
98102

99103
{%- when None %}
100104
{%- call swift::docstring(meth, 4) %}
101105
public func {{ meth.name()|fn_name }}({% call swift::arg_list_decl(meth) %}) {% call swift::throws(meth) %} {
102-
{% call swift::to_ffi_call_with_prefix("self.pointer", meth) %}
106+
{% call swift::to_ffi_call_with_prefix("self.uniffiClonePointer()", meth) %}
103107
}
104108

105109
{%- endmatch -%}
@@ -111,25 +115,25 @@ public class {{ impl_class_name }}:
111115
{%- when UniffiTrait::Display { fmt } %}
112116
public var description: String {
113117
return {% call swift::try(fmt) %} {{ fmt.return_type().unwrap()|lift_fn }}(
114-
{% call swift::to_ffi_call_with_prefix("self.pointer", fmt) %}
118+
{% call swift::to_ffi_call_with_prefix("self.uniffiClonePointer()", fmt) %}
115119
)
116120
}
117121
{%- when UniffiTrait::Debug { fmt } %}
118122
public var debugDescription: String {
119123
return {% call swift::try(fmt) %} {{ fmt.return_type().unwrap()|lift_fn }}(
120-
{% call swift::to_ffi_call_with_prefix("self.pointer", fmt) %}
124+
{% call swift::to_ffi_call_with_prefix("self.uniffiClonePointer()", fmt) %}
121125
)
122126
}
123127
{%- when UniffiTrait::Eq { eq, ne } %}
124128
public static func == (lhs: {{ impl_class_name }}, other: {{ impl_class_name }}) -> Bool {
125129
return {% call swift::try(eq) %} {{ eq.return_type().unwrap()|lift_fn }}(
126-
{% call swift::to_ffi_call_with_prefix("lhs.pointer", eq) %}
130+
{% call swift::to_ffi_call_with_prefix("lhs.uniffiClonePointer()", eq) %}
127131
)
128132
}
129133
{%- when UniffiTrait::Hash { hash } %}
130134
public func hash(into hasher: inout Hasher) {
131135
let val = {% call swift::try(hash) %} {{ hash.return_type().unwrap()|lift_fn }}(
132-
{% call swift::to_ffi_call_with_prefix("self.pointer", hash) %}
136+
{% call swift::to_ffi_call_with_prefix("self.uniffiClonePointer()", hash) %}
133137
)
134138
hasher.combine(val)
135139
}
@@ -161,7 +165,7 @@ public struct {{ ffi_converter_name }}: FfiConverter {
161165
public static func lower(_ value: {{ type_name }}) -> UnsafeMutableRawPointer {
162166
{%- match obj.imp() %}
163167
{%- when ObjectImpl::Struct %}
164-
return value.pointer
168+
return value.uniffiClonePointer()
165169
{%- when ObjectImpl::Trait %}
166170
guard let ptr = UnsafeMutableRawPointer(bitPattern: UInt(truncatingIfNeeded: handleMap.insert(obj: value))) else {
167171
fatalError("Cast to UnsafeMutableRawPointer failed")

0 commit comments

Comments
 (0)