Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## v3.22.0

### Fixed

Make connection switching thread safe, by fixing a thread safety issue caused by using a (class) instance variable instead of a thread-local variable.

## v3.21.0

### Added
Expand Down
2 changes: 1 addition & 1 deletion active_record_shards.gemspec
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Gem::Specification.new "active_record_shards", "3.21.0" do |s|
Gem::Specification.new "active_record_shards", "3.22.0" do |s|
s.authors = ["Benjamin Quorning", "Gabe Martin-Dempesy", "Pierre Schambacher", "Mick Staugaard", "Eric Chapweske", "Ben Osheroff"]
s.email = ["[email protected]", "[email protected]", "[email protected]", "[email protected]"]
s.homepage = "https://github.com/zendesk/active_record_shards"
Expand Down
15 changes: 11 additions & 4 deletions lib/active_record_shards/connection_switcher.rb
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,11 @@ def on_primary(&block)
alias_method :with_slave_unless, :on_replica_unless

def on_cx_switch_block(which, force: false, construct_ro_scope: nil, &block)
@disallow_replica ||= 0
@disallow_replica += 1 if [:primary, :master].include?(which)
self.disallow_replica += 1 if [:primary, :master].include?(which)

ActiveRecordShards::Deprecation.warn('the `:master` option should be replaced with `:primary`!') if which == :master

switch_to_replica = force || @disallow_replica.zero?
switch_to_replica = force || disallow_replica.zero?
old_options = current_shard_selection.options

switch_connection(replica: switch_to_replica)
Expand All @@ -131,10 +130,18 @@ def on_cx_switch_block(which, force: false, construct_ro_scope: nil, &block)
readonly.scoping(&block)
end
ensure
@disallow_replica -= 1 if [:primary, :master].include?(which)
self.disallow_replica -= 1 if [:primary, :master].include?(which)
switch_connection(old_options) if old_options
end

def disallow_replica=(value)
Thread.current[:__active_record_shards__disallow_replica_by_thread] = value
end

def disallow_replica
Thread.current[:__active_record_shards__disallow_replica_by_thread] ||= 0
end

def supports_sharding?
shard_names.any?
end
Expand Down
1 change: 1 addition & 0 deletions test/database.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mysql: &MYSQL
port: <%= mysql.port %>
password: <%= mysql.password %>
ssl_mode: :disabled
reaping_frequency: 0 # Prevents ActiveRecord from spawning reaping threads.

# We connect to the unsharded primary database on a different port, via a proxy,
# so we can make the connection unavailable when testing on_replica_by_default
Expand Down
4 changes: 4 additions & 0 deletions test/helper.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# frozen_string_literal: true

# Stop Minitest creating threads we won't use.
# They add noise to the thread safety tests when inspecting `Thread.list`.
ENV["MT_CPU"] ||= "1"

require 'bundler/setup'
require 'minitest/autorun'
require 'minitest/rg'
Expand Down
273 changes: 273 additions & 0 deletions test/thread_safety_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# frozen_string_literal: true

require_relative 'helper'
require_relative 'models'

describe "connection switching thread safety" do
with_fresh_databases

before do
ActiveRecord::Base.establish_connection(:test)
use_same_connection_handler_for_all_theads
create_seed_data
end

after do
ActiveRecord::Base.connection_handler.clear_all_connections!
end

it "can safely switch between all database connections in parallel" do
new_thread("switches_through_all_1") do
pause_and_mark_ready
switch_through_all_databases
end
new_thread("switches_through_all_2") do
pause_and_mark_ready
switch_through_all_databases
end
new_thread("switches_through_all_3") do
pause_and_mark_ready
switch_through_all_databases
end

wait_for_threads_to_be_ready
execute_and_wait_for_threads
end

describe "when multiple threads use different databases" do
it "allows threads to parallelize their IO" do
results = []

query_delay = { fast: "0.01", slow: "1", medium: "0.5" }
new_thread("different_db_parallel_thread1") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
result = execute_sql("SELECT name,'slower query',SLEEP(#{query_delay.fetch(:slow)}) FROM accounts")
assert_equal('Primary account', result.first[0])
results.push(result)
end
end

new_thread("different_db_parallel_thread2") do
ActiveRecord::Base.on_replica do
pause_and_mark_ready
result = execute_sql("SELECT name, 'faster query',SLEEP(#{query_delay.fetch(:fast)}) FROM accounts")
assert_equal('Replica account', result.first[0])
results.push(result)
end
end

new_thread("different_db_parallel_thread3") do
ActiveRecord::Base.on_shard(0) do
pause_and_mark_ready
result = execute_sql("SELECT title, 'medium query',SLEEP(#{query_delay.fetch(:medium)}) FROM tickets")
assert_equal('Shard 0 Primary ticket', result.first[0])
results.push(result)
end
end

wait_for_threads_to_be_ready

thread_exection_time = Benchmark.realtime do
execute_and_wait_for_threads
end

minimum_serial_query_exection_time = query_delay.values.map(&:to_f).sum
# Arbitrarily faster time such that there must have been some parallelization
max_parallel_time = minimum_serial_query_exection_time - 0.1
assert_operator(max_parallel_time, :>, thread_exection_time)

