diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index 7ada9a347..87b876653 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -6,76 +6,40 @@ defmodule RealtimeWeb.RealtimeChannel do require Logger + import Realtime.Helpers, only: [cancel_timer: 1, decrypt!: 2] + alias DBConnection.Backoff - alias Phoenix.Tracker.Shard - alias RealtimeWeb.{ChannelsAuthorization, Endpoint, Presence} - alias Realtime.{GenCounter, RateCounter, PostgresCdc, SignalHandler, Tenants} - import Realtime.Helpers, only: [cancel_timer: 1, decrypt!: 2] + alias Phoenix.Tracker.Shard - defmodule Assigns do - @moduledoc false - defstruct [ - :tenant, - :log_level, - :rate_counter, - :limits, - :tenant_topic, - :pg_sub_ref, - :pg_change_params, - :postgres_extension, - :claims, - :jwt_secret, - :tenant_token, - :access_token, - :postgres_cdc_module, - :channel_name - ] + alias Realtime.GenCounter + alias Realtime.PostgresCdc + alias Realtime.RateCounter + alias Realtime.SignalHandler + alias Realtime.Tenants - @type t :: %__MODULE__{ - tenant: String.t(), - log_level: atom(), - rate_counter: RateCounter.t(), - limits: %{ - max_events_per_second: integer(), - max_concurrent_users: integer(), - max_bytes_per_second: integer(), - max_channels_per_client: integer(), - max_joins_per_second: integer() - }, - tenant_topic: String.t(), - pg_sub_ref: reference() | nil, - pg_change_params: map(), - postgres_extension: map(), - claims: map(), - jwt_secret: String.t(), - tenant_token: String.t(), - access_token: String.t(), - channel_name: String.t() - } - end + alias RealtimeWeb.ChannelsAuthorization + alias RealtimeWeb.Endpoint + alias RealtimeWeb.Presence @confirm_token_ms_interval 1_000 * 60 * 5 @impl true - def join( - "realtime:" <> sub_topic = topic, - params, - %{ - assigns: %{ - tenant: tenant, - log_level: log_level, - postgres_cdc_module: module - }, - channel_pid: channel_pid, - serializer: serializer, - transport_pid: transport_pid - } = socket - ) do + def join("realtime:" <> sub_topic = topic, params, socket) do + %{ + assigns: %{tenant: tenant, log_level: log_level, postgres_cdc_module: module}, + channel_pid: channel_pid, + serializer: serializer, + transport_pid: transport_pid + } = socket + Logger.metadata(external_id: tenant, project: tenant) Logger.put_process_level(self(), log_level) - socket = socket |> assign_access_token(params) |> assign_counter() + socket = + socket + |> assign_access_token(params) + |> assign_counter() start_db_rate_counter(tenant) @@ -83,136 +47,46 @@ defmodule RealtimeWeb.RealtimeChannel do :ok <- limit_joins(socket), :ok <- limit_channels(socket), :ok <- limit_max_users(socket), - {:ok, claims, confirm_token_ref} <- confirm_token(socket) do + {:ok, claims, confirm_token_ref} <- confirm_token(socket), + is_new_api <- is_new_api(params) do Realtime.UsersCounter.add(transport_pid, tenant) tenant_topic = tenant <> ":" <> sub_topic RealtimeWeb.Endpoint.subscribe(tenant_topic) - is_new_api = - case params do - %{"config" => _} -> true - _ -> false - end + pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic) - pg_change_params = - if is_new_api do - send(self(), :sync_presence) - - params["config"]["postgres_changes"] - |> case do - [_ | _] = params_list -> - params_list - |> Enum.map(fn params -> - %{ - id: UUID.uuid1(), - channel_pid: channel_pid, - claims: claims, - params: params - } - end) - - _ -> - [] - end - else - params = - case String.split(sub_topic, ":", parts: 3) do - [schema, table, filter] -> - %{"schema" => schema, "table" => table, "filter" => filter} - - [schema, table] -> - %{"schema" => schema, "table" => table} - - [schema] -> - %{"schema" => schema} - end - - [ - %{ - id: UUID.uuid1(), - channel_pid: channel_pid, - claims: claims, - params: params - } - ] - end - |> case do - [_ | _] = pg_change_params -> - ids = - for %{id: id, params: params} <- pg_change_params do - {UUID.string_to_binary!(id), :erlang.phash2(params)} - end - - metadata = [ - metadata: - {:subscriber_fastlane, transport_pid, serializer, ids, topic, tenant, is_new_api} - ] - - # Endpoint.subscribe("realtime:postgres:" <> tenant, metadata) - - PostgresCdc.subscribe(module, pg_change_params, tenant, metadata) - - pg_change_params - - other -> - other - end - - Logger.debug("Postgres change params: " <> inspect(pg_change_params)) + opts = %{ + is_new_api: is_new_api, + pg_change_params: pg_change_params, + transport_pid: transport_pid, + serializer: serializer, + topic: topic, + tenant: tenant, + module: module + } - if !Enum.empty?(pg_change_params) do - send(self(), :postgres_subscribe) - end + postgres_cdc_subscribe(opts) Logger.debug("Start channel: " <> inspect(pg_change_params)) - presence_key = presence_key(params) - - {:ok, - %{ - postgres_changes: - Enum.map(pg_change_params, fn %{params: params} -> - id = :erlang.phash2(params) - Map.put(params, :id, id) - end) - }, - assign(socket, %{ - ack_broadcast: !!params["config"]["broadcast"]["ack"], - confirm_token_ref: confirm_token_ref, - is_new_api: is_new_api, - pg_sub_ref: nil, - pg_change_params: pg_change_params, - presence_key: presence_key, - self_broadcast: !!params["config"]["broadcast"]["self"], - tenant_topic: tenant_topic, - channel_name: sub_topic - })} + state = %{postgres_changes: postgres_changes(pg_change_params)} + + assigns = %{ + ack_broadcast: !!params["config"]["broadcast"]["ack"], + confirm_token_ref: confirm_token_ref, + is_new_api: is_new_api, + pg_sub_ref: nil, + pg_change_params: pg_change_params, + presence_key: presence_key(params), + self_broadcast: !!params["config"]["broadcast"]["self"], + tenant_topic: tenant_topic, + channel_name: sub_topic + } + + {:ok, state, assign(socket, assigns)} else - {:error, :too_many_channels} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - {:error, :too_many_connections} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - {:error, :too_many_joins} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - {:error, [message: "Invalid token", claim: _claim, claim_val: _value]} = error -> - error_msg = inspect(error) - Logger.warn("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} - - error -> - error_msg = inspect(error) - Logger.error("Start channel error: " <> error_msg) - {:error, %{reason: error_msg}} + error -> handle_join_error(error) end end @@ -236,7 +110,7 @@ defmodule RealtimeWeb.RealtimeChannel do # TODO: don't use Presence unless client explicitly wants to use Presence # TODO: count these again when that happens - socket = socket |> maybe_log_handle_info(msg) + socket = maybe_log_handle_info(socket, msg) push(socket, "presence_state", presence_dirty_list(topic)) @@ -259,19 +133,18 @@ defmodule RealtimeWeb.RealtimeChannel do end @impl true - def handle_info( - :postgres_subscribe, - %{ - assigns: %{ - tenant: tenant, - pg_sub_ref: pg_sub_ref, - pg_change_params: pg_change_params, - postgres_extension: postgres_extension, - channel_name: channel_name, - postgres_cdc_module: module - } - } = socket - ) do + def handle_info(:postgres_subscribe, socket) do + %{ + assigns: %{ + tenant: tenant, + pg_sub_ref: pg_sub_ref, + pg_change_params: pg_change_params, + postgres_extension: postgres_extension, + channel_name: channel_name, + postgres_cdc_module: module + } + } = socket + cancel_timer(pg_sub_ref) args = Map.put(postgres_extension, "id", tenant) @@ -281,20 +154,14 @@ defmodule RealtimeWeb.RealtimeChannel do case PostgresCdc.after_connect(module, response, postgres_extension, pg_change_params) do {:ok, _response} -> message = "Subscribed to PostgreSQL" - Logger.info(message) - push_system_message("postgres_changes", socket, "ok", message, channel_name) - {:noreply, assign(socket, :pg_sub_ref, nil)} error -> message = "Subscribing to PostgreSQL failed: " <> inspect(error) - - push_system_message("postgres_changes", socket, "error", message, channel_name) - Logger.error(message) - + push_system_message("postgres_changes", socket, "error", message, channel_name) {:noreply, assign(socket, :pg_sub_ref, postgres_subscribe(5, 10))} end @@ -375,7 +242,7 @@ defmodule RealtimeWeb.RealtimeChannel do } = socket ) when is_binary(refresh_token) do - socket = socket |> assign(:access_token, refresh_token) + socket = assign(socket, :access_token, refresh_token) case confirm_token(socket) do {:ok, claims, confirm_token_ref} -> @@ -389,12 +256,13 @@ defmodule RealtimeWeb.RealtimeChannel do _ -> nil end - {:noreply, - assign(socket, %{ - confirm_token_ref: confirm_token_ref, - pg_change_params: pg_change_params, - pg_sub_ref: pg_sub_ref - })} + assigns = %{ + pg_sub_ref: pg_sub_ref, + confirm_token_ref: confirm_token_ref, + pg_change_params: pg_change_params + } + + {:noreply, assign(socket, assigns)} {:error, error} -> message = "Received an invalid access token from client: " <> inspect(error) @@ -433,33 +301,10 @@ defmodule RealtimeWeb.RealtimeChannel do def handle_in( "presence", %{"event" => event} = payload, - %{assigns: %{is_new_api: true, presence_key: presence_key, tenant_topic: tenant_topic}} = - socket + %{assigns: %{is_new_api: true, presence_key: _, tenant_topic: _}} = socket ) do socket = count(socket) - - result = - event - |> String.downcase() - |> case do - "track" -> - payload = Map.get(payload, "payload", %{}) - - with {:error, {:already_tracked, _, _, _}} <- - Presence.track(self(), tenant_topic, presence_key, payload), - {:ok, _} <- Presence.update(self(), tenant_topic, presence_key, payload) do - :ok - else - {:ok, _} -> :ok - {:error, _} -> :error - end - - "untrack" -> - Presence.untrack(self(), tenant_topic, presence_key) - - _ -> - :error - end + result = handle_presence_event(event, payload, socket) {:reply, result, socket} end @@ -469,13 +314,35 @@ defmodule RealtimeWeb.RealtimeChannel do # Log info here so that bad messages from clients won't flood Logflare # Can subscribe to a Channel with `log_level` `info` to see these messages - Logger.info( - "Unexpected message from client of type `#{type}` with payload: " <> inspect(payload) - ) + message = "Unexpected message from client of type `#{type}` with payload: #{inspect(payload)}" + Logger.info(message) {:noreply, socket} end + defp handle_presence_event(event, payload, socket) do + %{assigns: %{presence_key: presence_key, tenant_topic: tenant_topic}} = socket + + case String.downcase(event) do + "track" -> + with payload <- Map.get(payload, "payload", %{}), + {:error, {:already_tracked, _, _, _}} <- + Presence.track(self(), tenant_topic, presence_key, payload), + {:ok, _} <- Presence.update(self(), tenant_topic, presence_key, payload) do + :ok + else + {:ok, _} -> :ok + {:error, _} -> :error + end + + "untrack" -> + Presence.untrack(self(), tenant_topic, presence_key) + + _ -> + :error + end + end + @impl true def terminate(reason, _state) do Logger.debug("Channel terminated with reason: " <> inspect(reason)) @@ -513,15 +380,14 @@ defmodule RealtimeWeb.RealtimeChannel do GenCounter.add(id) case RateCounter.get(id) do - {:ok, %{avg: avg}} -> - if avg < limits.max_joins_per_second do - :ok - else - {:error, :too_many_joins} - end + {:ok, %{avg: avg}} when avg < limits.max_joins_per_second -> + :ok + + {:ok, %{avg: _}} -> + {:error, :too_many_joins} other -> - Logger.error("Unexpected error for #{tenant}: " <> inspect(other)) + Logger.error("Unexpected error for " <> tenant <> ": " <> inspect(other)) {:error, other} end end @@ -568,9 +434,7 @@ defmodule RealtimeWeb.RealtimeChannel do assign(socket, :rate_counter, rate_counter) end - defp assign_counter(socket) do - socket - end + defp assign_counter(socket), do: socket defp count(%{assigns: %{rate_counter: counter}} = socket) do GenCounter.add(counter.id) @@ -583,12 +447,10 @@ defmodule RealtimeWeb.RealtimeChannel do %{assigns: %{log_level: log_level, channel_name: channel_name}} = socket, msg ) do - if Logger.compare_levels(log_level, :error) == :lt, - do: - Logger.log( - log_level, - "HANDLE_INFO INCOMING ON " <> channel_name <> " message: " <> inspect(msg) - ) + if Logger.compare_levels(log_level, :error) == :lt do + msg = "HANDLE_INFO INCOMING ON " <> channel_name <> " message: " <> inspect(msg) + Logger.log(log_level, msg) + end socket end @@ -621,33 +483,20 @@ defmodule RealtimeWeb.RealtimeChannel do assign(socket, :access_token, tenant_token) end - defp confirm_token(%{ - assigns: - %{ - jwt_secret: jwt_secret, - access_token: access_token - } = assigns - }) do + defp confirm_token(%{assigns: %{jwt_secret: jwt_secret, access_token: access_token} = assigns}) do with jwt_secret_dec <- decrypt_jwt_secret(jwt_secret), {:ok, %{"exp" => exp} = claims} when is_integer(exp) <- ChannelsAuthorization.authorize_conn(access_token, jwt_secret_dec), exp_diff when exp_diff > 0 <- exp - Joken.current_time() do if ref = assigns[:confirm_token_ref], do: cancel_timer(ref) - ref = - Process.send_after( - self(), - :confirm_token, - min(@confirm_token_ms_interval, exp_diff * 1_000) - ) + interval = min(@confirm_token_ms_interval, exp_diff * 1_000) + ref = Process.send_after(self(), :confirm_token, interval) {:ok, claims, ref} else - {:error, e} -> - {:error, e} - - e -> - {:error, e} + {:error, e} -> {:error, e} + e -> {:error, e} end end @@ -691,4 +540,105 @@ defmodule RealtimeWeb.RealtimeChannel do } ) end + + defp is_new_api(%{"config" => _}), do: true + defp is_new_api(_), do: false + + defp pg_change_params(true, params, channel_pid, claims, _) do + send(self(), :sync_presence) + + case get_in(params, ["config", "postgres_changes"]) do + [_ | _] = params_list -> + Enum.map(params_list, fn params -> + %{ + id: UUID.uuid1(), + channel_pid: channel_pid, + claims: claims, + params: params + } + end) + + _ -> + [] + end + end + + defp pg_change_params(false, params, channel_pid, claims, sub_topic) do + case String.split(sub_topic, ":", parts: 3) do + [schema, table, filter] -> %{"schema" => schema, "table" => table, "filter" => filter} + [schema, table] -> %{"schema" => schema, "table" => table} + [schema] -> %{"schema" => schema} + end + + [ + %{ + id: UUID.uuid1(), + channel_pid: channel_pid, + claims: claims, + params: params + } + ] + end + + defp postgres_cdc_subscribe(%{pg_change_params: [_ | _]} = opts) do + %{ + is_new_api: is_new_api, + pg_change_params: pg_change_params, + transport_pid: transport_pid, + serializer: serializer, + topic: topic, + tenant: tenant, + module: module + } = opts + + ids = + Enum.map(pg_change_params, fn %{id: id, params: params} -> + {UUID.string_to_binary!(id), :erlang.phash2(params)} + end) + + subscription_metadata = + {:subscriber_fastlane, transport_pid, serializer, ids, topic, tenant, is_new_api} + + metadata = [metadata: subscription_metadata] + + PostgresCdc.subscribe(module, pg_change_params, tenant, metadata) + + send(self(), :postgres_subscribe) + + pg_change_params + end + + defp postgres_cdc_subscribe(%{pg_change_params: pg_change_params}), do: pg_change_params + + defp postgres_changes(pg_change_params) do + Enum.map(pg_change_params, fn %{params: params} -> + id = :erlang.phash2(params) + Map.put(params, :id, id) + end) + end + + defp handle_join_error( + {:error, [message: "Invalid token", claim: _claim, claim_val: _value]} = error + ) do + log_error_message(:warning, error) + end + + defp handle_join_error({:error, type} = error) + when type in [:too_many_channels, :too_many_connections, :too_many_joins] do + log_error_message(:warning, error) + end + + defp handle_join_error(error), do: log_error_message(:error, error) + + defp log_error_message(:warning, error) do + error_msg = inspect(error) + Logger.warn("Start channel error: " <> error_msg) + {:error, %{reason: error_msg}} + end + + defp log_error_message(:error, error) do + error_msg = inspect(error) + Logger.error("Start channel error: " <> error_msg) + {:error, %{reason: error_msg}} + end end diff --git a/lib/realtime_web/channels/realtime_channel/assign.ex b/lib/realtime_web/channels/realtime_channel/assign.ex new file mode 100644 index 000000000..f5ed79bfb --- /dev/null +++ b/lib/realtime_web/channels/realtime_channel/assign.ex @@ -0,0 +1,40 @@ +defmodule RealtimeWeb.RealtimeChannel.Assigns do + defstruct [ + :tenant, + :log_level, + :rate_counter, + :limits, + :tenant_topic, + :pg_sub_ref, + :pg_change_params, + :postgres_extension, + :claims, + :jwt_secret, + :tenant_token, + :access_token, + :postgres_cdc_module, + :channel_name + ] + + @type t :: %__MODULE__{ + tenant: String.t(), + log_level: atom(), + rate_counter: RateCounter.t(), + limits: %{ + max_events_per_second: integer(), + max_concurrent_users: integer(), + max_bytes_per_second: integer(), + max_channels_per_client: integer(), + max_joins_per_second: integer() + }, + tenant_topic: String.t(), + pg_sub_ref: reference() | nil, + pg_change_params: map(), + postgres_extension: map(), + claims: map(), + jwt_secret: String.t(), + tenant_token: String.t(), + access_token: String.t(), + channel_name: String.t() + } +end diff --git a/lib/realtime_web/channels/user_socket.ex b/lib/realtime_web/channels/user_socket.ex index 893059371..e0db727ed 100644 --- a/lib/realtime_web/channels/user_socket.ex +++ b/lib/realtime_web/channels/user_socket.ex @@ -25,10 +25,10 @@ defmodule RealtimeWeb.UserSocket do log_level = params |> Map.get("log_level", @default_log_level) - |> case do + |> then(fn "" -> @default_log_level level -> level - end + end) |> String.to_existing_atom() secure_key = Application.get_env(:realtime, :db_enc_key) @@ -50,24 +50,24 @@ defmodule RealtimeWeb.UserSocket do jwt_secret_dec <- decrypt!(jwt_secret, secure_key), {:ok, claims} <- ChannelsAuthorization.authorize_conn(token, jwt_secret_dec), {:ok, postgres_cdc_module} <- PostgresCdc.driver(postgres_cdc_default) do - assigns = - %RealtimeChannel.Assigns{ - claims: claims, - jwt_secret: jwt_secret, - limits: %{ - max_concurrent_users: max_conn_users, - max_events_per_second: max_events_per_second, - max_bytes_per_second: max_bytes_per_second, - max_joins_per_second: max_joins_per_second, - max_channels_per_client: max_channels_per_client - }, - postgres_extension: PostgresCdc.filter_settings(postgres_cdc_default, extensions), - postgres_cdc_module: postgres_cdc_module, - tenant: external_id, - log_level: log_level, - tenant_token: token - } - |> Map.from_struct() + assigns = %RealtimeChannel.Assigns{ + claims: claims, + jwt_secret: jwt_secret, + limits: %{ + max_concurrent_users: max_conn_users, + max_events_per_second: max_events_per_second, + max_bytes_per_second: max_bytes_per_second, + max_joins_per_second: max_joins_per_second, + max_channels_per_client: max_channels_per_client + }, + postgres_extension: PostgresCdc.filter_settings(postgres_cdc_default, extensions), + postgres_cdc_module: postgres_cdc_module, + tenant: external_id, + log_level: log_level, + tenant_token: token + } + + assigns = Map.from_struct(assigns) {:ok, assign(socket, assigns)} else diff --git a/lib/realtime_web/endpoint.ex b/lib/realtime_web/endpoint.ex index 11fbdd154..569650b40 100644 --- a/lib/realtime_web/endpoint.ex +++ b/lib/realtime_web/endpoint.ex @@ -10,7 +10,7 @@ defmodule RealtimeWeb.Endpoint do signing_salt: "5OUq5X4H" ] - socket "/socket", RealtimeWeb.UserSocket, + socket("/socket", RealtimeWeb.UserSocket, websocket: [ connect_info: [:peer_data, :uri, :x_headers], fullsweep_after: 20, @@ -21,43 +21,62 @@ defmodule RealtimeWeb.Endpoint do ] ], longpoll: true + ) - socket "/live", Phoenix.LiveView.Socket, websocket: [connect_info: [session: @session_options]] + if Mix.env() == :dev do + socket("/realtime/v1", RealtimeWeb.UserSocket, + websocket: [ + connect_info: [:peer_data, :uri, :x_headers], + fullsweep_after: 20, + max_frame_size: 8_000_000, + serializer: [ + {Phoenix.Socket.V1.JSONSerializer, "~> 1.0.0"}, + {Phoenix.Socket.V2.JSONSerializer, "~> 2.0.0"} + ] + ], + longpoll: true + ) + end + + socket("/live", Phoenix.LiveView.Socket, websocket: [connect_info: [session: @session_options]]) # Serve at "/" the static files from "priv/static" directory. # # You should set gzip to true if you are running phx.digest # when deploying your static files in production. - plug Plug.Static, + plug(Plug.Static, at: "/", from: :realtime, gzip: false, only: RealtimeWeb.static_paths() + ) # plug PromEx.Plug, path: "/metrics", prom_ex_module: Realtime.PromEx # Code reloading can be explicitly enabled under the # :code_reloader configuration of your endpoint. if code_reloading? do - socket "/phoenix/live_reload/socket", Phoenix.LiveReloader.Socket - plug Phoenix.LiveReloader - plug Phoenix.CodeReloader + socket("/phoenix/live_reload/socket", Phoenix.LiveReloader.Socket) + plug(Phoenix.LiveReloader) + plug(Phoenix.CodeReloader) end - plug Phoenix.LiveDashboard.RequestLogger, + plug(Phoenix.LiveDashboard.RequestLogger, param_key: "request_logger", cookie_key: "request_logger" + ) - plug Plug.RequestId - plug Plug.Telemetry, event_prefix: [:phoenix, :endpoint] + plug(Plug.RequestId) + plug(Plug.Telemetry, event_prefix: [:phoenix, :endpoint]) - plug Plug.Parsers, + plug(Plug.Parsers, parsers: [:urlencoded, :multipart, :json], pass: ["*/*"], json_decoder: Phoenix.json_library() + ) - plug Plug.MethodOverride - plug Plug.Head - plug Plug.Session, @session_options - plug RealtimeWeb.Router + plug(Plug.MethodOverride) + plug(Plug.Head) + plug(Plug.Session, @session_options) + plug(RealtimeWeb.Router) end