diff --git a/src/main/java/net/juniper/netconf/NetconfSession.java b/src/main/java/net/juniper/netconf/NetconfSession.java index 5b25835..fabc9c3 100644 --- a/src/main/java/net/juniper/netconf/NetconfSession.java +++ b/src/main/java/net/juniper/netconf/NetconfSession.java @@ -39,7 +39,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Optional.ofNullable; /** * A {@code NetconfSession} is obtained by first building a @@ -59,6 +66,7 @@ public class NetconfSession { private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(NetconfSession.class); + private final ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor(); private final Channel netconfChannel; private String serverCapability; @@ -127,30 +135,49 @@ private void sendHello(String hello) throws IOException { } @VisibleForTesting - String getRpcReply(String rpc) throws IOException { + String getRpcReply(final String rpc) throws IOException { // write the rpc to the device sendRpcRequest(rpc); - final char[] buffer = new char[BUFFER_SIZE]; - final StringBuilder rpcReply = new StringBuilder(); - final long startTime = System.nanoTime(); - final Reader in = new InputStreamReader(stdInStreamFromDevice, Charsets.UTF_8); - boolean timeoutNotExceeded = true; - int promptPosition; - while ((promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT)) < 0 && - (timeoutNotExceeded = (TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime) < commandTimeout))) { - int charsRead = in.read(buffer, 0, buffer.length); - if (charsRead < 0) throw new NetconfException("Input Stream has been closed during reading."); - rpcReply.append(buffer, 0, charsRead); - } - - if (!timeoutNotExceeded) + final AtomicReference threadReference = new AtomicReference<>(); + try { + return singleThreadExecutor.submit(() -> { + try { + + threadReference.set(Thread.currentThread()); + final char[] buffer = new char[BUFFER_SIZE]; + final StringBuilder rpcReply = new StringBuilder(); + final Reader in = new InputStreamReader(stdInStreamFromDevice, Charsets.UTF_8); + int promptPosition; + while ((promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT)) < 0) { + int charsRead = in.read(buffer, 0, buffer.length); + if (charsRead < 0) throw new NetconfException("Input Stream has been closed during reading."); + rpcReply.append(buffer, 0, charsRead); + } + + log.debug("Received Netconf RPC-Reply\n{}", rpcReply); + rpcReply.setLength(promptPosition); + return rpcReply.toString(); + + } catch (final Exception e) { + log.warn("Error reading from input stream", e); + throw e; + } + }) + .get(commandTimeout, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new NetconfException("Thread interrupted whilst waiting for RPC reply", e); + } catch (final ExecutionException e) { + if(e.getCause() instanceof NetconfException) { + throw (NetconfException) e.getCause(); + } + throw new NetconfException("Unexpected exception whilst waiting for RPC reply", e); + } catch (final TimeoutException e) { + // Make sure the thread isn't still running + ofNullable(threadReference.get()).ifPresent(Thread::interrupt); throw new SocketTimeoutException("Command timeout limit was exceeded: " + commandTimeout); - // fixing the rpc reply by removing device prompt - log.debug("Received Netconf RPC-Reply\n{}", rpcReply); - rpcReply.setLength(promptPosition); - - return rpcReply.toString(); + } } private BufferedReader getRpcReplyRunning(String rpc) throws IOException { diff --git a/src/test/java/net/juniper/netconf/NetconfSessionTest.java b/src/test/java/net/juniper/netconf/NetconfSessionTest.java index 25730b7..9c704da 100644 --- a/src/test/java/net/juniper/netconf/NetconfSessionTest.java +++ b/src/test/java/net/juniper/netconf/NetconfSessionTest.java @@ -7,6 +7,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.slf4j.Logger; @@ -25,17 +26,17 @@ import java.net.SocketTimeoutException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class NetconfSessionTest { @@ -489,4 +490,67 @@ private String createHelloMessage() { + " 27700\n" + ""; } + + + @Test + @Timeout(value = 2, unit = TimeUnit.SECONDS) + void ifTheDeviceDoesNotRespondAnExceptionWillBeThrown() { + final Duration commandTimeoutDuration = Duration.ofSeconds(1); + + final Instant startTime = Instant.now(); + assertThatThrownBy(() -> createNetconfSession((int) commandTimeoutDuration.toMillis())) + .isInstanceOf(SocketTimeoutException.class) + .hasMessageStartingWith("Command timeout limit was exceeded"); + + final Duration executeRpcDuration = Duration.between(startTime, Instant.now()); + // This should have taken about 1 second to time out + assertThat(executeRpcDuration) + .isGreaterThanOrEqualTo(commandTimeoutDuration); + } + + @Test + @Timeout(value = 2, unit = TimeUnit.SECONDS) + void ifTheDeviceDoesNotRespondTheSessionCanStillBeUsed() throws Exception { + + final Semaphore semaphore = new Semaphore(0); + + final Duration commandTimeoutDuration = Duration.ofSeconds(1); + + new Thread(() -> { + try { + // This is the "hello" from the device, in response to the "Hello" to the initial client ""hello" + outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + + // Don't send any response until it's required + semaphore.acquire(); + // Now send a second response + outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + outPipe.close(); + } catch (final Exception e) { + log.error("Error in background thread", e); + } + }).start(); + final NetconfSession netconfSession = createNetconfSession((int) commandTimeoutDuration.toMillis()); + // We've now received a "FAKE_RPC_REPLY" + + // Now send a request, but we're expecting a timeout as the device won't send it yet + final Instant startTime = Instant.now(); + assertThatThrownBy(() -> netconfSession.getRpcReply("")) + .isInstanceOf(SocketTimeoutException.class) + .hasMessageStartingWith("Command timeout limit was exceeded"); + final Duration executeRpcDuration = Duration.between(startTime, Instant.now()); + + // This should have taken about 1 second to time out + assertThat(executeRpcDuration) + .isGreaterThanOrEqualTo(commandTimeoutDuration); + + // Try again - we should get a reply + semaphore.release(); // Ensure the device sends a response + final String rpcReply = netconfSession.getRpcReply(""); + assertThat(rpcReply).isEqualTo(FAKE_RPC_REPLY); + } } \ No newline at end of file