Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #202 from shazam/fixes-threading-issue
Browse files Browse the repository at this point in the history
Fixes threading issue
  • Loading branch information
jbaginski committed Jun 21, 2024
2 parents 977fde3 + a68ba85 commit a229ed1
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 55 deletions.
48 changes: 30 additions & 18 deletions fork-runner/src/main/java/com/shazam/fork/ForkRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

import static com.google.common.util.concurrent.Uninterruptibles.awaitTerminationUninterruptibly;
import static com.shazam.fork.Utils.namedExecutor;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
Expand Down Expand Up @@ -66,15 +68,14 @@ public ForkRunner(PoolLoader poolLoader,
}

public boolean run() {
ExecutorService poolExecutor = null;
try {
Collection<Pool> pools = poolLoader.loadPools();
poolExecutor = namedExecutor(pools.size(), "PoolExecutor-%d");

Collection<TestCaseEvent> testCases = testClassLoader.loadTestSuite();
summaryGeneratorHook.registerHook(pools, testCases);

executeTests(poolExecutor, pools, testCases);
executeTests(pools, testCases);


AggregatedTestResult aggregatedTestResult = aggregator.aggregateTestResults(pools, testCases);
if (!aggregatedTestResult.getFatalCrashedTests().isEmpty()) {
Expand All @@ -83,7 +84,7 @@ public boolean run() {

Collection<TestCaseEvent> fatalCrashedTestCases =
findFatalCrashedTestCases(testCases, aggregatedTestResult.getFatalCrashedTests());
executeTests(poolExecutor, pools, fatalCrashedTestCases);
executeTests(pools, fatalCrashedTestCases);

aggregatedTestResult = aggregator.aggregateTestResults(pools, testCases);

Expand All @@ -107,29 +108,40 @@ public boolean run() {
} catch (Exception e) {
logger.error("Error while Fork was executing", e);
return false;
} finally {
if (poolExecutor != null) {
poolExecutor.shutdown();
}
}
}

private void executeTests(ExecutorService poolExecutor,
Collection<Pool> pools,
Collection<TestCaseEvent> testCases) throws InterruptedException {
private void executeTests(
Collection<Pool> pools,
Collection<TestCaseEvent> testCases
) {
ProgressReporter progressReporter = progressReporterFactory.createProgressReporter();
progressReporter.start();

CountDownLatch poolCountDownLatch = new CountDownLatch(pools.size());
ExecutorService poolExecutor = null;
try {
poolExecutor = namedExecutor(pools.size(), "PoolExecutor-%d");

for (Pool pool : pools) {
Runnable poolTestRunner = poolTestRunnerFactory.createPoolTestRunner(
pool,
testCases,
progressReporter
);
poolExecutor.submit(poolTestRunner);
}

for (Pool pool : pools) {
Runnable poolTestRunner =
poolTestRunnerFactory.createPoolTestRunner(pool, testCases, poolCountDownLatch, progressReporter);
poolExecutor.execute(poolTestRunner);
poolExecutor.shutdown();
awaitTerminationUninterruptibly(poolExecutor);
} finally {
progressReporter.stop();

if (poolExecutor != null && !poolExecutor.isTerminated()) {
poolExecutor.shutdownNow();
awaitTerminationUninterruptibly(poolExecutor);
}
}
poolCountDownLatch.await();

progressReporter.stop();
}

private static void reportMissingTests(AggregatedTestResult aggregatedTestResult) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.Deque;

import static com.shazam.fork.system.io.RemoteFileManager.*;

Expand All @@ -33,25 +32,22 @@ public class DeviceTestRunner implements Runnable {
private final Installer installer;
private final Pool pool;
private final Device device;
private final Queue<TestCaseEvent> queueOfTestsInPool;
private final CountDownLatch deviceCountDownLatch;
private final Deque<TestCaseEvent> queueOfTestsInPool;
private final ProgressReporter progressReporter;
private final ScreenRecorder screenRecorder;
private final TestRunFactory testRunFactory;

public DeviceTestRunner(Installer installer,
Pool pool,
Device device,
Queue<TestCaseEvent> queueOfTestsInPool,
CountDownLatch deviceCountDownLatch,
Deque<TestCaseEvent> queueOfTestsInPool,
ProgressReporter progressReporter,
ScreenRecorder screenRecorder,
TestRunFactory testRunFactory) {
this.installer = installer;
this.pool = pool;
this.device = device;
this.queueOfTestsInPool = queueOfTestsInPool;
this.deviceCountDownLatch = deviceCountDownLatch;
this.progressReporter = progressReporter;
this.screenRecorder = screenRecorder;
this.testRunFactory = testRunFactory;
Expand Down Expand Up @@ -81,7 +77,6 @@ public void run() {
}
} finally {
logger.info("Device {} from pool {} finished", device.getSerial(), pool.getName());
deviceCountDownLatch.countDown();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import com.shazam.fork.model.TestCaseEvent;
import com.shazam.fork.system.adb.Installer;

import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.Deque;

public class DeviceTestRunnerFactory {

Expand All @@ -30,8 +29,7 @@ public DeviceTestRunnerFactory(Installer installer, TestRunFactory testRunFactor
}

public Runnable createDeviceTestRunner(Pool pool,
Queue<TestCaseEvent> testClassQueue,
CountDownLatch deviceInPoolCountDownLatch,
Deque<TestCaseEvent> testClassQueue,
Device device,
ProgressReporter progressReporter
) {
Expand All @@ -40,7 +38,6 @@ public Runnable createDeviceTestRunner(Pool pool,
pool,
device,
testClassQueue,
deviceInPoolCountDownLatch,
progressReporter,
new ScreenRecorderImpl(device),
testRunFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,26 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.Deque;
import java.util.concurrent.ExecutorService;

import static com.google.common.util.concurrent.Uninterruptibles.awaitTerminationUninterruptibly;
import static com.shazam.fork.Utils.namedExecutor;

public class PoolTestRunner implements Runnable {
private final Logger logger = LoggerFactory.getLogger(PoolTestRunner.class);
public static final String DROPPED_BY = "DroppedBy-";

private final Pool pool;
private final Queue<TestCaseEvent> testCases;
private final CountDownLatch poolCountDownLatch;
private final Deque<TestCaseEvent> testCases;
private final DeviceTestRunnerFactory deviceTestRunnerFactory;
private final ProgressReporter progressReporter;

public PoolTestRunner(DeviceTestRunnerFactory deviceTestRunnerFactory, Pool pool,
Queue<TestCaseEvent> testCases,
CountDownLatch poolCountDownLatch,
Deque<TestCaseEvent> testCases,
ProgressReporter progressReporter) {
this.pool = pool;
this.testCases = testCases;
this.poolCountDownLatch = poolCountDownLatch;
this.deviceTestRunnerFactory = deviceTestRunnerFactory;
this.progressReporter = progressReporter;
}
Expand All @@ -52,23 +49,25 @@ public void run() {
try {
int devicesInPool = pool.size();
concurrentDeviceExecutor = namedExecutor(devicesInPool, "DeviceExecutor-%d");
CountDownLatch deviceCountDownLatch = new CountDownLatch(devicesInPool);
logger.info("Pool {} started", poolName);
for (Device device : pool.getDevices()) {
Runnable deviceTestRunner = deviceTestRunnerFactory.createDeviceTestRunner(pool, testCases,
deviceCountDownLatch, device, progressReporter);
concurrentDeviceExecutor.execute(deviceTestRunner);
Runnable deviceTestRunner = deviceTestRunnerFactory.createDeviceTestRunner(
pool,
testCases,
device,
progressReporter
);
concurrentDeviceExecutor.submit(deviceTestRunner);
}
deviceCountDownLatch.await();
} catch (InterruptedException e) {
logger.warn("Pool {} was interrupted while running", poolName);

concurrentDeviceExecutor.shutdown();
awaitTerminationUninterruptibly(concurrentDeviceExecutor);
} finally {
if (concurrentDeviceExecutor != null) {
concurrentDeviceExecutor.shutdown();
if (concurrentDeviceExecutor != null && !concurrentDeviceExecutor.isTerminated()) {
concurrentDeviceExecutor.shutdownNow();
awaitTerminationUninterruptibly(concurrentDeviceExecutor);
}
logger.info("Pool {} finished", poolName);
poolCountDownLatch.countDown();
logger.info("Pools remaining: {}", poolCountDownLatch.getCount());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import com.shazam.fork.model.TestCaseEvent;

import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ConcurrentLinkedDeque;

public class PoolTestRunnerFactory {
private final DeviceTestRunnerFactory deviceTestRunnerFactory;
Expand All @@ -26,7 +25,6 @@ public PoolTestRunnerFactory(DeviceTestRunnerFactory deviceTestRunnerFactory) {

public Runnable createPoolTestRunner(Pool pool,
Collection<TestCaseEvent> testCases,
CountDownLatch poolCountDownLatch,
ProgressReporter progressReporter) {

int totalTests = testCases.size();
Expand All @@ -35,8 +33,7 @@ public Runnable createPoolTestRunner(Pool pool,
return new PoolTestRunner(
deviceTestRunnerFactory,
pool,
new LinkedList<>(testCases),
poolCountDownLatch,
new ConcurrentLinkedDeque<>(testCases),
progressReporter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static java.lang.String.format;

Expand Down Expand Up @@ -64,7 +65,7 @@ public void execute() {
}
runner.setRunName(poolName);
runner.setMethodName(testClassName, testMethodName);
runner.setMaxtimeToOutputResponse(testRunParameters.getTestOutputTimeout());
runner.setMaxTimeToOutputResponse(testRunParameters.getTestOutputTimeout(), TimeUnit.MILLISECONDS);

if (testRunParameters.isCoverageEnabled()) {
runner.setCoverage(true);
Expand Down

0 comments on commit a229ed1

Please sign in to comment.