Skip to content

Commit

Permalink
Ensure ChunkWriter uses current backend
Browse files Browse the repository at this point in the history
When changing backends, the ChunkWriter could still use the old one causing data loss, because chunks assumed to exist on new backend, were written to old one.
  • Loading branch information
grote committed Aug 28, 2024
1 parent 604d8dd commit 97a869f
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import java.io.IOException
import java.security.GeneralSecurityException

internal class SnapshotRetriever(
private val storagePlugin: () -> Backend,
private val backendGetter: () -> Backend,
private val streamCrypto: StreamCrypto = StreamCrypto,
) {

Expand All @@ -27,7 +27,7 @@ internal class SnapshotRetriever(
InvalidProtocolBufferException::class,
)
suspend fun getSnapshot(streamKey: ByteArray, storedSnapshot: StoredSnapshot): BackupSnapshot {
return storagePlugin().load(storedSnapshot.snapshotHandle).use { inputStream ->
return backendGetter().load(storedSnapshot.snapshotHandle).use { inputStream ->
val version = inputStream.readVersion()
val timestamp = storedSnapshot.timestamp
val ad = streamCrypto.getAssociatedDataForSnapshot(timestamp, version.toByte())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ internal class Backup(
} catch (e: GeneralSecurityException) {
throw AssertionError(e)
}
private val chunkWriter = ChunkWriter(streamCrypto, streamKey, chunksCache, backend, androidId)
private val chunkWriter =
ChunkWriter(streamCrypto, streamKey, chunksCache, backendGetter, androidId)
private val hasMediaAccessPerm =
context.checkSelfPermission(ACCESS_MEDIA_LOCATION) == PERMISSION_GRANTED
private val fileBackup = FileBackup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ internal class ChunkWriter(
private val streamCrypto: StreamCrypto,
private val streamKey: ByteArray,
private val chunksCache: ChunksCache,
private val backend: Backend,
private val backendGetter: () -> Backend,
private val androidId: String,
private val bufferSize: Int = DEFAULT_BUFFER_SIZE,
) {

private val backend get() = backendGetter()
private val buffer = ByteArray(bufferSize)

@Throws(IOException::class, GeneralSecurityException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import android.text.format.Formatter
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifyOrder
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
Expand Down Expand Up @@ -475,6 +476,39 @@ internal class BackupRestoreTest {
}
}

@Test
fun testBackupUpdatesBackend(): Unit = runBlocking {
val backendGetterNew: () -> Backend = mockk()
val backend1: Backend = mockk()
val backend2: Backend = mockk()
val backup = Backup(
context = context,
db = db,
fileScanner = fileScanner,
backendGetter = backendGetterNew,
androidId = androidId,
keyManager = keyManager,
cacheRepopulater = cacheRepopulater,
)
every { backendGetterNew() } returnsMany listOf(backend1, backend2)

coEvery { backend1.list(any(), Blob::class, callback = any()) } just Runs
every { chunksCache.areAllAvailableChunksCached(db, emptySet()) } returns true
every { fileScanner.getFiles() } returns FileScannerResult(emptyList(), emptyList())
every { filesCache.getByUri(any()) } returns null // nothing is cached, all is new

backup.runBackup(null)

// second run uses new backend
coEvery { backend2.list(any(), Blob::class, callback = any()) } just Runs
backup.runBackup(null)

coVerifyOrder {
backend1.list(any(), Blob::class, callback = any())
backend2.list(any(), Blob::class, callback = any())
}
}

private fun getRandomMediaFile(size: Int) = MediaFile(
uri = mockk(),
dir = getRandomString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal class ChunkWriterTest {

private val streamCrypto: StreamCrypto = mockk()
private val chunksCache: ChunksCache = mockk()
private val backendGetter: () -> Backend = mockk()
private val backend: Backend = mockk()
private val androidId: String = getRandomString()
private val streamKey: ByteArray = Random.nextBytes(KEY_SIZE_BYTES)
Expand All @@ -42,7 +43,7 @@ internal class ChunkWriterTest {
streamCrypto = streamCrypto,
streamKey = streamKey,
chunksCache = chunksCache,
backend = backend,
backendGetter = backendGetter,
androidId = androidId,
bufferSize = Random.nextInt(1, 42),
)
Expand All @@ -53,6 +54,7 @@ internal class ChunkWriterTest {

init {
mockLog()
every { backendGetter() } returns backend
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ internal class SmallFileBackupIntegrationTest {
private val filesCache: FilesCache = mockk()
private val mac: Mac = mockk()
private val chunksCache: ChunksCache = mockk()
private val backendGetter: () -> Backend = mockk()
private val backend: Backend = mockk()
private val androidId: String = getRandomString()

private val chunkWriter = ChunkWriter(
streamCrypto = StreamCrypto,
streamKey = Random.nextBytes(KEY_SIZE_BYTES),
chunksCache = chunksCache,
backend = backend,
backendGetter = backendGetter,
androidId = androidId,
)
private val zipChunker = ZipChunker(
Expand All @@ -58,6 +59,7 @@ internal class SmallFileBackupIntegrationTest {

init {
mockLog()
every { backendGetter() } returns backend
}

/**
Expand Down

0 comments on commit 97a869f

Please sign in to comment.