Skip to content
Merged
227 changes: 155 additions & 72 deletions crates/protect-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use cipherstash_client::{
ScopedCipher, SteVec, TypeParseError,
},
schema::ColumnConfig,
zerokms::{self, encrypted_record, EncryptedRecord, WithContext, ZeroKMSWithClientKey},
zerokms::{self, EncryptedRecord, WithContext, ZeroKMSWithClientKey},
};
use encrypt_config::{EncryptConfig, Identifier};
use neon::prelude::*;
Expand Down Expand Up @@ -43,8 +43,8 @@ impl Finalize for Client {}
pub enum Encrypted {
#[serde(rename = "ct")]
Ciphertext {
#[serde(rename = "c", with = "encrypted_record::formats::mp_base85")]
ciphertext: EncryptedRecord,
#[serde(rename = "c")]
ciphertext: String,
#[serde(rename = "o")]
ore_index: Option<Vec<String>>,
#[serde(rename = "m")]
Expand Down Expand Up @@ -139,11 +139,12 @@ async fn new_client_inner(encrypt_config: EncryptConfig) -> Result<Client, Error

fn encrypt(mut cx: FunctionContext) -> JsResult<JsPromise> {
let client = (**cx.argument::<JsBox<Client>>(0)?).clone();
let plaintext = cx.argument::<JsString>(1)?.value(&mut cx);
let column_name = cx.argument::<JsString>(2)?.value(&mut cx);
let table_name = cx.argument::<JsString>(3)?.value(&mut cx);
let lock_context = encryption_context_from_js_value(cx.argument_opt(4), &mut cx)?;
let service_token = service_token_from_js_value(cx.argument_opt(5), &mut cx)?;
let (plaintext_target, ident) = plaintext_target_from_js_object(
cx.argument::<JsObject>(1)?,
&client.encrypt_config,
&mut cx,
)?;
let service_token = service_token_from_js_value(cx.argument_opt(2), &mut cx)?;

let rt = runtime(&mut cx)?;
let channel = cx.channel();
Expand All @@ -159,15 +160,7 @@ fn encrypt(mut cx: FunctionContext) -> JsResult<JsPromise> {
//
// This task will _not_ block the JavaScript main thread.
rt.spawn(async move {
let ciphertext_result = encrypt_inner(
client,
plaintext,
column_name,
table_name,
lock_context,
service_token,
)
.await;
let ciphertext_result = encrypt_inner(client, plaintext_target, ident, service_token).await;

// Settle the promise from the result of a closure. JavaScript exceptions
// will be converted to a Promise rejection.
Expand All @@ -177,8 +170,7 @@ fn encrypt(mut cx: FunctionContext) -> JsResult<JsPromise> {
// should be performed outside of it.
deferred.settle_with(&channel, move |mut cx| {
let ciphertext = ciphertext_result.or_else(|err| cx.throw_error(err.to_string()))?;

Ok(cx.string(ciphertext))
eql_encrypted_to_js(ciphertext, &mut cx)
});
});

Expand All @@ -187,24 +179,13 @@ fn encrypt(mut cx: FunctionContext) -> JsResult<JsPromise> {

async fn encrypt_inner(
client: Client,
plaintext: String,
column_name: String,
table_name: String,
encryption_context: Vec<zerokms::Context>,
plaintext_target: PlaintextTarget,
ident: Identifier,
service_token: Option<ServiceToken>,
) -> Result<String, Error> {
let ident = Identifier::new(table_name, column_name);

let column_config = client
.encrypt_config
.get(&ident)
.ok_or_else(|| Error::UnknownColumn(ident.clone()))?;

) -> Result<Encrypted, Error> {
let mut pipeline = ReferencedPendingPipeline::new(client.cipher);
let mut encryptable = PlaintextTarget::new(plaintext, column_config.clone());
encryptable.context = encryption_context;

pipeline.add_with_ref::<PlaintextTarget>(encryptable, 0)?;
pipeline.add_with_ref::<PlaintextTarget>(plaintext_target, 0)?;

let mut source_encrypted = pipeline.encrypt(service_token).await?;

Expand All @@ -214,9 +195,7 @@ async fn encrypt_inner(
)
})?;

let eql_payload = to_eql_encrypted(encrypted, &ident)?;

eql_encrypted_to_json_string(&eql_payload)
to_eql_encrypted(encrypted, &ident)
}

fn encrypt_bulk(mut cx: FunctionContext) -> JsResult<JsPromise> {
Expand All @@ -238,7 +217,7 @@ fn encrypt_bulk(mut cx: FunctionContext) -> JsResult<JsPromise> {

deferred.settle_with(&channel, move |mut cx| {
let ciphertexts = ciphertexts_result.or_else(|err| cx.throw_error(err.to_string()))?;
js_array_from_string_vec(ciphertexts, &mut cx)
js_array_from_eql_encrypted_vec(ciphertexts, &mut cx)
});
});

Expand All @@ -249,7 +228,7 @@ async fn encrypt_bulk_inner(
client: Client,
plaintext_targets: Vec<(PlaintextTarget, Identifier)>,
service_token: Option<ServiceToken>,
) -> Result<Vec<String>, Error> {
) -> Result<Vec<Encrypted>, Error> {
let len = plaintext_targets.len();
let mut pipeline = ReferencedPendingPipeline::new(client.cipher);
let (plaintext_targets, identifiers): (Vec<PlaintextTarget>, Vec<Identifier>) =
Expand All @@ -261,7 +240,7 @@ async fn encrypt_bulk_inner(

let mut source_encrypted = pipeline.encrypt(service_token).await?;

let mut results: Vec<String> = Vec::with_capacity(len);
let mut results: Vec<Encrypted> = Vec::with_capacity(len);

for i in 0..len {
let encrypted = source_encrypted.remove(i).ok_or_else(|| {
Expand All @@ -278,7 +257,7 @@ async fn encrypt_bulk_inner(

let eql_payload = to_eql_encrypted(encrypted, ident)?;

results.push(eql_encrypted_to_json_string(&eql_payload)?);
results.push(eql_payload);
}

Ok(results)
Expand Down Expand Up @@ -404,18 +383,19 @@ fn service_token_from_js_value(
value: Option<Handle<JsValue>>,
cx: &mut FunctionContext,
) -> NeonResult<Option<ServiceToken>> {
if let Some(service_token) = value {
let service_token: Handle<JsObject> = service_token.downcast_or_throw(cx)?;
match value {
Some(service_token) if is_defined(service_token, cx) => {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change that fixes up errors on the Rust side when the token from JS is undefined.

let service_token: Handle<JsObject> = service_token.downcast_or_throw(cx)?;

let token = service_token
.get::<JsString, _, _>(cx, "accessToken")?
.value(cx);
let token = service_token
.get::<JsString, _, _>(cx, "accessToken")?
.value(cx);

let expiry = service_token.get::<JsNumber, _, _>(cx, "expiry")?.value(cx);
let expiry = service_token.get::<JsNumber, _, _>(cx, "expiry")?.value(cx);

Ok(Some(ServiceToken::new(token, expiry as u64)))
} else {
Ok(None)
Ok(Some(ServiceToken::new(token, expiry as u64)))
}
_ => Ok(None),
}
}

Expand All @@ -430,29 +410,38 @@ fn plaintext_targets_from_js_array(

for js_value in js_values {
let obj: Handle<JsObject> = js_value.downcast_or_throw(cx)?;
let (plaintext_target, ident) = plaintext_target_from_js_object(obj, &encrypt_config, cx)?;

plaintext_targets.push((plaintext_target, ident));
}

let plaintext = obj.get::<JsString, _, _>(cx, "plaintext")?.value(cx);
Ok(plaintext_targets)
}

let column = obj.get::<JsString, _, _>(cx, "column")?.value(cx);
let table = obj.get::<JsString, _, _>(cx, "table")?.value(cx);
fn plaintext_target_from_js_object(
value: Handle<'_, JsObject>,
encrypt_config: &Arc<HashMap<Identifier, ColumnConfig>>,
cx: &mut FunctionContext,
) -> NeonResult<(PlaintextTarget, Identifier)> {
let plaintext = value.get::<JsString, _, _>(cx, "plaintext")?.value(cx);

let lock_context = obj.get_opt::<JsValue, _, _>(cx, "lockContext")?;
let lock_context = encryption_context_from_js_value(lock_context, cx)?;
let column = value.get::<JsString, _, _>(cx, "column")?.value(cx);
let table = value.get::<JsString, _, _>(cx, "table")?.value(cx);

let ident = Identifier::new(table, column);
let lock_context = value.get_opt::<JsValue, _, _>(cx, "lockContext")?;
let lock_context = encryption_context_from_js_value(lock_context, cx)?;

let column_config = encrypt_config
.get(&ident)
.ok_or_else(|| Error::UnknownColumn(ident.clone()))
.or_else(|err| cx.throw_error(err.to_string()))?;
let ident = Identifier::new(table, column);

let mut plaintext_target = PlaintextTarget::new(plaintext, column_config.clone());
plaintext_target.context = lock_context;
let column_config = encrypt_config
.get(&ident)
.ok_or_else(|| Error::UnknownColumn(ident.clone()))
.or_else(|err| cx.throw_error(err.to_string()))?;

plaintext_targets.push((plaintext_target, ident));
}
let mut plaintext_target = PlaintextTarget::new(plaintext, column_config.clone());
plaintext_target.context = lock_context;

Ok(plaintext_targets)
Ok((plaintext_target, ident))
}

fn ciphertexts_from_js_array(
Expand Down Expand Up @@ -490,6 +479,34 @@ fn js_array_from_string_vec<'a, C: Context<'a>>(
Ok(js_array)
}

fn js_array_from_u16_vec<'a, C: Context<'a>>(
vec: Vec<u16>,
cx: &mut C,
) -> NeonResult<Handle<'a, JsArray>> {
let js_array = JsArray::new(cx, vec.len());

for (i, value) in vec.iter().enumerate() {
let js_number = cx.number(*value);
js_array.set(cx, i as u32, js_number)?;
}

Ok(js_array)
}

fn js_array_from_eql_encrypted_vec<'a, C: Context<'a>>(
vec: Vec<Encrypted>,
cx: &mut C,
) -> NeonResult<Handle<'a, JsArray>> {
let js_array = JsArray::new(cx, vec.len());

for (i, value) in vec.into_iter().enumerate() {
let js_obj = eql_encrypted_to_js(value, cx)?;
js_array.set(cx, i as u32, js_obj)?;
}

Ok(js_array)
}

fn encrypted_record_from_mp_base85(
base85str: &str,
encryption_context: Vec<zerokms::Context>,
Expand Down Expand Up @@ -554,6 +571,12 @@ fn to_eql_encrypted(
};
}

let ciphertext = ciphertext
.to_mp_base85()
// The error type from `to_mp_base85` isn't public, so we don't derive an error for this one.
// Instead, we use `map_err`.
.map_err(|err| Error::Base85(err.to_string()))?;

Ok(Encrypted::Ciphertext {
ciphertext,
identifier: identifier.to_owned(),
Expand All @@ -571,6 +594,71 @@ fn to_eql_encrypted(
}
}

fn eql_encrypted_to_js<'cx, C: Context<'cx>>(
encrypted: Encrypted,
cx: &mut C,
) -> NeonResult<Handle<'cx, JsObject>> {
let obj: Handle<JsObject> = cx.empty_object();

let Encrypted::Ciphertext {
ciphertext,
ore_index,
match_index,
unique_index,
identifier,
version,
} = encrypted
else {
return cx
.throw_error(Error::Unimplemented("encrypted JSON columns".to_string()).to_string());
};

let k = cx.string("ct");
obj.set(cx, "k", k)?;

let c = cx.string(ciphertext);
obj.set(cx, "c", c)?;

if let Some(ore_index) = ore_index {
let o = js_array_from_string_vec(ore_index, cx)?;
obj.set(cx, "o", o)?;
} else {
let o = cx.null();
obj.set(cx, "o", o)?;
}

if let Some(match_index) = match_index {
let m = js_array_from_u16_vec(match_index, cx)?;
obj.set(cx, "m", m)?;
} else {
let m = cx.null();
obj.set(cx, "m", m)?;
}

if let Some(unique_index) = unique_index {
let u = cx.string(unique_index);
obj.set(cx, "u", u)?;
} else {
let u = cx.null();
obj.set(cx, "u", u)?;
}

let i = cx.empty_object();

let col = cx.string(identifier.column);
i.set(cx, "c", col)?;

let t = cx.string(identifier.table);
i.set(cx, "t", t)?;

obj.set(cx, "i", i)?;

let v = cx.number(version);
obj.set(cx, "v", v)?;

Ok(obj)
}

fn format_index_term_binary(bytes: &Vec<u8>) -> String {
hex::encode(bytes)
}
Expand All @@ -596,13 +684,8 @@ fn format_index_term_ore(bytes: &Vec<u8>) -> Vec<String> {
vec![format_index_term_ore_bytea(bytes)]
}

fn eql_encrypted_to_json_string(encrypted: &Encrypted) -> Result<String, Error> {
serde_json::to_string(encrypted).map_err(|_| {
Error::InvariantViolation(
"expected EQL payload to be serialiable as JSON, but it could not be serialized"
.to_string(),
)
})
fn is_defined(js_value: Handle<'_, JsValue>, cx: &mut FunctionContext) -> bool {
!js_value.is_a::<JsUndefined, _>(cx)
}

#[neon::main]
Expand Down
Loading
Loading