diff --git a/c_src/ex_dtls/native.c b/c_src/ex_dtls/native.c index 964e50f..93c38bb 100644 --- a/c_src/ex_dtls/native.c +++ b/c_src/ex_dtls/native.c @@ -108,6 +108,7 @@ UNIFEX_TERM do_init(UnifexEnv *env, char *mode_str, int dtls_srtp, state->x509 = NULL; state->mode = 0; state->hsk_finished = 0; + state->closed = 0; state->env = unifex_alloc_env(env); int mode; @@ -244,6 +245,10 @@ UNIFEX_TERM get_cert_fingerprint(UnifexEnv *env, UnifexPayload *cert) { } UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) { + if (state->closed == 1) { + return do_handshake_result_error_closed(env); + } + SSL_do_handshake(state->ssl); UnifexPayload **gen_packets = NULL; @@ -258,7 +263,7 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) { } else { int timeout = get_timeout(state->ssl); UNIFEX_TERM res_term = - do_handshake_result(env, gen_packets, gen_packets_size, timeout); + do_handshake_result_ok(env, gen_packets, gen_packets_size, timeout); free_payload_array(gen_packets, gen_packets_size); return res_term; @@ -266,6 +271,11 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) { } UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) { + if (state->closed == 1) { + DEBUG("Cannot write, connection closed"); + return write_data_result_error_closed(env); + } + if (state->hsk_finished != 1) { DEBUG("Cannot write, handshake not finished"); return write_data_result_error_handshake_not_finished(env); @@ -303,6 +313,10 @@ UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) { } UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) { + if (state->closed == 1) { + return handle_data_result_error_closed(env); + } + (void)env; if (payload->size != 0) { @@ -332,6 +346,32 @@ UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) { } } +// prefix close with exd (ex_dtls) as close is defined in unistd.h +UNIFEX_TERM exd_close(UnifexEnv *env, State *state) { + if (state->closed == 1) { + return exd_close_result_ok(env, NULL, 0); + } + + state->closed = 1; + if (SSL_shutdown(state->ssl) < 0) { + return exd_close_result_ok(env, NULL, 0); + } else { + UnifexPayload **gen_packets = NULL; + int gen_packets_size = 0; + read_pending_data(&gen_packets, &gen_packets_size, state); + + if (gen_packets == NULL) { + return unifex_raise(state->env, + "Close failed: couldn't read pending data"); + } else { + UNIFEX_TERM res_term = + exd_close_result_ok(env, gen_packets, gen_packets_size); + free_payload_array(gen_packets, gen_packets_size); + return res_term; + } + } +} + UNIFEX_TERM handle_regular_read(State *state, char data[], int ret) { if (ret > 0) { UnifexPayload packets; @@ -351,6 +391,7 @@ UNIFEX_TERM handle_read_error(State *state, int ret) { int error = SSL_get_error(state->ssl, ret); switch (error) { case SSL_ERROR_ZERO_RETURN: + state->closed = 1; return handle_data_result_error_peer_closed_for_writing(state->env); case SSL_ERROR_WANT_READ: DEBUG("SSL WANT READ. This is workaround. Did we get retransmission?"); @@ -452,6 +493,10 @@ UNIFEX_TERM handle_handshake_in_progress(State *state, int ret) { } UNIFEX_TERM handle_timeout(UnifexEnv *env, State *state) { + if (state->closed == 1) { + return handle_timeout_result_error_closed(env); + } + long result = DTLSv1_handle_timeout(state->ssl); if (result != 1) return handle_timeout_result_ok(env); diff --git a/c_src/ex_dtls/native.h b/c_src/ex_dtls/native.h index 1c5365b..bcaf461 100644 --- a/c_src/ex_dtls/native.h +++ b/c_src/ex_dtls/native.h @@ -13,6 +13,7 @@ struct State { X509 *x509; int mode; int hsk_finished; + int closed; }; #include "_generated/native.h" diff --git a/c_src/ex_dtls/native.spec.exs b/c_src/ex_dtls/native.spec.exs index 5d9cde5..8b5c2f7 100644 --- a/c_src/ex_dtls/native.spec.exs +++ b/c_src/ex_dtls/native.spec.exs @@ -21,11 +21,17 @@ spec get_peer_cert(state) :: payload | (nil :: label) spec get_cert_fingerprint(payload) :: payload -spec do_handshake(state) :: {packets :: [payload], timeout :: int} +spec do_handshake(state) :: {:ok :: label, packets :: [payload], timeout :: int} | {:error :: label, :closed :: label} -spec handle_timeout(state) :: (:ok :: label) | {:retransmit :: label, packets :: [payload], timeout :: int} +spec handle_timeout(state) :: + (:ok :: label) + | {:retransmit :: label, packets :: [payload], timeout :: int} + | {:error :: label, :closed :: label} -spec write_data(state, packets :: payload) :: {:ok :: label, packets :: [payload]} | {:error :: label, :handshake_not_finished :: label} +spec write_data(state, packets :: payload) :: + {:ok :: label, packets :: [payload]} + | {:error :: label, :handshake_not_finished :: label} + | {:error :: label, :closed :: label} spec handle_data(state, packets :: payload) :: {:ok :: label, packets :: payload} @@ -34,5 +40,7 @@ spec handle_data(state, packets :: payload) :: | {:handshake_finished :: label, client_keying_material :: payload, server_keying_material :: payload, protection_profile :: int, packets :: [payload]} | {:error :: label, :peer_closed_for_writing :: label} - | {:error :: label, :handshake_error :: label -} + | {:error :: label, :handshake_error :: label} + | {:error :: label, :closed :: label} + +spec exd_close(state) :: {:ok :: label, packets :: [payload]} diff --git a/lib/ex_dtls.ex b/lib/ex_dtls.ex index 9def56d..143ed69 100644 --- a/lib/ex_dtls.ex +++ b/lib/ex_dtls.ex @@ -139,7 +139,8 @@ defmodule ExDTLS do `timeout` is a time in ms after which `handle_timeout/1` should be called. """ - @spec do_handshake(dtls()) :: {packets :: [binary()], timeout :: integer()} + @spec do_handshake(dtls()) :: + {:ok, packets :: [binary()], timeout :: integer()} | {:error, :closed} defdelegate do_handshake(dtls), to: Native @doc """ @@ -148,7 +149,7 @@ defmodule ExDTLS do Generates encrypted packets that need to be passed to the second host. """ @spec write_data(dtls(), data :: binary()) :: - {:ok, packets :: [binary()]} | {:error, :handshake_not_finished} + {:ok, packets :: [binary()]} | {:error, :handshake_not_finished | :closed} defdelegate write_data(dtls, data), to: Native @doc """ @@ -172,7 +173,7 @@ defmodule ExDTLS do remote_keying_material :: binary(), protection_profile_t(), packets :: [binary()]} | {:handshake_finished, local_keying_material :: binary(), remote_keying_material :: binary(), protection_profile_t()} - | {:error, :handshake_error | :peer_closed_for_writing} + | {:error, :handshake_error | :peer_closed_for_writing | :closed} def handle_data(dtls, packets) do case Native.handle_data(dtls, packets) do {:handshake_finished, lkm, rkm, protection_profile, []} -> @@ -192,6 +193,16 @@ defmodule ExDTLS do If there is no timeout to handle, simple `{:ok, dtls()}` tuple is returned. """ - @spec handle_timeout(dtls()) :: :ok | {:retransmit, packets :: [binary()], timeout :: integer()} + @spec handle_timeout(dtls()) :: + :ok | {:retransmit, packets :: [binary()], timeout :: integer()} | {:error, :closed} defdelegate handle_timeout(dtls), to: Native + + @doc """ + Irreversibly closes DTLS session. + + If a handshake has been finished, this function will generate `close_notify` DTLS alert + that should be sent to the other side. + """ + @spec close(dtls()) :: {:ok, packets :: [binary()]} + defdelegate close(dtls), to: Native, as: :exd_close end diff --git a/test/integration_test.exs b/test/integration_test.exs index 9336d14..4c65dac 100644 --- a/test/integration_test.exs +++ b/test/integration_test.exs @@ -5,7 +5,7 @@ defmodule ExDTLS.IntegrationTest do rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true) tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true) - {packets, _timeout} = ExDTLS.do_handshake(tx_dtls) + {:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls) assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets) @@ -17,7 +17,7 @@ defmodule ExDTLS.IntegrationTest do rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true) tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true) - {packets, _timeout} = ExDTLS.do_handshake(tx_dtls) + {:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls) assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets) @@ -34,7 +34,7 @@ defmodule ExDTLS.IntegrationTest do assert {:error, :handshake_not_finished} = ExDTLS.write_data(sr_dtls, <<1, 2, 3>>) assert {:error, :handshake_not_finished} = ExDTLS.write_data(cl_dtls, <<1, 2, 3>>) - {packets, _timeout} = ExDTLS.do_handshake(cl_dtls) + {:ok, packets, _timeout} = ExDTLS.do_handshake(cl_dtls) assert :ok == loop({sr_dtls, false}, {cl_dtls, false}, packets) msg = <<1, 3, 2, 5>> @@ -55,11 +55,53 @@ defmodule ExDTLS.IntegrationTest do tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true) - {packets, _timeout} = ExDTLS.do_handshake(tx_dtls) + {:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls) {:handshake_packets, packets, _timeout} = feed_packets(rx_dtls, packets) assert {:error, :handshake_error} = feed_packets(tx_dtls, packets) end + describe "close/1" do + test "before handshake has finished (client mode)" do + dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true) + assert {:ok, []} = ExDTLS.close(dtls) + # assert that handshake can't be started + assert {:error, :closed} = ExDTLS.do_handshake(dtls) + end + + test "before handshake has finished (server mode)" do + dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true) + assert {:ok, []} = ExDTLS.close(dtls) + # assert that handshake can't be started + assert {:error, :closed} = ExDTLS.do_handshake(dtls) + end + + test "after handshake has finished (client mode)" do + rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true) + tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true) + + {:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls) + + assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets) + assert {:ok, [packet]} = ExDTLS.close(tx_dtls) + assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(rx_dtls, packet) + assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls) + assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls) + end + + test "after handshake has finished (server mode)" do + rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true) + tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true) + + {:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls) + + assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets) + assert {:ok, [packet]} = ExDTLS.close(rx_dtls) + assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(tx_dtls, packet) + assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls) + assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls) + end + end + defp loop({_dtls1, true}, {_dtls2, true}, _packets) do :ok end diff --git a/test/retransmission_test.exs b/test/retransmission_test.exs index ffbb88d..f7ec391 100644 --- a/test/retransmission_test.exs +++ b/test/retransmission_test.exs @@ -5,7 +5,7 @@ defmodule ExDTLS.RetransmissionTest do rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true) tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true) - {_packets, timeout} = ExDTLS.do_handshake(tx_dtls) + {:ok, _packets, timeout} = ExDTLS.do_handshake(tx_dtls) Process.send_after(self(), {:handle_timeout, :tx}, timeout) {:retransmit, packets, timeout} = wait_for_timeout(tx_dtls, :tx) Process.send_after(self(), {:handle_timeout, :tx}, timeout)