# This order cannot be guaranteed but it likely given the artificial delays
rows = results.map(&:first)
result_strings = rows.map { |r| r[1] }
assert_equal(
[
"faster query",
"medium query",
"slower query"
],
result_strings
)
end
end

describe "when multiple threads use the same database" do
it "exposes a different connections to each thread" do
connections = []

new_thread("connection_per_thread1") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
connections << ActiveRecord::Base.connection
end
end

new_thread("connection_per_thread2") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
connections << ActiveRecord::Base.connection
end
end

wait_for_threads_to_be_ready
execute_and_wait_for_threads

expect(connections.first).must_be_kind_of(ActiveRecord::ConnectionAdapters::Mysql2Adapter)
assert_equal(2, connections.uniq.size)
end

it "allows threads to parallelize their IO" do
results = []

query_delay = { fast: "0.01", slow: "1", medium: "0.5" }
new_thread("same_db_parallel_thread1") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
result = execute_sql("SELECT 'slower query',SLEEP(#{query_delay.fetch(:slow)})")
results.push(result)
end
end

new_thread("same_db_parallel_thread2") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
result = execute_sql("SELECT 'faster query',SLEEP(#{query_delay.fetch(:fast)})")
results.push(result)
end
end

new_thread("same_db_parallel_thread3") do
ActiveRecord::Base.on_primary do
pause_and_mark_ready
result = execute_sql("SELECT 'medium query',SLEEP(#{query_delay.fetch(:medium)})")
results.push(result)
end
end

wait_for_threads_to_be_ready

thread_exection_time = Benchmark.realtime do
execute_and_wait_for_threads
end

minimum_serial_query_exection_time = query_delay.values.map(&:to_f).sum
# Arbitrarily faster time such that there must have been some parallelization
max_parallel_time = minimum_serial_query_exection_time - 0.1
assert_operator(max_parallel_time, :>, thread_exection_time)

rows = results.map(&:first)
result_strings = rows.map(&:first)
# This order cannot be guaranteed but it likely given the artificial delays
assert_equal(
[
"faster query",
"medium query",
"slower query"
],
result_strings
)
end
end

def new_thread(name)
thread = Thread.new do
Thread.current.name = name
yield
end

@test_threads ||= []
@test_threads.push(thread)
end

def switch_through_all_databases
ActiveRecord::Base.on_primary do
result = ActiveRecord::Base.connection.execute("SELECT * from accounts")
assert_equal("Primary account", record_name(result))
end
ActiveRecord::Base.on_replica do
result = ActiveRecord::Base.connection.execute("SELECT * from accounts")
assert_equal("Replica account", record_name(result))
end
ActiveRecord::Base.on_shard(0) do
result = ActiveRecord::Base.connection.execute("SELECT * from tickets")
assert_equal("Shard 0 Primary ticket", record_name(result))

ActiveRecord::Base.on_replica do
result = ActiveRecord::Base.connection.execute("SELECT * from tickets")
assert_equal("Shard 0 Replica ticket", record_name(result))
end
end
ActiveRecord::Base.on_shard(1) do
result = ActiveRecord::Base.connection.execute("SELECT * from tickets")
assert_equal("Shard 1 Primary ticket", record_name(result))

ActiveRecord::Base.on_replica do
result = ActiveRecord::Base.connection.execute("SELECT * from tickets")
assert_equal("Shard 1 Replica ticket", record_name(result))
end
end
end

# This allows us to get all of our threads into a prepared state by pausing
# them at a 'ready' point so as there is as little overhead as possible
# before the interesting code executes.
#
# Here we use 'ready' to mean the thread is spawned, has had its names set
# and has established a database connection.
def pause_and_mark_ready
Thread.current[:ready] = true
sleep
end

def execute_and_wait_for_threads
@test_threads.each { |t| t.wakeup if t.alive? }
@test_threads.each(&:join)
end

def wait_for_threads_to_be_ready
sleep(0.01) until @test_threads.all? { |t| t[:ready] }
end

def use_same_connection_handler_for_all_theads
ActiveRecord::Base.default_connection_handler = ActiveRecord::Base.connection_handler
end

def record_name(db_result)
name_column_index = 1
db_result.first[name_column_index]
end

def execute_sql(query)
ActiveRecord::Base.connection.execute(query)
end

def create_seed_data
ActiveRecord::Base.on_primary_db do
Account.connection.execute(account_insert_sql(name: "Primary account"))

Account.on_replica do
Account.connection.execute(account_insert_sql(name: "Replica account"))
end
end

[0, 1].each do |shard_id|
ActiveRecord::Base.on_shard(shard_id) do
Ticket.connection.execute(ticket_insert_sql(title: "Shard #{shard_id} Primary ticket"))

Ticket.on_replica do
Ticket.connection.execute(ticket_insert_sql(title: "Shard #{shard_id} Replica ticket"))
end
end
end
end

def account_insert_sql(name:)
"INSERT INTO accounts (id, name, created_at, updated_at)" \
" VALUES (1000, '#{name}', NOW(), NOW())"
end

def ticket_insert_sql(title:)
"INSERT INTO tickets (id, title, account_id, created_at, updated_at)" \
" VALUES (1000, '#{title}', 5000, NOW(), NOW())"
end
end