From c831d3274a9d6f6791ab8be38e27c9f373138e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Wed, 6 Sep 2023 23:35:11 +0100 Subject: [PATCH] feat: Connect to tenant database on channel join To improve stability we'll connect to the Tenant database upon channel. --- .../postgres_cdc_stream/cdc_stream.ex | 44 +- lib/realtime/helpers.ex | 25 + lib/realtime_web/channels/realtime_channel.ex | 523 +++++++++--------- .../channels/realtime_channel/assign.ex | 44 ++ lib/realtime_web/channels/user_socket.ex | 54 +- mix.exs | 2 +- .../cluster_strategy/postgres_test.exs | 2 +- .../channels/realtime_channel_test.exs | 34 ++ test/support/channel_case.ex | 1 + test/support/conn_case.ex | 40 -- test/support/generators.ex | 43 ++ 11 files changed, 445 insertions(+), 367 deletions(-) create mode 100644 lib/realtime_web/channels/realtime_channel/assign.ex create mode 100644 test/support/generators.ex diff --git a/lib/extensions/postgres_cdc_stream/cdc_stream.ex b/lib/extensions/postgres_cdc_stream/cdc_stream.ex index fc0a466e7..c63962636 100644 --- a/lib/extensions/postgres_cdc_stream/cdc_stream.ex +++ b/lib/extensions/postgres_cdc_stream/cdc_stream.ex @@ -9,8 +9,7 @@ defmodule Extensions.PostgresCdcStream do def handle_connect(opts) do Enum.reduce_while(1..5, nil, fn retry, acc -> - get_manager_conn(opts["id"]) - |> case do + case get_manager_conn(opts["id"]) do nil -> start_distributed(opts) if retry > 1, do: Process.sleep(1_000) @@ -22,13 +21,12 @@ defmodule Extensions.PostgresCdcStream do end) end - def handle_after_connect(_, _, _) do - {:ok, nil} - end + def handle_after_connect(_, _, _), do: {:ok, nil} def handle_subscribe(pg_change_params, tenant, metadata) do Enum.each(pg_change_params, fn e -> - topic(tenant, e.params) + tenant + |> topic(e.params) |> RealtimeWeb.Endpoint.subscribe(metadata) end) end @@ -45,13 +43,9 @@ defmodule Extensions.PostgresCdcStream do @spec get_manager_conn(String.t()) :: nil | {:ok, pid(), pid()} def get_manager_conn(id) do - Phoenix.Tracker.get_by_key(Stream.Tracker, "postgres_cdc_stream", id) - |> case do - [] -> - nil - - [{_, %{manager_pid: pid, conn: conn}}] -> - {:ok, pid, conn} + case Phoenix.Tracker.get_by_key(Stream.Tracker, "postgres_cdc_stream", id) do + [] -> nil + [{_, %{manager_pid: pid, conn: conn}}] -> {:ok, pid, conn} end end @@ -81,27 +75,23 @@ defmodule Extensions.PostgresCdcStream do def start(args) do addrtype = case args["ip_version"] do - 6 -> - :inet6 - - _ -> - :inet + 6 -> :inet6 + _ -> :inet end - args = - Map.merge(args, %{ - "db_socket_opts" => [addrtype] - }) + args = Map.merge(args, %{"db_socket_opts" => [addrtype]}) Logger.debug("Starting postgres stream extension with args: #{inspect(args, pretty: true)}") + opts = %{ + id: args["id"], + start: {Stream.WorkerSupervisor, :start_link, [args]}, + restart: :transient + } + DynamicSupervisor.start_child( {:via, PartitionSupervisor, {Stream.DynamicSupervisor, self()}}, - %{ - id: args["id"], - start: {Stream.WorkerSupervisor, :start_link, [args]}, - restart: :transient - } + opts ) end diff --git a/lib/realtime/helpers.ex b/lib/realtime/helpers.ex index e312dca1d..9db65ac9e 100644 --- a/lib/realtime/helpers.ex +++ b/lib/realtime/helpers.ex @@ -21,6 +21,31 @@ defmodule Realtime.Helpers do |> unpad() end + @spec connect_db(%{ + host: binary(), + port: non_neg_integer(), + name: binary(), + user: binary(), + pass: binary(), + socket_opts: list(), + pool: pos_integer(), + queue_target: pos_integer(), + ssl_enforced: boolean() + }) :: {:ok, pid} | {:error, Postgrex.Error.t() | term()} + def connect_db(%{ + host: host, + port: port, + name: name, + user: user, + pass: pass, + socket_opts: socket_opts, + pool: pool, + queue_target: queue_target, + ssl_enforced: ssl_enforced + }) do + connect_db(host, port, name, user, pass, socket_opts, pool, queue_target, ssl_enforced) + end + @spec connect_db( String.t(), String.t(), diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index 7ada9a347..c86a41c38 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -6,213 +6,97 @@ defmodule RealtimeWeb.RealtimeChannel do require Logger + alias Realtime.Helpers + alias DBConnection.Backoff + alias Phoenix.Tracker.Shard - alias RealtimeWeb.{ChannelsAuthorization, Endpoint, Presence} - alias Realtime.{GenCounter, RateCounter, PostgresCdc, SignalHandler, Tenants} - - import Realtime.Helpers, only: [cancel_timer: 1, decrypt!: 2] - - defmodule Assigns do - @moduledoc false - defstruct [ - :tenant, - :log_level, - :rate_counter, - :limits, - :tenant_topic, - :pg_sub_ref, - :pg_change_params, - :postgres_extension, - :claims, - :jwt_secret, - :tenant_token, - :access_token, - :postgres_cdc_module, - :channel_name - ] - @type t :: %__MODULE__{ - tenant: String.t(), - log_level: atom(), - rate_counter: RateCounter.t(), - limits: %{ - max_events_per_second: integer(), - max_concurrent_users: integer(), - max_bytes_per_second: integer(), - max_channels_per_client: integer(), - max_joins_per_second: integer() - }, - tenant_topic: String.t(), - pg_sub_ref: reference() | nil, - pg_change_params: map(), - postgres_extension: map(), - claims: map(), - jwt_secret: String.t(), - tenant_token: String.t(), - access_token: String.t(), - channel_name: String.t() - } - end + alias Realtime.GenCounter + alias Realtime.PostgresCdc + alias Realtime.RateCounter + alias Realtime.SignalHandler + alias Realtime.Tenants + + alias RealtimeWeb.ChannelsAuthorization + alias RealtimeWeb.Endpoint + alias RealtimeWeb.Presence @confirm_token_ms_interval 1_000 * 60 * 5 @impl true - def join( - "realtime:" <> sub_topic = topic, - params, - %{ - assigns: %{ - tenant: tenant, - log_level: log_level, - postgres_cdc_module: module - }, - channel_pid: channel_pid, - serializer: serializer, - transport_pid: transport_pid - } = socket - ) do + def join("realtime:" <> sub_topic = topic, params, socket) do + %{ + assigns: %{tenant: tenant, log_level: log_level, postgres_cdc_module: module}, + channel_pid: channel_pid, + serializer: serializer, + transport_pid: transport_pid + } = socket + Logger.metadata(external_id: tenant, project: tenant) Logger.put_process_level(self(), log_level) - socket = socket |> assign_access_token(params) |> assign_counter() + socket = + socket + |> assign_access_token(params) + |> assign_counter() start_db_rate_counter(tenant) with false <- SignalHandler.shutdown_in_progress?(), + %{extensions: extensions} <- Tenants.get_tenant_by_external_id(tenant), + :ok <- check_tenant_connection(extensions), :ok <- limit_joins(socket), :ok <- limit_channels(socket), :ok <- limit_max_users(socket), - {:ok, claims, confirm_token_ref} <- confirm_token(socket) do + {:ok, claims, confirm_token_ref} <- confirm_token(socket), + is_new_api <- is_new_api(params) do Realtime.UsersCounter.add(transport_pid, tenant) tenant_topic = tenant <> ":" <> sub_topic RealtimeWeb.Endpoint.subscribe(tenant_topic) - is_new_api = - case params do - %{"config" => _} -> true - _ -> false - end - - pg_change_params = - if is_new_api do - send(self(), :sync_presence) - - params["config"]["postgres_changes"] - |> case do - [_ | _] = params_list -> - params_list - |> Enum.map(fn params -> - %{ - id: UUID.uuid1(), - channel_pid: channel_pid, - claims: claims, - params: params - } - end) - - _ -> - [] - end - else - params = - case String.split(sub_topic, ":", parts: 3) do - [schema, table, filter] -> - %{"schema" => schema, "table" => table, "filter" => filter} - - [schema, table] -> - %{"schema" => schema, "table" => table} - - [schema] -> - %{"schema" => schema} - end - - [ - %{ - id: UUID.uuid1(), - channel_pid: channel_pid, - claims: claims, - params: params - } - ] - end - |> case do - [_ | _] = pg_change_params -> - ids = - for %{id: id, params: params} <- pg_change_params do - {UUID.string_to_binary!(id), :erlang.phash2(params)} - end - - metadata = [ - metadata: - {:subscriber_fastlane, transport_pid, serializer, ids, topic, tenant, is_new_api} - ] - - # Endpoint.subscribe("realtime:postgres:" <> tenant, metadata) - - PostgresCdc.subscribe(module, pg_change_params, tenant, metadata) - - pg_change_params - - other -> - other - end + pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic) - Logger.debug("Postgres change params: " <> inspect(pg_change_params)) + opts = %{ + is_new_api: is_new_api, + pg_change_params: pg_change_params, + transport_pid: transport_pid, + serializer: serializer, + topic: topic, + tenant: tenant, + module: module + } - if !Enum.empty?(pg_change_params) do - send(self(), :postgres_subscribe) - end + postgres_cdc_subscribe(opts) Logger.debug("Start channel: " <> inspect(pg_change_params)) - presence_key = presence_key(params) - - {:ok, - %{ - postgres_changes: - Enum.map(pg_change_params, fn %{params: params} -> - id = :erlang.phash2(params) - Map.put(params, :id, id) - end) - }, - assign(socket, %{ - ack_broadcast: !!params["config"]["broadcast"]["ack"], - confirm_token_ref: confirm_token_ref, - is_new_api: is_new_api, - pg_sub_ref: nil, - pg_change_params: pg_change_params, - presence_key: presence_key, - self_broadcast: !!params["config"]["broadcast"]["self"], - tenant_topic: tenant_topic, - channel_name: sub_topic - })} - else - {:error, :too_many_channels} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - {:error, :too_many_connections} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - {:error, :too_many_joins} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} + state = %{postgres_changes: add_id_to_postgres_changes(pg_change_params)} + + assigns = %{ + ack_broadcast: !!params["config"]["broadcast"]["ack"], + confirm_token_ref: confirm_token_ref, + is_new_api: is_new_api, + pg_sub_ref: nil, + pg_change_params: pg_change_params, + presence_key: presence_key(params), + self_broadcast: !!params["config"]["broadcast"]["self"], + tenant_topic: tenant_topic, + channel_name: sub_topic + } + {:ok, state, assign(socket, assigns)} + else {:error, [message: "Invalid token", claim: _claim, claim_val: _value]} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} + log_error_message(:warning, error) + + {:error, type} = error + when type in [:too_many_channels, :too_many_connections, :too_many_joins] -> + log_error_message(:warning, error) error -> - error_msg = inspect(error) - Logger.error("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} + log_error_message(:error, error) end end @@ -236,7 +120,7 @@ defmodule RealtimeWeb.RealtimeChannel do # TODO: don't use Presence unless client explicitly wants to use Presence # TODO: count these again when that happens - socket = socket |> maybe_log_handle_info(msg) + socket = maybe_log_handle_info(socket, msg) push(socket, "presence_state", presence_dirty_list(topic)) @@ -259,20 +143,19 @@ defmodule RealtimeWeb.RealtimeChannel do end @impl true - def handle_info( - :postgres_subscribe, - %{ - assigns: %{ - tenant: tenant, - pg_sub_ref: pg_sub_ref, - pg_change_params: pg_change_params, - postgres_extension: postgres_extension, - channel_name: channel_name, - postgres_cdc_module: module - } - } = socket - ) do - cancel_timer(pg_sub_ref) + def handle_info(:postgres_subscribe, socket) do + %{ + assigns: %{ + tenant: tenant, + pg_sub_ref: pg_sub_ref, + pg_change_params: pg_change_params, + postgres_extension: postgres_extension, + channel_name: channel_name, + postgres_cdc_module: module + } + } = socket + + Helpers.cancel_timer(pg_sub_ref) args = Map.put(postgres_extension, "id", tenant) @@ -281,20 +164,14 @@ defmodule RealtimeWeb.RealtimeChannel do case PostgresCdc.after_connect(module, response, postgres_extension, pg_change_params) do {:ok, _response} -> message = "Subscribed to PostgreSQL" - Logger.info(message) - push_system_message("postgres_changes", socket, "ok", message, channel_name) - {:noreply, assign(socket, :pg_sub_ref, nil)} error -> message = "Subscribing to PostgreSQL failed: " <> inspect(error) - - push_system_message("postgres_changes", socket, "error", message, channel_name) - Logger.error(message) - + push_system_message("postgres_changes", socket, "error", message, channel_name) {:noreply, assign(socket, :pg_sub_ref, postgres_subscribe(5, 10))} end @@ -375,11 +252,11 @@ defmodule RealtimeWeb.RealtimeChannel do } = socket ) when is_binary(refresh_token) do - socket = socket |> assign(:access_token, refresh_token) + socket = assign(socket, :access_token, refresh_token) case confirm_token(socket) do {:ok, claims, confirm_token_ref} -> - cancel_timer(pg_sub_ref) + Helpers.cancel_timer(pg_sub_ref) pg_change_params = Enum.map(pg_change_params, &Map.put(&1, :claims, claims)) @@ -389,12 +266,13 @@ defmodule RealtimeWeb.RealtimeChannel do _ -> nil end - {:noreply, - assign(socket, %{ - confirm_token_ref: confirm_token_ref, - pg_change_params: pg_change_params, - pg_sub_ref: pg_sub_ref - })} + assigns = %{ + pg_sub_ref: pg_sub_ref, + confirm_token_ref: confirm_token_ref, + pg_change_params: pg_change_params + } + + {:noreply, assign(socket, assigns)} {:error, error} -> message = "Received an invalid access token from client: " <> inspect(error) @@ -433,33 +311,10 @@ defmodule RealtimeWeb.RealtimeChannel do def handle_in( "presence", %{"event" => event} = payload, - %{assigns: %{is_new_api: true, presence_key: presence_key, tenant_topic: tenant_topic}} = - socket + %{assigns: %{is_new_api: true, presence_key: _, tenant_topic: _}} = socket ) do socket = count(socket) - - result = - event - |> String.downcase() - |> case do - "track" -> - payload = Map.get(payload, "payload", %{}) - - with {:error, {:already_tracked, _, _, _}} <- - Presence.track(self(), tenant_topic, presence_key, payload), - {:ok, _} <- Presence.update(self(), tenant_topic, presence_key, payload) do - :ok - else - {:ok, _} -> :ok - {:error, _} -> :error - end - - "untrack" -> - Presence.untrack(self(), tenant_topic, presence_key) - - _ -> - :error - end + result = handle_presence_event(event, payload, socket) {:reply, result, socket} end @@ -469,13 +324,35 @@ defmodule RealtimeWeb.RealtimeChannel do # Log info here so that bad messages from clients won't flood Logflare # Can subscribe to a Channel with `log_level` `info` to see these messages - Logger.info( - "Unexpected message from client of type `#{type}` with payload: " <> inspect(payload) - ) + message = "Unexpected message from client of type `#{type}` with payload: #{inspect(payload)}" + Logger.info(message) {:noreply, socket} end + defp handle_presence_event(event, payload, socket) do + %{assigns: %{presence_key: presence_key, tenant_topic: tenant_topic}} = socket + + case String.downcase(event) do + "track" -> + with payload <- Map.get(payload, "payload", %{}), + {:error, {:already_tracked, _, _, _}} <- + Presence.track(self(), tenant_topic, presence_key, payload), + {:ok, _} <- Presence.update(self(), tenant_topic, presence_key, payload) do + :ok + else + {:ok, _} -> :ok + {:error, _} -> :error + end + + "untrack" -> + Presence.untrack(self(), tenant_topic, presence_key) + + _ -> + :error + end + end + @impl true def terminate(reason, _state) do Logger.debug("Channel terminated with reason: " <> inspect(reason)) @@ -485,7 +362,7 @@ defmodule RealtimeWeb.RealtimeChannel do defp decrypt_jwt_secret(secret) do secure_key = Application.get_env(:realtime, :db_enc_key) - decrypt!(secret, secure_key) + Helpers.decrypt!(secret, secure_key) end defp postgres_subscribe(min \\ 1, max \\ 5) do @@ -513,15 +390,14 @@ defmodule RealtimeWeb.RealtimeChannel do GenCounter.add(id) case RateCounter.get(id) do - {:ok, %{avg: avg}} -> - if avg < limits.max_joins_per_second do - :ok - else - {:error, :too_many_joins} - end + {:ok, %{avg: avg}} when avg < limits.max_joins_per_second -> + :ok + + {:ok, %{avg: _}} -> + {:error, :too_many_joins} other -> - Logger.error("Unexpected error for #{tenant}: " <> inspect(other)) + Logger.error("Unexpected error for " <> tenant <> ": " <> inspect(other)) {:error, other} end end @@ -568,9 +444,7 @@ defmodule RealtimeWeb.RealtimeChannel do assign(socket, :rate_counter, rate_counter) end - defp assign_counter(socket) do - socket - end + defp assign_counter(socket), do: socket defp count(%{assigns: %{rate_counter: counter}} = socket) do GenCounter.add(counter.id) @@ -583,12 +457,10 @@ defmodule RealtimeWeb.RealtimeChannel do %{assigns: %{log_level: log_level, channel_name: channel_name}} = socket, msg ) do - if Logger.compare_levels(log_level, :error) == :lt, - do: - Logger.log( - log_level, - "HANDLE_INFO INCOMING ON " <> channel_name <> " message: " <> inspect(msg) - ) + if Logger.compare_levels(log_level, :error) == :lt do + msg = "HANDLE_INFO INCOMING ON " <> channel_name <> " message: " <> inspect(msg) + Logger.log(log_level, msg) + end socket end @@ -621,33 +493,20 @@ defmodule RealtimeWeb.RealtimeChannel do assign(socket, :access_token, tenant_token) end - defp confirm_token(%{ - assigns: - %{ - jwt_secret: jwt_secret, - access_token: access_token - } = assigns - }) do + defp confirm_token(%{assigns: %{jwt_secret: jwt_secret, access_token: access_token} = assigns}) do with jwt_secret_dec <- decrypt_jwt_secret(jwt_secret), {:ok, %{"exp" => exp} = claims} when is_integer(exp) <- ChannelsAuthorization.authorize_conn(access_token, jwt_secret_dec), exp_diff when exp_diff > 0 <- exp - Joken.current_time() do - if ref = assigns[:confirm_token_ref], do: cancel_timer(ref) + if ref = assigns[:confirm_token_ref], do: Helpers.cancel_timer(ref) - ref = - Process.send_after( - self(), - :confirm_token, - min(@confirm_token_ms_interval, exp_diff * 1_000) - ) + interval = min(@confirm_token_ms_interval, exp_diff * 1_000) + ref = Process.send_after(self(), :confirm_token, interval) {:ok, claims, ref} else - {:error, e} -> - {:error, e} - - e -> - {:error, e} + {:error, e} -> {:error, e} + e -> {:error, e} end end @@ -691,4 +550,128 @@ defmodule RealtimeWeb.RealtimeChannel do } ) end + + defp is_new_api(%{"config" => _}), do: true + defp is_new_api(_), do: false + + defp pg_change_params(true, params, channel_pid, claims, _) do + send(self(), :sync_presence) + + case get_in(params, ["config", "postgres_changes"]) do + [_ | _] = params_list -> + Enum.map(params_list, fn params -> + %{ + id: UUID.uuid1(), + channel_pid: channel_pid, + claims: claims, + params: params + } + end) + + _ -> + [] + end + end + + defp pg_change_params(false, params, channel_pid, claims, sub_topic) do + case String.split(sub_topic, ":", parts: 3) do + [schema, table, filter] -> %{"schema" => schema, "table" => table, "filter" => filter} + [schema, table] -> %{"schema" => schema, "table" => table} + [schema] -> %{"schema" => schema} + end + + [ + %{ + id: UUID.uuid1(), + channel_pid: channel_pid, + claims: claims, + params: params + } + ] + end + + defp postgres_cdc_subscribe(%{pg_change_params: []}), do: [] + + defp postgres_cdc_subscribe(opts) do + %{ + is_new_api: is_new_api, + pg_change_params: pg_change_params, + transport_pid: transport_pid, + serializer: serializer, + topic: topic, + tenant: tenant, + module: module + } = opts + + ids = + Enum.map(pg_change_params, fn %{id: id, params: params} -> + {UUID.string_to_binary!(id), :erlang.phash2(params)} + end) + + subscription_metadata = + {:subscriber_fastlane, transport_pid, serializer, ids, topic, tenant, is_new_api} + + metadata = [metadata: subscription_metadata] + + PostgresCdc.subscribe(module, pg_change_params, tenant, metadata) + + send(self(), :postgres_subscribe) + + pg_change_params + end + + defp add_id_to_postgres_changes(pg_change_params) do + Enum.map(pg_change_params, fn %{params: params} -> + id = :erlang.phash2(params) + Map.put(params, :id, id) + end) + end + + defp check_tenant_connection(extensions) do + extensions + |> Enum.map(fn %{settings: settings} -> + ssl_enforced = Helpers.default_ssl_param(settings) + + host = settings["db_host"] + port = settings["db_port"] + name = settings["db_name"] + user = settings["db_user"] + password = settings["db_password"] + socket_opts = settings["db_socket_opts"] + + opts = %{ + host: host, + port: port, + name: name, + user: user, + pass: password, + socket_opts: socket_opts, + pool: 1, + queue_target: 1000, + ssl_enforced: ssl_enforced + } + + with {:ok, conn} <- Helpers.connect_db(opts), + {:ok, _} <- Postgrex.query(conn, "SELECT 1", []) do + :ok + end + end) + |> Enum.any?(fn res -> res == :ok end) + |> then(fn + true -> :ok + false -> {:error, :tenant_database_unavailable} + end) + end + + defp log_error_message(:warning, error) do + error_msg = inspect(error) + Logger.warn("Start channel error: " <> error_msg) + {:error, %{reason: error_msg}} + end + + defp log_error_message(:error, error) do + error_msg = inspect(error) + Logger.error("Start channel error: " <> error_msg) + {:error, %{reason: error_msg}} + end end diff --git a/lib/realtime_web/channels/realtime_channel/assign.ex b/lib/realtime_web/channels/realtime_channel/assign.ex new file mode 100644 index 000000000..b108ff049 --- /dev/null +++ b/lib/realtime_web/channels/realtime_channel/assign.ex @@ -0,0 +1,44 @@ +defmodule RealtimeWeb.RealtimeChannel.Assigns do + @moduledoc """ + Assigns for RealtimeChannel + """ + + defstruct [ + :tenant, + :log_level, + :rate_counter, + :limits, + :tenant_topic, + :pg_sub_ref, + :pg_change_params, + :postgres_extension, + :claims, + :jwt_secret, + :tenant_token, + :access_token, + :postgres_cdc_module, + :channel_name + ] + + @type t :: %__MODULE__{ + tenant: String.t(), + log_level: atom(), + rate_counter: RateCounter.t(), + limits: %{ + max_events_per_second: integer(), + max_concurrent_users: integer(), + max_bytes_per_second: integer(), + max_channels_per_client: integer(), + max_joins_per_second: integer() + }, + tenant_topic: String.t(), + pg_sub_ref: reference() | nil, + pg_change_params: map(), + postgres_extension: map(), + claims: map(), + jwt_secret: String.t(), + tenant_token: String.t(), + access_token: String.t(), + channel_name: String.t() + } +end diff --git a/lib/realtime_web/channels/user_socket.ex b/lib/realtime_web/channels/user_socket.ex index 893059371..3b4a04b78 100644 --- a/lib/realtime_web/channels/user_socket.ex +++ b/lib/realtime_web/channels/user_socket.ex @@ -3,12 +3,14 @@ defmodule RealtimeWeb.UserSocket do require Logger - alias Realtime.{PostgresCdc, Api} + import Realtime.Helpers, only: [decrypt!: 2, get_external_id: 1] + alias Api.Tenant + alias Realtime.Api + alias Realtime.PostgresCdc alias Realtime.Tenants alias RealtimeWeb.ChannelsAuthorization alias RealtimeWeb.RealtimeChannel - import Realtime.Helpers, only: [decrypt!: 2, get_external_id: 1] ## Channels channel("realtime:*", RealtimeChannel) @@ -25,10 +27,10 @@ defmodule RealtimeWeb.UserSocket do log_level = params |> Map.get("log_level", @default_log_level) - |> case do + |> then(fn "" -> @default_log_level level -> level - end + end) |> String.to_existing_atom() secure_key = Application.get_env(:realtime, :db_enc_key) @@ -50,24 +52,24 @@ defmodule RealtimeWeb.UserSocket do jwt_secret_dec <- decrypt!(jwt_secret, secure_key), {:ok, claims} <- ChannelsAuthorization.authorize_conn(token, jwt_secret_dec), {:ok, postgres_cdc_module} <- PostgresCdc.driver(postgres_cdc_default) do - assigns = - %RealtimeChannel.Assigns{ - claims: claims, - jwt_secret: jwt_secret, - limits: %{ - max_concurrent_users: max_conn_users, - max_events_per_second: max_events_per_second, - max_bytes_per_second: max_bytes_per_second, - max_joins_per_second: max_joins_per_second, - max_channels_per_client: max_channels_per_client - }, - postgres_extension: PostgresCdc.filter_settings(postgres_cdc_default, extensions), - postgres_cdc_module: postgres_cdc_module, - tenant: external_id, - log_level: log_level, - tenant_token: token - } - |> Map.from_struct() + assigns = %RealtimeChannel.Assigns{ + claims: claims, + jwt_secret: jwt_secret, + limits: %{ + max_concurrent_users: max_conn_users, + max_events_per_second: max_events_per_second, + max_bytes_per_second: max_bytes_per_second, + max_joins_per_second: max_joins_per_second, + max_channels_per_client: max_channels_per_client + }, + postgres_extension: PostgresCdc.filter_settings(postgres_cdc_default, extensions), + postgres_cdc_module: postgres_cdc_module, + tenant: external_id, + log_level: log_level, + tenant_token: token + } + + assigns = Map.from_struct(assigns) {:ok, assign(socket, assigns)} else @@ -90,12 +92,8 @@ defmodule RealtimeWeb.UserSocket do end @impl true - def id(%{assigns: %{tenant: tenant}}) do - subscribers_id(tenant) - end + def id(%{assigns: %{tenant: tenant}}), do: subscribers_id(tenant) @spec subscribers_id(String.t()) :: String.t() - def subscribers_id(tenant) do - "user_socket:" <> tenant - end + def subscribers_id(tenant), do: "user_socket:" <> tenant end diff --git a/mix.exs b/mix.exs index 14a8a9cac..3253e700b 100644 --- a/mix.exs +++ b/mix.exs @@ -4,7 +4,7 @@ defmodule Realtime.MixProject do def project do [ app: :realtime, - version: "2.22.21", + version: "2.22.22", elixir: "~> 1.14.0", elixirc_paths: elixirc_paths(Mix.env()), start_permanent: Mix.env() == :prod, diff --git a/test/realtime/cluster_strategy/postgres_test.exs b/test/realtime/cluster_strategy/postgres_test.exs index 8891326c0..1bc6aa681 100644 --- a/test/realtime/cluster_strategy/postgres_test.exs +++ b/test/realtime/cluster_strategy/postgres_test.exs @@ -24,7 +24,7 @@ defmodule Realtime.Cluster.Strategy.PostgresTest do {:ok, conn_notif} = PN.start_link(state.meta.opts.()) PN.listen(conn_notif, channel_name) node = "#{node()}" - assert_receive {:notification, _, _, channel_name, ^node} + assert_receive {:notification, _, _, ^channel_name, ^node} end defp libcluster_state() do diff --git a/test/realtime_web/channels/realtime_channel_test.exs b/test/realtime_web/channels/realtime_channel_test.exs index a66b26d92..be48cddfe 100644 --- a/test/realtime_web/channels/realtime_channel_test.exs +++ b/test/realtime_web/channels/realtime_channel_test.exs @@ -118,4 +118,38 @@ defmodule RealtimeWeb.RealtimeChannelTest do end end end + + describe "checks tenant db connectivity" do + setup_with_mocks([ + {ChannelsAuthorization, [], + [ + authorize_conn: fn _, _ -> + {:ok, %{"exp" => Joken.current_time() + 1_000, "role" => "postgres"}} + end + ]} + ]) do + :ok + end + + test "successful connection proceeds with join" do + {:ok, %Socket{} = socket} = connect(UserSocket, %{}, @default_conn_opts) + assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{}) + end + + test "unsuccessful connection halts join" do + tenant = tenant_fixture() + + conn_opts = [ + connect_info: %{ + uri: %{host: "#{tenant.external_id}.localhost:4000/socket/websocket", query: ""}, + x_headers: [{"x-api-key", "token123"}] + } + ] + + {:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts) + + assert {:error, %{reason: "{:error, :tenant_database_unavailable}"}} = + subscribe_and_join(socket, "realtime:test", %{}) + end + end end diff --git a/test/support/channel_case.ex b/test/support/channel_case.ex index 1689a1688..c2ef770b0 100644 --- a/test/support/channel_case.ex +++ b/test/support/channel_case.ex @@ -22,6 +22,7 @@ defmodule RealtimeWeb.ChannelCase do quote do # Import conveniences for testing with channels import Phoenix.ChannelTest + import Generators # The default endpoint for testing @endpoint RealtimeWeb.Endpoint diff --git a/test/support/conn_case.ex b/test/support/conn_case.ex index b33737349..d9d3900eb 100644 --- a/test/support/conn_case.ex +++ b/test/support/conn_case.ex @@ -18,46 +18,6 @@ defmodule RealtimeWeb.ConnCase do use ExUnit.CaseTemplate alias Ecto.Adapters.SQL.Sandbox - defmodule Generators do - def tenant_fixture(override \\ %{}) do - create_attrs = %{ - "external_id" => rand_string(), - "name" => "localhost", - "extensions" => [ - %{ - "type" => "postgres_cdc_rls", - "settings" => %{ - "db_host" => "127.0.0.1", - "db_name" => "postgres", - "db_user" => "postgres", - "db_password" => "postgres", - "db_port" => "6432", - "poll_interval" => 100, - "poll_max_changes" => 100, - "poll_max_record_bytes" => 1_048_576, - "region" => "us-east-1" - } - } - ], - "postgres_cdc_default" => "postgres_cdc_rls", - "jwt_secret" => "new secret" - } - - {:ok, tenant} = - create_attrs - |> Map.merge(override) - |> Realtime.Api.create_tenant() - - tenant - end - - def rand_string(length \\ 10) do - length - |> :crypto.strong_rand_bytes() - |> Base.encode32() - end - end - using do quote do # Import conveniences for testing with connections diff --git a/test/support/generators.ex b/test/support/generators.ex new file mode 100644 index 000000000..c3479a346 --- /dev/null +++ b/test/support/generators.ex @@ -0,0 +1,43 @@ +defmodule Generators do + @moduledoc """ + Data genarators for tests. + """ + + def tenant_fixture(override \\ %{}) do + create_attrs = %{ + "external_id" => rand_string(), + "name" => "localhost", + "extensions" => [ + %{ + "type" => "postgres_cdc_rls", + "settings" => %{ + "db_host" => "127.0.0.1", + "db_name" => "postgres", + "db_user" => "postgres", + "db_password" => "postgres", + "db_port" => "6432", + "poll_interval" => 100, + "poll_max_changes" => 100, + "poll_max_record_bytes" => 1_048_576, + "region" => "us-east-1" + } + } + ], + "postgres_cdc_default" => "postgres_cdc_rls", + "jwt_secret" => "new secret" + } + + {:ok, tenant} = + create_attrs + |> Map.merge(override) + |> Realtime.Api.create_tenant() + + tenant + end + + def rand_string(length \\ 10) do + length + |> :crypto.strong_rand_bytes() + |> Base.encode32() + end +end