Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion c_src/ex_dtls/native.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ UNIFEX_TERM do_init(UnifexEnv *env, char *mode_str, int dtls_srtp,
state->x509 = NULL;
state->mode = 0;
state->hsk_finished = 0;
state->closed = 0;
state->env = unifex_alloc_env(env);

int mode;
Expand Down Expand Up @@ -244,6 +245,10 @@ UNIFEX_TERM get_cert_fingerprint(UnifexEnv *env, UnifexPayload *cert) {
}

UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
if (state->closed == 1) {
return do_handshake_result_error_closed(env);
}

SSL_do_handshake(state->ssl);

UnifexPayload **gen_packets = NULL;
Expand All @@ -258,14 +263,19 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
} else {
int timeout = get_timeout(state->ssl);
UNIFEX_TERM res_term =
do_handshake_result(env, gen_packets, gen_packets_size, timeout);
do_handshake_result_ok(env, gen_packets, gen_packets_size, timeout);
free_payload_array(gen_packets, gen_packets_size);

return res_term;
}
}

UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
if (state->closed == 1) {
DEBUG("Cannot write, connection closed");
return write_data_result_error_closed(env);
}

if (state->hsk_finished != 1) {
DEBUG("Cannot write, handshake not finished");
return write_data_result_error_handshake_not_finished(env);
Expand Down Expand Up @@ -303,6 +313,10 @@ UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
}

UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
if (state->closed == 1) {
return handle_data_result_error_closed(env);
}

(void)env;

if (payload->size != 0) {
Expand Down Expand Up @@ -332,6 +346,32 @@ UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
}
}

// prefix close with exd (ex_dtls) as close is defined in unistd.h
UNIFEX_TERM exd_close(UnifexEnv *env, State *state) {
if (state->closed == 1) {
return exd_close_result_ok(env, NULL, 0);
}

state->closed = 1;
if (SSL_shutdown(state->ssl) < 0) {
return exd_close_result_ok(env, NULL, 0);
} else {
UnifexPayload **gen_packets = NULL;
int gen_packets_size = 0;
read_pending_data(&gen_packets, &gen_packets_size, state);

if (gen_packets == NULL) {
return unifex_raise(state->env,
"Close failed: couldn't read pending data");
} else {
UNIFEX_TERM res_term =
exd_close_result_ok(env, gen_packets, gen_packets_size);
free_payload_array(gen_packets, gen_packets_size);
return res_term;
}
}
}

UNIFEX_TERM handle_regular_read(State *state, char data[], int ret) {
if (ret > 0) {
UnifexPayload packets;
Expand Down
1 change: 1 addition & 0 deletions c_src/ex_dtls/native.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct State {
X509 *x509;
int mode;
int hsk_finished;
int closed;
};

#include "_generated/native.h"
18 changes: 13 additions & 5 deletions c_src/ex_dtls/native.spec.exs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ spec get_peer_cert(state) :: payload | (nil :: label)

spec get_cert_fingerprint(payload) :: payload

spec do_handshake(state) :: {packets :: [payload], timeout :: int}
spec do_handshake(state) :: {:ok :: label, packets :: [payload], timeout :: int} | {:error :: label, :closed :: label}

spec handle_timeout(state) :: (:ok :: label) | {:retransmit :: label, packets :: [payload], timeout :: int}
spec handle_timeout(state) ::
(:ok :: label)
| {:retransmit :: label, packets :: [payload], timeout :: int}
| {:error :: label, :closed :: label}

spec write_data(state, packets :: payload) :: {:ok :: label, packets :: [payload]} | {:error :: label, :handshake_not_finished :: label}
spec write_data(state, packets :: payload) ::
{:ok :: label, packets :: [payload]}
| {:error :: label, :handshake_not_finished :: label}
| {:error :: label, :closed :: label}

spec handle_data(state, packets :: payload) ::
{:ok :: label, packets :: payload}
Expand All @@ -34,5 +40,7 @@ spec handle_data(state, packets :: payload) ::
| {:handshake_finished :: label, client_keying_material :: payload,
server_keying_material :: payload, protection_profile :: int, packets :: [payload]}
| {:error :: label, :peer_closed_for_writing :: label}
| {:error :: label, :handshake_error :: label
}
| {:error :: label, :handshake_error :: label}
| {:error :: label, :closed :: label}

spec exd_close(state) :: {:ok :: label, packets :: [payload]}
9 changes: 9 additions & 0 deletions lib/ex_dtls.ex
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,13 @@ defmodule ExDTLS do
"""
@spec handle_timeout(dtls()) :: :ok | {:retransmit, packets :: [binary()], timeout :: integer()}
defdelegate handle_timeout(dtls), to: Native

@doc """
Irreversibly closes DTLS session.

If a handshake has been finished, this function will generate `close_notify` DTLS alert
that should be sent to the other side.
"""
@spec close(dtls()) :: {:ok, packets :: [binary()]}
defdelegate close(dtls), to: Native, as: :exd_close
end
28 changes: 24 additions & 4 deletions test/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule ExDTLS.IntegrationTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)

Expand All @@ -17,7 +17,7 @@ defmodule ExDTLS.IntegrationTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)

Expand All @@ -34,7 +34,7 @@ defmodule ExDTLS.IntegrationTest do
assert {:error, :handshake_not_finished} = ExDTLS.write_data(sr_dtls, <<1, 2, 3>>)
assert {:error, :handshake_not_finished} = ExDTLS.write_data(cl_dtls, <<1, 2, 3>>)

{packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
assert :ok == loop({sr_dtls, false}, {cl_dtls, false}, packets)

msg = <<1, 3, 2, 5>>
Expand All @@ -55,11 +55,31 @@ defmodule ExDTLS.IntegrationTest do

tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:handshake_packets, packets, _timeout} = feed_packets(rx_dtls, packets)
assert {:error, :handshake_error} = feed_packets(tx_dtls, packets)
end

describe "disconnect" do
test "before handshake has finished" do
dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
assert {:ok, []} = ExDTLS.close(dtls)
# assert that handshake can't be started
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
end

test "after handshake has finished" do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
assert {:ok, [packet]} = ExDTLS.close(rx_dtls)
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(tx_dtls, packet)
end
end

defp loop({_dtls1, true}, {_dtls2, true}, _packets) do
:ok
end
Expand Down
2 changes: 1 addition & 1 deletion test/retransmission_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule ExDTLS.RetransmissionTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)

{_packets, timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, _packets, timeout} = ExDTLS.do_handshake(tx_dtls)
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
{:retransmit, packets, timeout} = wait_for_timeout(tx_dtls, :tx)
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
Expand Down
Loading